};
struct clip_hparams {
- bool has_vision = false;
- bool has_audio = false;
-
int32_t image_size;
int32_t patch_size;
int32_t n_embd;
int32_t n_layer;
int32_t proj_scale_factor = 0; // idefics3
+ float image_mean[3];
+ float image_std[3];
+
// for models using dynamic image size, we need to have a smaller image size to warmup
// otherwise, user will get OOM everytime they load the model
int32_t warmup_image_size = 0;
+ int32_t warmup_audio_size = 3000;
ffn_op_type ffn_op = FFN_GELU;
// audio
int32_t n_mel_bins = 0; // whisper preprocessor
int32_t proj_stack_factor = 0; // ultravox
+
+ // legacy
+ bool has_llava_projector = false;
+ int minicpmv_version = 0;
};
struct clip_layer {
ggml_tensor * ls_2_w = nullptr;
};
-struct clip_vision_model {
- struct clip_hparams hparams;
+struct clip_model {
+ clip_modality modality = CLIP_MODALITY_VISION;
+ projector_type proj_type = PROJECTOR_TYPE_MLP;
+ clip_hparams hparams;
// embeddings
ggml_tensor * class_embedding = nullptr;
};
struct clip_ctx {
- bool has_llava_projector = false;
- int minicpmv_version = 0;
-
- struct clip_vision_model vision_model;
- projector_type proj_type = PROJECTOR_TYPE_MLP;
-
- float image_mean[3];
- float image_std[3];
+ clip_model model;
gguf_context_ptr ctx_gguf;
ggml_context_ptr ctx_data;
ggml_backend_free(backend_cpu);
}
}
+
+ // this function is added so that we don't change too much of the existing code
+ projector_type proj_type() const {
+ return model.proj_type;
+ }
};
struct clip_graph {
clip_ctx * ctx;
- const clip_vision_model & model;
+ const clip_model & model;
const clip_hparams & hparams;
// we only support single image per batch
clip_graph(clip_ctx * ctx, const clip_image_f32 & img) :
ctx(ctx),
- model(ctx->vision_model),
+ model(ctx->model),
hparams(model.hparams),
img(img),
patch_size(hparams.patch_size),
model.position_embeddings,
nullptr);
- if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
+ if (ctx->proj_type() == PROJECTOR_TYPE_GEMMA3) {
const int batch_size = 1;
GGML_ASSERT(n_patches_x == n_patches_y);
const int patches_per_image = n_patches_x;
ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)),
cur);
- } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
+ } else if (ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3) {
// https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
const int scale_factor = model.hparams.proj_scale_factor;
const int n_pos = n_patches;
const int num_position_ids = n_pos * 4; // m-rope requires 4 dim per position
- norm_type norm_t = ctx->proj_type == PROJECTOR_TYPE_QWEN25VL
+ norm_type norm_t = ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL
? NORM_TYPE_RMS // qwen 2.5 vl
: NORM_TYPE_NORMAL; // qwen 2 vl
const int d_head = 128;
int n_head = n_embd/d_head;
int num_query = 96;
- if (ctx->minicpmv_version == 2) {
+ if (ctx->model.hparams.minicpmv_version == 2) {
num_query = 96;
- } else if (ctx->minicpmv_version == 3) {
+ } else if (ctx->model.hparams.minicpmv_version == 3) {
num_query = 64;
- } else if (ctx->minicpmv_version == 4) {
+ } else if (ctx->model.hparams.minicpmv_version == 4) {
num_query = 64;
}
int il_last = hparams.n_layer - 1;
int deepest_feature_layer = -1;
- if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
+ if (ctx->proj_type() == PROJECTOR_TYPE_MINICPMV || ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE) {
il_last += 1;
}
}
// llava projector (also used by granite)
- if (ctx->has_llava_projector) {
+ if (ctx->model.hparams.has_llava_projector) {
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
// print_tensor_info(embeddings, "embeddings");
// llava projector
- if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
+ if (ctx->proj_type() == PROJECTOR_TYPE_MLP) {
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
}
}
- else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
+ else if (ctx->proj_type() == PROJECTOR_TYPE_MLP_NORM) {
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
// ggml_tensor_printf(embeddings, "mm_0_w",0,true,false);
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_4_w),
model.mm_4_b);
}
- else if (ctx->proj_type == PROJECTOR_TYPE_LDP) {
+ else if (ctx->proj_type() == PROJECTOR_TYPE_LDP) {
// MobileVLM projector
int n_patch = 24;
ggml_tensor * mlp_1 = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, embeddings);
}
embeddings = block_1;
}
- else if (ctx->proj_type == PROJECTOR_TYPE_LDPV2)
+ else if (ctx->proj_type() == PROJECTOR_TYPE_LDPV2)
{
int n_patch = 24;
ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
}
// glm projector
- else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
+ else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE) {
size_t gridsz = (size_t)sqrt(embeddings->ne[1]);
embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3));
embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]);
cb(cur, "after_transformer", -1);
- if (ctx->proj_type == PROJECTOR_TYPE_ULTRAVOX) {
+ if (ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX) {
// StackAudioFrames
// https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py
{
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
}
- } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2A) {
+ } else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A) {
// projector
cur = ggml_mul_mat(ctx0, model.mm_fc_w, cur);
cur = ggml_add(ctx0, cur, model.mm_fc_b);
}
// TODO @ngxson : find a way to move this outside
- if (ctx->proj_type == PROJECTOR_TYPE_QWEN2A) {
+ if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A) {
ggml_tensor * cur = inpL;
cur = ggml_transpose(ctx0, cur);
cur = ggml_cont(ctx0, cur);
ggml_cgraph * res;
- switch (ctx->proj_type) {
+ switch (ctx->proj_type()) {
case PROJECTOR_TYPE_GEMMA3:
case PROJECTOR_TYPE_IDEFICS3:
{
ggml_context_ptr ctx_meta;
gguf_context_ptr ctx_gguf;
- clip_ctx & ctx_clip;
std::string fname;
size_t model_size = 0; // in bytes
- // TODO @ngxson : we should not pass clip_ctx here, it should be clip_vision_model
- clip_model_loader(const char * fname, clip_ctx & ctx_clip) : ctx_clip(ctx_clip), fname(fname) {
+ bool has_vision = false;
+ bool has_audio = false;
+
+ // TODO @ngxson : we should not pass clip_ctx here, it should be clip_model
+ clip_model_loader(const char * fname) : fname(fname) {
struct ggml_context * meta = nullptr;
struct gguf_init_params params = {
LOG_INF("\n");
}
+ // modalities
+ {
+ get_bool(KEY_HAS_VISION_ENC, has_vision, false);
+ get_bool(KEY_HAS_AUDIO_ENC, has_audio, false);
+
+ if (has_vision) {
+ LOG_INF("%s: has vision encoder\n", __func__);
+ }
+ if (has_audio) {
+ LOG_INF("%s: has audio encoder\n", __func__);
+ }
+ }
+
// tensors
{
for (int i = 0; i < n_tensors; ++i) {
}
}
- void load_hparams() {
- auto & hparams = ctx_clip.vision_model.hparams;
+ void load_hparams(clip_model & model, clip_modality modality) {
+ auto & hparams = model.hparams;
std::string log_ffn_op; // for logging
+ // sanity check
+ if (modality == CLIP_MODALITY_VISION) {
+ GGML_ASSERT(has_vision);
+ } else if (modality == CLIP_MODALITY_AUDIO) {
+ GGML_ASSERT(has_audio);
+ }
+ model.modality = modality;
+
+
// projector type
std::string proj_type;
{
get_string(KEY_PROJ_TYPE, proj_type, false);
if (!proj_type.empty()) {
- ctx_clip.proj_type = clip_projector_type_from_string(proj_type);
+ model.proj_type = clip_projector_type_from_string(proj_type);
}
- if (ctx_clip.proj_type == PROJECTOR_TYPE_UNKNOWN) {
+ if (model.proj_type == PROJECTOR_TYPE_UNKNOWN) {
throw std::runtime_error(string_format("%s: unknown projector type: %s\n", __func__, proj_type.c_str()));
}
+
+ // correct arch for multimodal models
+ if (model.proj_type == PROJECTOR_TYPE_QWEN25O) {
+ model.proj_type = modality == CLIP_MODALITY_VISION
+ ? PROJECTOR_TYPE_QWEN25VL
+ : PROJECTOR_TYPE_QWEN2A;
+ }
}
+ const bool is_vision = model.modality == CLIP_MODALITY_VISION;
+ const bool is_audio = model.modality == CLIP_MODALITY_AUDIO;
+
// other hparams
{
- get_bool(KEY_HAS_AUDIO_ENC, hparams.has_audio, false);
- get_bool(KEY_HAS_VISION_ENC, hparams.has_vision, false);
-
- const char * prefix = hparams.has_vision ? "vision" : "audio";
+ const char * prefix = is_vision ? "vision" : "audio";
get_u32(string_format(KEY_N_EMBD, prefix), hparams.n_embd);
get_u32(string_format(KEY_N_HEAD, prefix), hparams.n_head);
get_u32(string_format(KEY_N_FF, prefix), hparams.n_ff);
get_u32(string_format(KEY_PROJ_DIM, prefix), hparams.projection_dim);
get_f32(string_format(KEY_LAYER_NORM_EPS, prefix), hparams.eps);
- if (hparams.has_vision) {
+ if (is_vision) {
get_u32(KEY_IMAGE_SIZE, hparams.image_size);
get_u32(KEY_PATCH_SIZE, hparams.patch_size);
get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false);
- get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false); // legacy
+ get_i32(KEY_MINICPMV_VERSION, hparams.minicpmv_version, false); // legacy
- } else if (hparams.has_audio) {
+ } else if (is_audio) {
get_u32(KEY_A_NUM_MEL_BINS, hparams.n_mel_bins);
} else {
- throw std::runtime_error(string_format("%s: neither vision nor audio encoder is present\n", __func__));
+ GGML_ASSERT(false && "unknown modality");
}
// default warmup value
hparams.warmup_image_size = hparams.image_size;
- ctx_clip.has_llava_projector = ctx_clip.proj_type == PROJECTOR_TYPE_MLP
- || ctx_clip.proj_type == PROJECTOR_TYPE_MLP_NORM
- || ctx_clip.proj_type == PROJECTOR_TYPE_LDP
- || ctx_clip.proj_type == PROJECTOR_TYPE_LDPV2;
+ hparams.has_llava_projector = model.proj_type == PROJECTOR_TYPE_MLP
+ || model.proj_type == PROJECTOR_TYPE_MLP_NORM
+ || model.proj_type == PROJECTOR_TYPE_LDP
+ || model.proj_type == PROJECTOR_TYPE_LDPV2;
{
bool use_gelu = false;
}
}
- if (hparams.has_vision) {
+ if (is_vision) {
int idx_mean = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_MEAN);
int idx_std = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_STD);
GGML_ASSERT(idx_mean >= 0 && "image_mean not found");
const float * mean_data = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_mean);
const float * std_data = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_std);
for (int i = 0; i < 3; ++i) {
- ctx_clip.image_mean[i] = mean_data[i];
- ctx_clip.image_std[i] = std_data[i];
+ hparams.image_mean[i] = mean_data[i];
+ hparams.image_std[i] = std_data[i];
}
}
}
// model-specific params
- switch (ctx_clip.proj_type) {
+ switch (model.proj_type) {
case PROJECTOR_TYPE_MINICPMV:
{
- if (ctx_clip.minicpmv_version == 0) {
- ctx_clip.minicpmv_version = 2; // default to 2 if not set
+ if (hparams.minicpmv_version == 0) {
+ hparams.minicpmv_version = 2; // default to 2 if not set
}
} break;
case PROJECTOR_TYPE_IDEFICS3:
case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_QWEN2A:
{
- bool require_stack = ctx_clip.proj_type == PROJECTOR_TYPE_ULTRAVOX;
+ bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX;
get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor, require_stack);
if (hparams.n_mel_bins != 128) {
throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__));
}
LOG_INF("%s: projector: %s\n", __func__, proj_type.c_str());
- LOG_INF("%s: has_vision_encoder: %d\n", __func__, hparams.has_vision);
- LOG_INF("%s: has_audio_encoder: %d\n", __func__, hparams.has_audio);
LOG_INF("%s: n_embd: %d\n", __func__, hparams.n_embd);
LOG_INF("%s: n_head: %d\n", __func__, hparams.n_head);
LOG_INF("%s: n_ff: %d\n", __func__, hparams.n_ff);
LOG_INF("%s: n_layer: %d\n", __func__, hparams.n_layer);
LOG_INF("%s: ffn_op: %s\n", __func__, log_ffn_op.c_str());
LOG_INF("%s: projection_dim: %d\n", __func__, hparams.projection_dim);
- LOG_INF("\n");
- if (hparams.has_vision) {
+ if (is_vision) {
+ LOG_INF("\n--- vision hparams ---\n");
LOG_INF("%s: image_size: %d\n", __func__, hparams.image_size);
LOG_INF("%s: patch_size: %d\n", __func__, hparams.patch_size);
- LOG_INF("%s: has_llava_proj: %d\n", __func__, ctx_clip.has_llava_projector);
- LOG_INF("%s: minicpmv_version: %d\n", __func__, ctx_clip.minicpmv_version);
+ LOG_INF("%s: has_llava_proj: %d\n", __func__, hparams.has_llava_projector);
+ LOG_INF("%s: minicpmv_version: %d\n", __func__, hparams.minicpmv_version);
LOG_INF("%s: proj_scale_factor: %d\n", __func__, hparams.proj_scale_factor);
LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern);
- } else if (hparams.has_audio) {
+ } else if (is_audio) {
+ LOG_INF("\n--- audio hparams ---\n");
LOG_INF("%s: n_mel_bins: %d\n", __func__, hparams.n_mel_bins);
LOG_INF("%s: proj_stack_factor: %d\n", __func__, hparams.proj_stack_factor);
}
}
}
- void load_tensors() {
- auto & hparams = ctx_clip.vision_model.hparams;
+ void load_tensors(clip_ctx & ctx_clip) {
+ auto & model = ctx_clip.model;
+ auto & hparams = model.hparams;
std::map<std::string, size_t> tensor_offset;
std::vector<ggml_tensor *> tensors_to_load;
// TODO @ngxson : support both audio and video in the future
- const char * prefix = hparams.has_audio ? "a" : "v";
+ const char * prefix = model.modality == CLIP_MODALITY_AUDIO ? "a" : "v";
// get offsets
for (int64_t i = 0; i < gguf_get_n_tensors(ctx_gguf.get()); ++i) {
return cur;
};
- auto & vision_model = ctx_clip.vision_model; // TODO: rename this to just "model"
-
- vision_model.class_embedding = get_tensor(TN_CLASS_EMBD, false);
+ model.class_embedding = get_tensor(TN_CLASS_EMBD, false);
- vision_model.pre_ln_w = get_tensor(string_format(TN_LN_PRE, prefix, "weight"), false);
- vision_model.pre_ln_b = get_tensor(string_format(TN_LN_PRE, prefix, "bias"), false);
+ model.pre_ln_w = get_tensor(string_format(TN_LN_PRE, prefix, "weight"), false);
+ model.pre_ln_b = get_tensor(string_format(TN_LN_PRE, prefix, "bias"), false);
- vision_model.post_ln_w = get_tensor(string_format(TN_LN_POST, prefix, "weight"), false);
- vision_model.post_ln_b = get_tensor(string_format(TN_LN_POST, prefix, "bias"), false);
+ model.post_ln_w = get_tensor(string_format(TN_LN_POST, prefix, "weight"), false);
+ model.post_ln_b = get_tensor(string_format(TN_LN_POST, prefix, "bias"), false);
- vision_model.patch_bias = get_tensor(TN_PATCH_BIAS, false);
- vision_model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false);
- vision_model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false);
+ model.patch_bias = get_tensor(TN_PATCH_BIAS, false);
+ model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false);
+ model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false);
- vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, prefix), false);
+ model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, prefix), false);
// layers
- vision_model.layers.resize(hparams.n_layer);
+ model.layers.resize(hparams.n_layer);
for (int il = 0; il < hparams.n_layer; ++il) {
- auto & layer = vision_model.layers[il];
+ auto & layer = model.layers[il];
layer.k_w = get_tensor(string_format(TN_ATTN_K, prefix, il, "weight"));
layer.q_w = get_tensor(string_format(TN_ATTN_Q, prefix, il, "weight"));
layer.v_w = get_tensor(string_format(TN_ATTN_V, prefix, il, "weight"));
}
}
- switch (ctx_clip.proj_type) {
+ switch (model.proj_type) {
case PROJECTOR_TYPE_MLP:
case PROJECTOR_TYPE_MLP_NORM:
{
// LLaVA projection
- vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"), false);
- vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"), false);
+ model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"), false);
+ model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"), false);
// Yi-type llava
- vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"), false);
- vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false);
+ model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"), false);
+ model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false);
// missing in Yi-type llava
- vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"), false);
- vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
+ model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"), false);
+ model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
// Yi-type llava
- vision_model.mm_3_w = get_tensor(string_format(TN_LLAVA_PROJ, 3, "weight"), false);
- vision_model.mm_3_b = get_tensor(string_format(TN_LLAVA_PROJ, 3, "bias"), false);
- vision_model.mm_4_w = get_tensor(string_format(TN_LLAVA_PROJ, 4, "weight"), false);
- vision_model.mm_4_b = get_tensor(string_format(TN_LLAVA_PROJ, 4, "bias"), false);
- if (vision_model.mm_3_w) {
+ model.mm_3_w = get_tensor(string_format(TN_LLAVA_PROJ, 3, "weight"), false);
+ model.mm_3_b = get_tensor(string_format(TN_LLAVA_PROJ, 3, "bias"), false);
+ model.mm_4_w = get_tensor(string_format(TN_LLAVA_PROJ, 4, "weight"), false);
+ model.mm_4_b = get_tensor(string_format(TN_LLAVA_PROJ, 4, "bias"), false);
+ if (model.mm_3_w) {
// TODO: this is a hack to support Yi-type llava
- ctx_clip.proj_type = PROJECTOR_TYPE_MLP_NORM;
+ model.proj_type = PROJECTOR_TYPE_MLP_NORM;
}
- vision_model.image_newline = get_tensor(TN_IMAGE_NEWLINE, false);
+ model.image_newline = get_tensor(TN_IMAGE_NEWLINE, false);
} break;
case PROJECTOR_TYPE_LDP:
{
// MobileVLM projection
- vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
- vision_model.mm_model_mlp_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias"));
- vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
- vision_model.mm_model_mlp_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias"));
- vision_model.mm_model_block_1_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "0.weight"));
- vision_model.mm_model_block_1_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.weight"));
- vision_model.mm_model_block_1_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.bias"));
- vision_model.mm_model_block_1_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.weight"));
- vision_model.mm_model_block_1_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.bias"));
- vision_model.mm_model_block_1_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.weight"));
- vision_model.mm_model_block_1_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.bias"));
- vision_model.mm_model_block_1_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "0.weight"));
- vision_model.mm_model_block_1_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.weight"));
- vision_model.mm_model_block_1_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.bias"));
- vision_model.mm_model_block_2_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "0.weight"));
- vision_model.mm_model_block_2_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.weight"));
- vision_model.mm_model_block_2_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.bias"));
- vision_model.mm_model_block_2_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.weight"));
- vision_model.mm_model_block_2_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.bias"));
- vision_model.mm_model_block_2_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.weight"));
- vision_model.mm_model_block_2_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.bias"));
- vision_model.mm_model_block_2_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "0.weight"));
- vision_model.mm_model_block_2_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.weight"));
- vision_model.mm_model_block_2_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.bias"));
+ model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
+ model.mm_model_mlp_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias"));
+ model.mm_model_mlp_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
+ model.mm_model_mlp_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias"));
+ model.mm_model_block_1_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "0.weight"));
+ model.mm_model_block_1_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.weight"));
+ model.mm_model_block_1_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.bias"));
+ model.mm_model_block_1_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.weight"));
+ model.mm_model_block_1_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.bias"));
+ model.mm_model_block_1_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.weight"));
+ model.mm_model_block_1_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.bias"));
+ model.mm_model_block_1_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "0.weight"));
+ model.mm_model_block_1_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.weight"));
+ model.mm_model_block_1_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.bias"));
+ model.mm_model_block_2_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "0.weight"));
+ model.mm_model_block_2_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.weight"));
+ model.mm_model_block_2_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.bias"));
+ model.mm_model_block_2_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.weight"));
+ model.mm_model_block_2_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.bias"));
+ model.mm_model_block_2_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.weight"));
+ model.mm_model_block_2_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.bias"));
+ model.mm_model_block_2_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "0.weight"));
+ model.mm_model_block_2_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.weight"));
+ model.mm_model_block_2_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.bias"));
} break;
case PROJECTOR_TYPE_LDPV2:
{
// MobilVLM_V2 projection
- vision_model.mm_model_mlp_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
- vision_model.mm_model_mlp_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias"));
- vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight"));
- vision_model.mm_model_mlp_2_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "bias"));
- vision_model.mm_model_peg_0_w = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "weight"));
- vision_model.mm_model_peg_0_b = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "bias"));
+ model.mm_model_mlp_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
+ model.mm_model_mlp_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias"));
+ model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight"));
+ model.mm_model_mlp_2_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "bias"));
+ model.mm_model_peg_0_w = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "weight"));
+ model.mm_model_peg_0_b = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "bias"));
} break;
case PROJECTOR_TYPE_MINICPMV:
{
- // vision_model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD);
- vision_model.mm_model_pos_embed_k = get_tensor(TN_MINICPMV_POS_EMBD_K);
- vision_model.mm_model_query = get_tensor(TN_MINICPMV_QUERY);
- vision_model.mm_model_proj = get_tensor(TN_MINICPMV_PROJ);
- vision_model.mm_model_kv_proj = get_tensor(TN_MINICPMV_KV_PROJ);
- vision_model.mm_model_attn_q_w = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "weight"));
- vision_model.mm_model_attn_k_w = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "weight"));
- vision_model.mm_model_attn_v_w = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "weight"));
- vision_model.mm_model_attn_q_b = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "bias"));
- vision_model.mm_model_attn_k_b = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "bias"));
- vision_model.mm_model_attn_v_b = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "bias"));
- vision_model.mm_model_attn_o_w = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "weight"));
- vision_model.mm_model_attn_o_b = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "bias"));
- vision_model.mm_model_ln_q_w = get_tensor(string_format(TN_MINICPMV_LN, "q", "weight"));
- vision_model.mm_model_ln_q_b = get_tensor(string_format(TN_MINICPMV_LN, "q", "bias"));
- vision_model.mm_model_ln_kv_w = get_tensor(string_format(TN_MINICPMV_LN, "kv", "weight"));
- vision_model.mm_model_ln_kv_b = get_tensor(string_format(TN_MINICPMV_LN, "kv", "bias"));
- vision_model.mm_model_ln_post_w = get_tensor(string_format(TN_MINICPMV_LN, "post", "weight"));
- vision_model.mm_model_ln_post_b = get_tensor(string_format(TN_MINICPMV_LN, "post", "bias"));
+ // model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD);
+ model.mm_model_pos_embed_k = get_tensor(TN_MINICPMV_POS_EMBD_K);
+ model.mm_model_query = get_tensor(TN_MINICPMV_QUERY);
+ model.mm_model_proj = get_tensor(TN_MINICPMV_PROJ);
+ model.mm_model_kv_proj = get_tensor(TN_MINICPMV_KV_PROJ);
+ model.mm_model_attn_q_w = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "weight"));
+ model.mm_model_attn_k_w = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "weight"));
+ model.mm_model_attn_v_w = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "weight"));
+ model.mm_model_attn_q_b = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "bias"));
+ model.mm_model_attn_k_b = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "bias"));
+ model.mm_model_attn_v_b = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "bias"));
+ model.mm_model_attn_o_w = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "weight"));
+ model.mm_model_attn_o_b = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "bias"));
+ model.mm_model_ln_q_w = get_tensor(string_format(TN_MINICPMV_LN, "q", "weight"));
+ model.mm_model_ln_q_b = get_tensor(string_format(TN_MINICPMV_LN, "q", "bias"));
+ model.mm_model_ln_kv_w = get_tensor(string_format(TN_MINICPMV_LN, "kv", "weight"));
+ model.mm_model_ln_kv_b = get_tensor(string_format(TN_MINICPMV_LN, "kv", "bias"));
+ model.mm_model_ln_post_w = get_tensor(string_format(TN_MINICPMV_LN, "post", "weight"));
+ model.mm_model_ln_post_b = get_tensor(string_format(TN_MINICPMV_LN, "post", "bias"));
} break;
case PROJECTOR_TYPE_GLM_EDGE:
{
- vision_model.mm_model_adapter_conv_w = get_tensor(string_format(TN_GLM_ADAPER_CONV, "weight"));
- vision_model.mm_model_adapter_conv_b = get_tensor(string_format(TN_GLM_ADAPER_CONV, "bias"));
- vision_model.mm_model_mlp_0_w = get_tensor(string_format(TN_GLM_ADAPTER_LINEAR, "weight"));
- vision_model.mm_model_ln_q_w = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1, "weight"));
- vision_model.mm_model_ln_q_b = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1, "bias"));
- vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_GLM_ADAPTER_D_H_2_4H, "weight"));
- vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_GLM_ADAPTER_GATE, "weight"));
- vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H, "weight"));
- vision_model.mm_glm_tok_boi = get_tensor(string_format(TN_TOK_GLM_BOI, "weight"));
- vision_model.mm_glm_tok_eoi = get_tensor(string_format(TN_TOK_GLM_EOI, "weight"));
+ model.mm_model_adapter_conv_w = get_tensor(string_format(TN_GLM_ADAPER_CONV, "weight"));
+ model.mm_model_adapter_conv_b = get_tensor(string_format(TN_GLM_ADAPER_CONV, "bias"));
+ model.mm_model_mlp_0_w = get_tensor(string_format(TN_GLM_ADAPTER_LINEAR, "weight"));
+ model.mm_model_ln_q_w = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1, "weight"));
+ model.mm_model_ln_q_b = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1, "bias"));
+ model.mm_model_mlp_1_w = get_tensor(string_format(TN_GLM_ADAPTER_D_H_2_4H, "weight"));
+ model.mm_model_mlp_2_w = get_tensor(string_format(TN_GLM_ADAPTER_GATE, "weight"));
+ model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H, "weight"));
+ model.mm_glm_tok_boi = get_tensor(string_format(TN_TOK_GLM_BOI, "weight"));
+ model.mm_glm_tok_eoi = get_tensor(string_format(TN_TOK_GLM_EOI, "weight"));
} break;
case PROJECTOR_TYPE_QWEN2VL:
case PROJECTOR_TYPE_QWEN25VL:
{
- vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
- vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
- vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
- vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
+ model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
+ model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
+ model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
+ model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
} break;
case PROJECTOR_TYPE_GEMMA3:
{
- vision_model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ);
- vision_model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N);
+ model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ);
+ model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N);
} break;
case PROJECTOR_TYPE_IDEFICS3:
{
- vision_model.projection = get_tensor(TN_MM_PROJECTOR);
+ model.projection = get_tensor(TN_MM_PROJECTOR);
} break;
case PROJECTOR_TYPE_PIXTRAL:
{
- vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
- vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false);
- vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
- vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
+ model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
+ model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false);
+ model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
+ model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
// [IMG_BREAK] token embedding
- vision_model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK);
+ model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK);
// for mistral small 3.1
- vision_model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
- vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
+ model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
+ model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
} break;
case PROJECTOR_TYPE_ULTRAVOX:
{
- vision_model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
- vision_model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
- vision_model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
- vision_model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
- vision_model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
- vision_model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
- vision_model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight"));
- vision_model.mm_norm_mid_w = get_tensor(string_format(TN_MM_NORM_MID, "weight"));
+ model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
+ model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
+ model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
+ model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
+ model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
+ model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
+ model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight"));
+ model.mm_norm_mid_w = get_tensor(string_format(TN_MM_NORM_MID, "weight"));
} break;
case PROJECTOR_TYPE_QWEN2A:
{
- vision_model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
- vision_model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
- vision_model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
- vision_model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
- vision_model.mm_fc_w = get_tensor(string_format(TN_MM_AUDIO_FC, "weight"));
- vision_model.mm_fc_b = get_tensor(string_format(TN_MM_AUDIO_FC, "bias"));
+ model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
+ model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
+ model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
+ model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
+ model.mm_fc_w = get_tensor(string_format(TN_MM_AUDIO_FC, "weight"));
+ model.mm_fc_b = get_tensor(string_format(TN_MM_AUDIO_FC, "bias"));
} break;
case PROJECTOR_TYPE_INTERNVL:
{
- vision_model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
- vision_model.mm_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias"));
- vision_model.mm_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
- vision_model.mm_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias"));
- vision_model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
- vision_model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias"));
+ model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
+ model.mm_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias"));
+ model.mm_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
+ model.mm_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias"));
+ model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
+ model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias"));
} break;
case PROJECTOR_TYPE_LLAMA4:
{
- vision_model.mm_model_proj = get_tensor(TN_MM_PROJECTOR);
- vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
- vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight"));
+ model.mm_model_proj = get_tensor(TN_MM_PROJECTOR);
+ model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
+ model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight"));
} break;
default:
GGML_ASSERT(false && "unknown projector type");
}
}
- void alloc_compute_meta() {
- const auto & hparams = ctx_clip.vision_model.hparams;
+ void alloc_compute_meta(clip_ctx & ctx_clip) {
+ const auto & hparams = ctx_clip.model.hparams;
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
// create a fake batch
clip_image_f32_batch batch;
clip_image_f32_ptr img(clip_image_f32_init());
- if (hparams.has_vision) {
+ if (ctx_clip.model.modality == CLIP_MODALITY_VISION) {
img->nx = hparams.warmup_image_size;
img->ny = hparams.warmup_image_size;
} else {
- img->nx = 1024; // TODO @ngxson : use a better default
+ img->nx = hparams.warmup_audio_size;
img->ny = hparams.n_mel_bins;
}
- img->buf.resize(img->nx * img->ny * 3);
batch.entries.push_back(std::move(img));
ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch);
}
};
-struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_params) {
+struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_params) {
g_logger_state.verbosity_thold = ctx_params.verbosity;
- clip_ctx * ctx_clip = nullptr;
+ clip_ctx * ctx_vision = nullptr;
+ clip_ctx * ctx_audio = nullptr;
try {
- ctx_clip = new clip_ctx(ctx_params);
- clip_model_loader loader(fname, *ctx_clip);
- loader.load_hparams();
- loader.load_tensors();
- loader.alloc_compute_meta();
+ clip_model_loader loader(fname);
+
+ if (loader.has_vision) {
+ ctx_vision = new clip_ctx(ctx_params);
+ loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION);
+ loader.load_tensors(*ctx_vision);
+ loader.alloc_compute_meta(*ctx_vision);
+ }
+
+ if (loader.has_audio) {
+ ctx_audio = new clip_ctx(ctx_params);
+ loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO);
+ loader.load_tensors(*ctx_audio);
+ loader.alloc_compute_meta(*ctx_audio);
+ }
+
} catch (const std::exception & e) {
LOG_ERR("%s: failed to load model '%s': %s\n", __func__, fname, e.what());
- delete ctx_clip;
- return nullptr;
+ if (ctx_vision) {
+ delete ctx_vision;
+ }
+ if (ctx_audio) {
+ delete ctx_audio;
+ }
+ return {nullptr, nullptr};
}
- return ctx_clip;
+ return {ctx_vision, ctx_audio};
}
struct clip_image_size * clip_image_size_init() {
const float ratio = (float)original_width * original_height / (slice_size * slice_size);
const int multiple = fmin(ceil(ratio), max_slice_nums);
const bool has_slices = (multiple > 1);
- const bool has_pinpoints = !ctx->vision_model.hparams.image_grid_pinpoints.empty();
+ const bool has_pinpoints = !ctx->model.hparams.image_grid_pinpoints.empty();
if (has_pinpoints) {
// has pinpoints, use them to calculate the grid size (e.g. llava-1.6)
auto refine_size = llava_uhd::select_best_resolution(
- ctx->vision_model.hparams.image_grid_pinpoints,
+ ctx->model.hparams.image_grid_pinpoints,
original_size);
res.overview_size = clip_image_size{slice_size, slice_size};
res.refined_size = refine_size;
bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) {
clip_image_size original_size{img->nx, img->ny};
bool pad_to_square = true;
- auto & params = ctx->vision_model.hparams;
+ auto & params = ctx->model.hparams;
// The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing
if (params.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD) {
pad_to_square = false;
for (size_t i = 0; i < imgs.size(); ++i) {
// clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp");
clip_image_f32_ptr res(clip_image_f32_init());
- normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std);
+ normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std);
res_imgs->entries.push_back(std::move(res));
}
res_imgs->grid_y = inst.grid_size.height;
return true;
- } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
+ } else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) {
clip_image_u8 resized;
auto patch_size = params.patch_size * 2;
auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, patch_size, params.image_size);
clip_image_f32_ptr img_f32(clip_image_f32_init());
// clip_image_f32_ptr res(clip_image_f32_init());
- normalize_image_u8_to_f32(resized, *img_f32, ctx->image_mean, ctx->image_std);
+ normalize_image_u8_to_f32(resized, *img_f32, params.image_mean, params.image_std);
// res_imgs->data[0] = *res;
res_imgs->entries.push_back(std::move(img_f32));
return true;
}
- else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE
- || ctx->proj_type == PROJECTOR_TYPE_GEMMA3
- || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3
- || ctx->proj_type == PROJECTOR_TYPE_INTERNVL // TODO @ngxson : support dynamic resolution
+ else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE
+ || ctx->proj_type() == PROJECTOR_TYPE_GEMMA3
+ || ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3
+ || ctx->proj_type() == PROJECTOR_TYPE_INTERNVL // TODO @ngxson : support dynamic resolution
) {
clip_image_u8 resized_image;
int sz = params.image_size;
image_manipulation::resize_and_pad_image(*img, resized_image, {sz, sz});
clip_image_f32_ptr img_f32(clip_image_f32_init());
//clip_image_save_to_bmp(resized_image, "resized.bmp");
- normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std);
+ normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std);
res_imgs->entries.push_back(std::move(img_f32));
return true;
- } else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
+ } else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL) {
clip_image_u8 resized_image;
auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size);
image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height);
clip_image_f32_ptr img_f32(clip_image_f32_init());
- normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std);
+ normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std);
res_imgs->entries.push_back(std::move(img_f32));
return true;
- } else if (ctx->proj_type == PROJECTOR_TYPE_LLAMA4) {
+ } else if (ctx->proj_type() == PROJECTOR_TYPE_LLAMA4) {
GGML_ASSERT(!params.image_grid_pinpoints.empty());
auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
std::vector<clip_image_u8_ptr> imgs = llava_uhd::slice_image(img, inst);
for (size_t i = 0; i < imgs.size(); ++i) {
clip_image_f32_ptr res(clip_image_f32_init());
- normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std);
+ normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std);
res_imgs->entries.push_back(std::move(res));
}
image_manipulation::resize_and_pad_image(*img, *temp, clip_image_size{params.image_size, params.image_size}, pad_color);
clip_image_f32_ptr res(clip_image_f32_init());
- normalize_image_u8_to_f32(*temp, *res, ctx->image_mean, ctx->image_std);
+ normalize_image_u8_to_f32(*temp, *res, params.image_mean, params.image_std);
res_imgs->entries.push_back(std::move(res));
return true;
for (size_t i = 0; i < imgs.size(); ++i) {
// clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp");
clip_image_f32_ptr res(clip_image_f32_init());
- normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std);
+ normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std);
res_imgs->entries.push_back(std::move(res));
}
}
ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx) {
- return ctx->vision_model.image_newline;
+ return ctx->model.image_newline;
}
void clip_free(clip_ctx * ctx) {
// deprecated
size_t clip_embd_nbytes(const struct clip_ctx * ctx) {
- const int32_t nx = ctx->vision_model.hparams.image_size;
- const int32_t ny = ctx->vision_model.hparams.image_size;
+ const int32_t nx = ctx->model.hparams.image_size;
+ const int32_t ny = ctx->model.hparams.image_size;
return clip_embd_nbytes_by_img(ctx, nx, ny);
}
}
int32_t clip_get_image_size(const struct clip_ctx * ctx) {
- return ctx->vision_model.hparams.image_size;
+ return ctx->model.hparams.image_size;
}
int32_t clip_get_patch_size(const struct clip_ctx * ctx) {
- return ctx->vision_model.hparams.patch_size;
+ return ctx->model.hparams.patch_size;
}
int32_t clip_get_hidden_size(const struct clip_ctx * ctx) {
- return ctx->vision_model.hparams.n_embd;
+ return ctx->model.hparams.n_embd;
}
const char * clip_patch_merge_type(const struct clip_ctx * ctx) {
- return ctx->vision_model.hparams.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD ? "spatial_unpad" : "flat";
+ return ctx->model.hparams.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD ? "spatial_unpad" : "flat";
}
const int32_t * clip_image_grid(const struct clip_ctx * ctx) {
- if (ctx->vision_model.hparams.image_grid_pinpoints.size()) {
- return &ctx->vision_model.hparams.image_grid_pinpoints.front();
+ if (ctx->model.hparams.image_grid_pinpoints.size()) {
+ return &ctx->model.hparams.image_grid_pinpoints.front();
}
return nullptr;
}
size_t get_clip_image_grid_size(const struct clip_ctx * ctx) {
- return ctx->vision_model.hparams.image_grid_pinpoints.size();
+ return ctx->model.hparams.image_grid_pinpoints.size();
}
int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
- const auto & params = ctx->vision_model.hparams;
+ const auto & params = ctx->model.hparams;
const int n_total = clip_n_output_tokens(ctx, img);
- if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
+ if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) {
return img->nx / (params.patch_size * 2) + (int)(img->nx % params.patch_size > 0);
}
return n_total;
}
int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
- const auto & params = ctx->vision_model.hparams;
- if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
+ const auto & params = ctx->model.hparams;
+ if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) {
return img->ny / (params.patch_size * 2) + (int)(img->ny % params.patch_size > 0);
}
return 1;
}
int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
- const auto & params = ctx->vision_model.hparams;
+ const auto & params = ctx->model.hparams;
- int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
- int scale_factor = ctx->vision_model.hparams.proj_scale_factor;
+ // only for models using fixed size square images
+ int n_patches_sq = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
- if (ctx->proj_type == PROJECTOR_TYPE_LDP
- || ctx->proj_type == PROJECTOR_TYPE_LDPV2
- || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
- n_patches /= 4;
- if (ctx->vision_model.mm_glm_tok_boi) {
- n_patches += 2; // for BOI and EOI token embeddings
- }
- } else if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
- if (ctx->minicpmv_version == 2) {
- n_patches = 96;
- }
- else if (ctx->minicpmv_version == 3) {
- n_patches = 64;
- }
- else if (ctx->minicpmv_version == 4) {
- n_patches = 64;
- }
- else {
- GGML_ABORT("Unknown minicpmv version");
- }
- } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
- int patch_size = params.patch_size * 2;
- int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
- int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
- n_patches = x_patch * y_patch;
- } else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
- int n_per_side = params.image_size / params.patch_size;
- int n_per_side_2d_pool = n_per_side / params.proj_scale_factor;
- n_patches = n_per_side_2d_pool * n_per_side_2d_pool;
- } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3 || ctx->proj_type == PROJECTOR_TYPE_INTERNVL) {
- // both W and H are divided by proj_scale_factor
- n_patches /= (params.proj_scale_factor * params.proj_scale_factor);
- } else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
- int n_merge = params.spatial_merge_size;
- int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1);
- int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1);
- n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
- } else if (ctx->proj_type == PROJECTOR_TYPE_LLAMA4) {
- n_patches /= (scale_factor * scale_factor);
- } else if (ctx->proj_type == PROJECTOR_TYPE_ULTRAVOX) {
- const int proj_stack_factor = ctx->vision_model.hparams.proj_stack_factor;
- const int n_len = CLIP_ALIGN(img->nx, proj_stack_factor);
- n_patches = n_len / proj_stack_factor / 2;
- } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2A) {
- // divide by 2 because of whisper
- // another divide by 2 because of nn.AvgPool1d(2, stride=2)
- n_patches = img->nx / 4;
- }
-
- return n_patches;
+ projector_type proj = ctx->proj_type();
+
+ switch (proj) {
+ case PROJECTOR_TYPE_MLP:
+ case PROJECTOR_TYPE_MLP_NORM:
+ {
+ // do nothing
+ } break;
+ case PROJECTOR_TYPE_LDP:
+ case PROJECTOR_TYPE_LDPV2:
+ case PROJECTOR_TYPE_GLM_EDGE:
+ {
+ n_patches_sq /= 4;
+ if (ctx->model.mm_glm_tok_boi) {
+ n_patches_sq += 2; // for BOI and EOI token embeddings
+ }
+ } break;
+ case PROJECTOR_TYPE_MINICPMV:
+ {
+ if (params.minicpmv_version == 2) {
+ n_patches_sq = 96;
+ } else if (params.minicpmv_version == 3) {
+ n_patches_sq = 64;
+ } else if (params.minicpmv_version == 4) {
+ n_patches_sq = 64;
+ } else {
+ GGML_ABORT("Unknown minicpmv version");
+ }
+ } break;
+ case PROJECTOR_TYPE_QWEN2VL:
+ case PROJECTOR_TYPE_QWEN25VL:
+ {
+ // dynamic size
+ int patch_size = params.patch_size * 2;
+ int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
+ int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
+ n_patches_sq = x_patch * y_patch;
+ } break;
+ case PROJECTOR_TYPE_GEMMA3:
+ {
+ int n_per_side = params.image_size / params.patch_size;
+ int n_per_side_2d_pool = n_per_side / params.proj_scale_factor;
+ n_patches_sq = n_per_side_2d_pool * n_per_side_2d_pool;
+ } break;
+ case PROJECTOR_TYPE_IDEFICS3:
+ case PROJECTOR_TYPE_INTERNVL:
+ {
+ // both W and H are divided by proj_scale_factor
+ n_patches_sq /= (params.proj_scale_factor * params.proj_scale_factor);
+ } break;
+ case PROJECTOR_TYPE_PIXTRAL:
+ {
+ // dynamic size
+ int n_merge = params.spatial_merge_size;
+ int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1);
+ int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1);
+ n_patches_sq = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
+ } break;
+ case PROJECTOR_TYPE_LLAMA4:
+ {
+ int scale_factor = ctx->model.hparams.proj_scale_factor;
+ n_patches_sq /= (scale_factor * scale_factor);
+ } break;
+ case PROJECTOR_TYPE_ULTRAVOX:
+ {
+ const int proj_stack_factor = ctx->model.hparams.proj_stack_factor;
+ const int n_len = CLIP_ALIGN(img->nx, proj_stack_factor);
+ n_patches_sq = n_len / proj_stack_factor / 2;
+ } break;
+ case PROJECTOR_TYPE_QWEN2A:
+ {
+ // divide by 2 because of whisper
+ // another divide by 2 because of nn.AvgPool1d(2, stride=2)
+ n_patches_sq = img->nx / 4;
+ } break;
+ default:
+ GGML_ABORT("unsupported projector type");
+ }
+
+ return n_patches_sq;
}
static std::vector<std::vector<std::vector<float>>> get_1d_sincos_pos_embed_from_grid_new(int embed_dim, const std::vector<std::vector<float>> & pos) {
ggml_backend_sched_alloc_graph(ctx->sched.get(), gf);
// set inputs
- const auto & model = ctx->vision_model;
+ const auto & model = ctx->model;
const auto & hparams = model.hparams;
const int image_size_width = imgs.entries[0]->nx;
}
// set input per projector
- switch (ctx->proj_type) {
+ switch (ctx->model.proj_type) {
case PROJECTOR_TYPE_MINICPMV:
{
// inspired from siglip:
}
int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
- switch (ctx->proj_type) {
+ const auto & hparams = ctx->model.hparams;
+ switch (ctx->model.proj_type) {
case PROJECTOR_TYPE_LDP:
- return ctx->vision_model.mm_model_block_1_block_2_1_b->ne[0];
+ return ctx->model.mm_model_block_1_block_2_1_b->ne[0];
case PROJECTOR_TYPE_LDPV2:
- return ctx->vision_model.mm_model_peg_0_b->ne[0];
+ return ctx->model.mm_model_peg_0_b->ne[0];
case PROJECTOR_TYPE_MLP:
case PROJECTOR_TYPE_PIXTRAL:
- return ctx->vision_model.mm_2_w->ne[1];
+ return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_MLP_NORM:
- return ctx->vision_model.mm_3_b->ne[0];
+ return ctx->model.mm_3_b->ne[0];
case PROJECTOR_TYPE_MINICPMV:
- if (ctx->minicpmv_version == 2) {
+ if (hparams.minicpmv_version == 2) {
return 4096;
- } else if (ctx->minicpmv_version == 3) {
+ } else if (hparams.minicpmv_version == 3) {
return 3584;
- } else if (ctx->minicpmv_version == 4) {
+ } else if (hparams.minicpmv_version == 4) {
return 3584;
}
GGML_ABORT("Unknown minicpmv version");
case PROJECTOR_TYPE_GLM_EDGE:
- return ctx->vision_model.mm_model_mlp_3_w->ne[1];
+ return ctx->model.mm_model_mlp_3_w->ne[1];
case PROJECTOR_TYPE_QWEN2VL:
case PROJECTOR_TYPE_QWEN25VL:
- return ctx->vision_model.mm_1_b->ne[0];
+ return ctx->model.mm_1_b->ne[0];
case PROJECTOR_TYPE_GEMMA3:
- return ctx->vision_model.mm_input_proj_w->ne[0];
+ return ctx->model.mm_input_proj_w->ne[0];
case PROJECTOR_TYPE_IDEFICS3:
- return ctx->vision_model.projection->ne[1];
+ return ctx->model.projection->ne[1];
case PROJECTOR_TYPE_ULTRAVOX:
- return ctx->vision_model.mm_2_w->ne[1];
+ return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_INTERNVL:
- return ctx->vision_model.mm_3_w->ne[1];
+ return ctx->model.mm_3_w->ne[1];
case PROJECTOR_TYPE_LLAMA4:
- return ctx->vision_model.mm_model_proj->ne[1];
+ return ctx->model.mm_model_proj->ne[1];
case PROJECTOR_TYPE_QWEN2A:
- return ctx->vision_model.mm_fc_w->ne[1];
+ return ctx->model.mm_fc_w->ne[1];
default:
GGML_ABORT("Unknown projector type");
}
}
int clip_is_minicpmv(const struct clip_ctx * ctx) {
- if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
- return ctx->minicpmv_version;
+ if (ctx->proj_type() == PROJECTOR_TYPE_MINICPMV) {
+ return ctx->model.hparams.minicpmv_version;
}
return 0;
}
bool clip_is_glm(const struct clip_ctx * ctx) {
- return ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE;
+ return ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE;
}
bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
- return ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL;
+ return ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL
+ || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL;
}
bool clip_is_llava(const struct clip_ctx * ctx) {
- return ctx->has_llava_projector;
+ return ctx->model.hparams.has_llava_projector;
}
bool clip_is_gemma3(const struct clip_ctx * ctx) {
- return ctx->proj_type == PROJECTOR_TYPE_GEMMA3;
+ return ctx->proj_type() == PROJECTOR_TYPE_GEMMA3;
}
bool clip_has_vision_encoder(const struct clip_ctx * ctx) {
- return ctx->vision_model.hparams.has_vision;
+ return ctx->model.modality == CLIP_MODALITY_VISION;
}
bool clip_has_audio_encoder(const struct clip_ctx * ctx) {
- return ctx->vision_model.hparams.has_audio;
+ return ctx->model.modality == CLIP_MODALITY_AUDIO;
}
bool clip_has_whisper_encoder(const struct clip_ctx * ctx) {
- return ctx->proj_type == PROJECTOR_TYPE_ULTRAVOX || ctx->proj_type == PROJECTOR_TYPE_QWEN2A;
+ return ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX
+ || ctx->proj_type() == PROJECTOR_TYPE_QWEN2A;
}
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
//
projector_type clip_get_projector_type(const struct clip_ctx * ctx) {
- return ctx->proj_type;
+ return ctx->proj_type();
}
void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel, int n_frames, float * mel) {
}
struct mtmd_context {
- struct clip_ctx * ctx_clip;
+ struct clip_ctx * ctx_v; // vision
+ struct clip_ctx * ctx_a; // audio
const struct llama_model * text_model;
std::vector<float> image_embd_v; // image embedding vector
bool print_timings;
int n_threads;
std::string media_marker;
- bool has_vision;
- bool has_audio;
+ const int n_embd_text;
+
+ // these are not token, but strings used to mark the beginning and end of image/audio embeddings
+ std::string img_beg;
+ std::string img_end;
+ std::string aud_beg;
+ std::string aud_end;
// for llava-uhd style models, we need special tokens in-between slices
// minicpmv calls them "slices", llama 4 calls them "tiles"
text_model (text_model),
print_timings(ctx_params.print_timings),
n_threads (ctx_params.n_threads),
- media_marker (ctx_params.media_marker)
+ media_marker (ctx_params.media_marker),
+ n_embd_text (llama_model_n_embd(text_model))
{
if (std::string(ctx_params.image_marker) != MTMD_DEFAULT_IMAGE_MARKER) {
throw std::runtime_error("custom image_marker is not supported anymore, use media_marker instead");
}
+ if (media_marker.empty()) {
+ throw std::runtime_error("media_marker must not be empty");
+ }
+
clip_context_params ctx_clip_params;
ctx_clip_params.use_gpu = ctx_params.use_gpu;
ctx_clip_params.verbosity = ctx_params.verbosity;
- ctx_clip = clip_init(mmproj_fname, ctx_clip_params);
- if (!ctx_clip) {
+ auto res = clip_init(mmproj_fname, ctx_clip_params);
+ ctx_v = res.ctx_v;
+ ctx_a = res.ctx_a;
+ if (!ctx_v && !ctx_a) {
throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname));
}
- if (llama_model_n_embd(text_model) != clip_n_mmproj_embd(ctx_clip)) {
+ // if both vision and audio mmproj are present, we need to validate their n_embd
+ if (ctx_v && ctx_a) {
+ int n_embd_v = clip_n_mmproj_embd(ctx_v);
+ int n_embd_a = clip_n_mmproj_embd(ctx_a);
+ if (n_embd_v != n_embd_a) {
+ throw std::runtime_error(string_format(
+ "mismatch between vision and audio mmproj (n_embd_v = %d, n_embd_a = %d)\n",
+ n_embd_v, n_embd_a));
+ }
+ }
+
+ // since we already validate n_embd of vision and audio mmproj,
+ // we can safely assume that they are the same
+ int n_embd_clip = clip_n_mmproj_embd(ctx_v ? ctx_v : ctx_a);
+ if (n_embd_text != n_embd_clip) {
throw std::runtime_error(string_format(
"mismatch between text model (n_embd = %d) and mmproj (n_embd = %d)\n"
"hint: you may be using wrong mmproj\n",
- llama_model_n_embd(text_model), clip_n_mmproj_embd(ctx_clip)));
+ n_embd_text, n_embd_clip));
+ }
+ if (ctx_v) {
+ init_vision();
}
+ if (ctx_a) {
+ init_audio();
+ }
+ }
- has_vision = clip_has_vision_encoder(ctx_clip);
- has_audio = clip_has_audio_encoder(ctx_clip);
- use_mrope = clip_is_qwen2vl(ctx_clip);
+ void init_vision() {
+ GGML_ASSERT(ctx_v != nullptr);
+ use_mrope = clip_is_qwen2vl(ctx_v);
- projector_type proj = clip_get_projector_type(ctx_clip);
- int minicpmv_version = clip_is_minicpmv(ctx_clip);
+ projector_type proj = clip_get_projector_type(ctx_v);
+ int minicpmv_version = clip_is_minicpmv(ctx_v);
if (minicpmv_version == 2) {
// minicpmv 2.5 format:
// <image> (overview) </image><slice><image> (slice) </image><image> (slice) </image>\n ... </slice>
ov_img_first = false; // overview image is last
}
- if (clip_has_whisper_encoder(ctx_clip)) {
+ // set boi/eoi
+ if (proj == PROJECTOR_TYPE_GEMMA3) {
+ // <start_of_image> ... (image embeddings) ... <end_of_image>
+ img_beg = "<start_of_image>";
+ img_end = "<end_of_image>";
+
+ } else if (proj == PROJECTOR_TYPE_IDEFICS3) {
+ // https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215
+ img_beg = "<fake_token_around_image><global-img>";
+ img_end = "<fake_token_around_image>";
+
+ } else if (proj == PROJECTOR_TYPE_PIXTRAL) {
+ // https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
+ img_end = "[IMG_END]";
+
+ } else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL) {
+ // <|vision_start|> ... (image embeddings) ... <|vision_end|>
+ img_beg = "<|vision_start|>";
+ img_end = "<|vision_end|>";
+
+ } else if (proj == PROJECTOR_TYPE_LLAMA4) {
+ // (more details in mtmd_context constructor)
+ img_beg = "<|image_start|>";
+ img_end = "<|image_end|>";
+ LOG_WRN("%s: llama 4 vision is known to have degraded quality:\n"
+ " https://github.com/ggml-org/llama.cpp/pull/13282\n", __func__);
+
+ } else if (proj == PROJECTOR_TYPE_INTERNVL) {
+ // <img> ... (image embeddings) ... </img>
+ img_beg = "<img>";
+ img_end = "</img>";
+
+ }
+ }
+
+ void init_audio() {
+ GGML_ASSERT(ctx_a != nullptr);
+ projector_type proj = clip_get_projector_type(ctx_a);
+
+ if (clip_has_whisper_encoder(ctx_a)) {
// TODO @ngxson : check if model n_mel is 128 or 80
w_filters = whisper_precalc_filters::get_128_bins();
}
- // warning messages
- if (proj == PROJECTOR_TYPE_LLAMA4) {
- LOG_WRN("%s: llama 4 vision is known to have degraded quality:\n"
- " https://github.com/ggml-org/llama.cpp/pull/13282\n", __func__);
+ LOG_WRN("%s: audio input is in experimental stage and may have reduced quality:\n"
+ " https://github.com/ggml-org/llama.cpp/discussions/13759\n", __func__);
+
+ if (proj == PROJECTOR_TYPE_QWEN2A) {
+ // <|audio_bos|> ... (embeddings) ... <|audio_eos|>
+ aud_beg = "<|audio_bos|>";
+ aud_end = "<|audio_eos|>";
+
}
- if (has_audio) {
- LOG_WRN("%s: audio input is in experimental stage and may have reduced quality:\n"
- " https://github.com/ggml-org/llama.cpp/discussions/13759\n", __func__);
+ }
+
+ // get clip ctx based on chunk type
+ clip_ctx * get_clip_ctx(const mtmd_input_chunk * chunk) const {
+ if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
+ return ctx_v;
+ } else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
+ return ctx_a;
}
+ GGML_ABORT("unknown chunk type");
+ }
+
+ projector_type proj_type_v() const {
+ return ctx_v ? clip_get_projector_type(ctx_v) : PROJECTOR_TYPE_UNKNOWN;
+ }
+
+ projector_type proj_type_a() const {
+ return ctx_a ? clip_get_projector_type(ctx_a) : PROJECTOR_TYPE_UNKNOWN;
}
~mtmd_context() {
- clip_free(ctx_clip);
+ clip_free(ctx_a);
+ clip_free(ctx_v);
}
private:
}
}
-// copied from common_tokenize
-static std::vector<llama_token> mtmd_tokenize_text_internal(
- const struct llama_vocab * vocab,
- const std::string & text,
- bool add_special,
- bool parse_special) {
- // upper limit for the number of tokens
- int n_tokens = text.length() + 2 * add_special;
- std::vector<llama_token> result(n_tokens);
- n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
- if (n_tokens < 0) {
- result.resize(-n_tokens);
- int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
- GGML_ASSERT(check == -n_tokens);
- } else {
- result.resize(n_tokens);
- }
- return result;
-}
+struct mtmd_tokenizer {
+ mtmd_context * ctx;
+ std::vector<const mtmd_bitmap *> bitmaps;
-int32_t mtmd_tokenize(mtmd_context * ctx,
- mtmd_input_chunks * output,
+ std::string input_text;
+ bool add_special;
+ bool parse_special;
+ const llama_vocab * vocab;
+
+ mtmd_input_chunks cur;
+
+ mtmd_tokenizer(mtmd_context * ctx,
const mtmd_input_text * text,
const mtmd_bitmap ** bitmaps,
- size_t n_bitmaps) {
- auto vocab = llama_model_get_vocab(ctx->text_model);
-
- std::string prompt_modified(text->text);
- std::string marker_modified(ctx->media_marker);
- projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
-
- // for compatibility, we convert image marker to media marker
- string_replace_all(prompt_modified, MTMD_DEFAULT_IMAGE_MARKER, ctx->media_marker);
-
- // a bit hacky here, but works for now
- // for some models, we need to add prefix and suffix to the image embeddings
- if (clip_is_gemma3(ctx->ctx_clip)) {
- // gemma 3
- // <start_of_image> ... (image embeddings) ... <end_of_image>
- marker_modified = "<start_of_image>" + ctx->media_marker + "<end_of_image>";
- string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
-
- } else if (proj_type == PROJECTOR_TYPE_IDEFICS3) {
- // https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215
- marker_modified = "<fake_token_around_image><global-img>" + ctx->media_marker + "<fake_token_around_image>";
- string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
-
- } else if (proj_type == PROJECTOR_TYPE_PIXTRAL) {
- // https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
- marker_modified = ctx->media_marker + "[IMG_END]";
- string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
-
- } else if (proj_type == PROJECTOR_TYPE_QWEN2VL || proj_type == PROJECTOR_TYPE_QWEN25VL) {
- // <|vision_start|> ... (image embeddings) ... <|vision_end|>
- marker_modified = "<|vision_start|>" + ctx->media_marker + "<|vision_end|>";
- string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
-
- } else if (proj_type == PROJECTOR_TYPE_LLAMA4) {
- // (more details in mtmd_context constructor)
- marker_modified = "<|image_start|>" + ctx->media_marker + "<|image_end|>";
- string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
-
- } else if (proj_type == PROJECTOR_TYPE_INTERNVL) {
- // <img> ... (image embeddings) ... </img>
- marker_modified = "<img>" + ctx->media_marker + "</img>";
- string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
-
- } else if (proj_type == PROJECTOR_TYPE_QWEN2A) {
- // <|audio_bos|> ... (embeddings) ... <|audio_eos|>
- marker_modified = "<|audio_bos|>" + ctx->media_marker + "<|audio_eos|>";
- string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
-
- }
-
- // llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
- // for glm-edge, BOI and EOI token's embeddings are not present in the text model
-
- std::vector<std::string> parts = string_split_str(prompt_modified, ctx->media_marker);
- output->entries.clear();
- output->entries.reserve(parts.size());
-
- size_t i_bm = 0;
-
- // utility for adding raw tokens
- auto add_text_chunk = [&output](std::vector<llama_token> && tokens) {
- mtmd_input_chunk chunk{
- MTMD_INPUT_CHUNK_TYPE_TEXT,
- std::move(tokens),
- nullptr, // image tokens
- nullptr, // audio tokens
- };
- output->entries.emplace_back(std::move(chunk));
- };
+ size_t n_bitmaps) : ctx(ctx), bitmaps(bitmaps, bitmaps + n_bitmaps) {
+ add_special = text->add_special;
+ parse_special = text->parse_special;
+ input_text = text->text;
+ vocab = llama_model_get_vocab(ctx->text_model);
+
+ // for compatibility, we convert image marker to media marker
+ string_replace_all(input_text, MTMD_DEFAULT_IMAGE_MARKER, ctx->media_marker);
+ }
- // utility for splitting batch of multiple images into chunks of batch having single images
- auto split_batch_to_chunk = [&ctx](clip_image_f32_batch && batch_f32, const std::string & id) {
- std::vector<mtmd_input_chunk> chunks;
+ int32_t tokenize(mtmd_input_chunks * output) {
+ cur.entries.clear();
+ std::vector<std::string> parts = split_text(input_text, ctx->media_marker);
+ size_t i_bm = 0; // index of the current bitmap
+ for (auto & part : parts) {
+ if (part == ctx->media_marker) {
+ // this is a marker, we should add the next bitmap
+ if (i_bm >= bitmaps.size()) {
+ LOG_ERR("%s: error: number of bitmaps (%zu) does not match number of markers (%zu)\n",
+ __func__, bitmaps.size(), parts.size() - 1);
+ return 1;
+ }
+ const mtmd_bitmap * bitmap = bitmaps[i_bm++];
+ int32_t res = add_media(bitmap);
+ if (res != 0) {
+ return res;
+ }
+ } else {
+ // this is a text part, we should add it as text
+ add_text(part, parse_special);
+ }
+ }
- for (auto & entry : batch_f32.entries) {
- mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
- image_tokens->nx = clip_n_output_tokens(ctx->ctx_clip, entry.get());
- image_tokens->ny = 1;
- image_tokens->batch_f32.entries.push_back(std::move(entry));
- image_tokens->id = id;
+ if (add_special && llama_vocab_get_add_bos(vocab)) {
+ // if first chunk is text, we add BOS token to first text chunk
+ // otherwise, create a new text chunk with BOS token
+ if (!cur.entries.empty() && cur.entries[0].type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
+ // add BOS token to the beginning of first text chunk
+ cur.entries[0].tokens_text.insert(cur.entries[0].tokens_text.begin(), llama_vocab_bos(vocab));
+ } else {
+ // create a new text chunk with BOS token at the beginning
+ mtmd_input_chunk bos_chunk{
+ MTMD_INPUT_CHUNK_TYPE_TEXT,
+ {llama_vocab_bos(vocab)},
+ nullptr, // image tokens
+ nullptr, // audio tokens
+ };
+ cur.entries.insert(cur.entries.begin(), std::move(bos_chunk));
+ }
+ }
- mtmd_input_chunk chunk{
- MTMD_INPUT_CHUNK_TYPE_IMAGE,
- {}, // text tokens
- std::move(image_tokens),
- nullptr, // audio tokens
- };
- chunks.emplace_back(std::move(chunk));
+ if (add_special && llama_vocab_get_add_eos(vocab)) {
+ // if last chunk is text, we add EOS token to it
+ add_text({llama_vocab_eos(vocab)});
}
- return chunks;
- };
+ if (i_bm != bitmaps.size()) {
+ LOG_ERR("%s: error: number of bitmaps (%zu) does not match number of markers (%zu)\n",
+ __func__, bitmaps.size(), parts.size() - 1);
+ return 1;
+ }
+
+ *output = std::move(cur);
+
+ return 0;
+ }
+
+ void add_text(const std::string & txt, bool parse_special) {
+ LOG_DBG("%s: %s\n", __func__, txt.c_str());
+ auto tokens = mtmd_tokenize_text_internal(vocab, txt, /* add_special */ false, parse_special);
+ add_text(tokens);
+ }
- for (const auto & part : parts) {
- // printf("tokenizing part: %s\n", part.c_str());
- bool add_bos = &parts.front() == ∂
- auto tokens = mtmd_tokenize_text_internal(vocab, part, text->add_special && add_bos, text->parse_special);
+ void add_text(const std::vector<llama_token> & tokens) {
if (tokens.empty()) {
- continue;
+ return;
}
- mtmd_input_chunk chunk{
- MTMD_INPUT_CHUNK_TYPE_TEXT,
- std::move(tokens),
- nullptr, // image tokens
- nullptr, // audio tokens
- };
- output->entries.emplace_back(std::move(chunk));
-
- // only add image/audio tokens to middle of 2 parts
- // therefore, we skip handling image/audio if this is the last part
- if (&parts.back() == &part) {
- continue;
+ // if last entry is also a text chunk, add tokens to it instead of creating new chunk
+ if (!cur.entries.empty() && cur.entries.back().type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
+ cur.entries.back().tokens_text.insert(
+ cur.entries.back().tokens_text.end(),
+ tokens.begin(),
+ tokens.end());
+ } else {
+ mtmd_input_chunk chunk{
+ MTMD_INPUT_CHUNK_TYPE_TEXT,
+ tokens,
+ nullptr, // image tokens
+ nullptr, // audio tokens
+ };
+ cur.entries.emplace_back(std::move(chunk));
}
+ }
- if (!bitmaps[i_bm]->is_audio) {
+ int32_t add_media(const mtmd_bitmap * bitmap) {
+ if (!bitmap->is_audio) {
// handle image
- if (i_bm >= n_bitmaps) {
- LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
- return 1;
- }
-
- if (!ctx->has_vision) {
+ if (!ctx->ctx_v) {
LOG_ERR("%s: error: model does not support vision input\n", __func__);
return 2;
}
+ if (!ctx->img_beg.empty()) {
+ add_text(ctx->img_beg, true); // add image begin token
+ }
+
// convert mtmd_bitmap to clip_image_u8
clip_image_u8_ptr img_u8(clip_image_u8_init());
- img_u8->nx = bitmaps[i_bm]->nx;
- img_u8->ny = bitmaps[i_bm]->ny;
- img_u8->buf.resize(bitmaps[i_bm]->data.size());
- std::memcpy(img_u8->buf.data(), bitmaps[i_bm]->data.data(), img_u8->nx * img_u8->ny * 3);
+ img_u8->nx = bitmap->nx;
+ img_u8->ny = bitmap->ny;
+ img_u8->buf.resize(bitmap->data.size());
+ std::memcpy(img_u8->buf.data(), bitmap->data.data(), img_u8->nx * img_u8->ny * 3);
// preprocess image
clip_image_f32_batch batch_f32;
- bool ok = clip_image_preprocess(ctx->ctx_clip, img_u8.get(), &batch_f32);
+ bool ok = clip_image_preprocess(ctx->ctx_v, img_u8.get(), &batch_f32);
if (!ok) {
LOG_ERR("Unable to preprocess image\n");
return 2;
|| ctx->slice_tmpl == MTMD_SLICE_TMPL_LLAMA4
) {
// split batch into chunks of single images
- auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[i_bm]->id);
+ auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmap->id);
GGML_ASSERT(chunks.size() > 0);
auto ov_chunk = std::move(chunks.front());
// add overview image (first)
if (ctx->ov_img_first) {
if (ctx->tok_ov_img_start != LLAMA_TOKEN_NULL) {
- add_text_chunk({ctx->tok_ov_img_start});
+ add_text({ctx->tok_ov_img_start});
}
- output->entries.emplace_back(std::move(ov_chunk));
+ cur.entries.emplace_back(std::move(ov_chunk));
if (ctx->tok_ov_img_end != LLAMA_TOKEN_NULL) {
- add_text_chunk({ctx->tok_ov_img_end});
+ add_text({ctx->tok_ov_img_end});
}
}
const int n_col = batch_f32.grid_x;
const int n_row = batch_f32.grid_y;
if (ctx->tok_slices_start != LLAMA_TOKEN_NULL) {
- add_text_chunk({ctx->tok_slices_start});
+ add_text({ctx->tok_slices_start});
}
for (int y = 0; y < n_row; y++) {
for (int x = 0; x < n_col; x++) {
const bool is_last_in_row = (x == n_col - 1);
if (ctx->tok_sli_img_start != LLAMA_TOKEN_NULL) {
- add_text_chunk({ctx->tok_sli_img_start});
+ add_text({ctx->tok_sli_img_start});
}
- output->entries.emplace_back(std::move(chunks[y * n_col + x]));
+ cur.entries.emplace_back(std::move(chunks[y * n_col + x]));
if (ctx->tok_sli_img_end != LLAMA_TOKEN_NULL) {
- add_text_chunk({ctx->tok_sli_img_end});
+ add_text({ctx->tok_sli_img_end});
}
if (!is_last_in_row && ctx->tok_sli_img_mid != LLAMA_TOKEN_NULL) {
- add_text_chunk({ctx->tok_sli_img_mid});
+ add_text({ctx->tok_sli_img_mid});
}
}
if ((y != n_row - 1 || ctx->tok_row_end_trail) && ctx->tok_row_end != LLAMA_TOKEN_NULL) {
- add_text_chunk({ctx->tok_row_end});
+ add_text({ctx->tok_row_end});
}
}
if (ctx->tok_slices_end != LLAMA_TOKEN_NULL) {
- add_text_chunk({ctx->tok_slices_end});
+ add_text({ctx->tok_slices_end});
}
}
// add overview image (last)
if (!ctx->ov_img_first) {
if (ctx->tok_ov_img_start != LLAMA_TOKEN_NULL) {
- add_text_chunk({ctx->tok_ov_img_start});
+ add_text({ctx->tok_ov_img_start});
}
- output->entries.emplace_back(std::move(ov_chunk));
+ cur.entries.emplace_back(std::move(ov_chunk));
if (ctx->tok_ov_img_end != LLAMA_TOKEN_NULL) {
- add_text_chunk({ctx->tok_ov_img_end});
+ add_text({ctx->tok_ov_img_end});
}
}
} else {
size_t n_tokens = 0;
for (const auto & entry : batch_f32.entries) {
- n_tokens += clip_n_output_tokens(ctx->ctx_clip, entry.get());
+ n_tokens += clip_n_output_tokens(ctx->ctx_v, entry.get());
}
mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
if (ctx->use_mrope) {
// for Qwen2VL, we need this information for M-RoPE decoding positions
- image_tokens->nx = clip_n_output_tokens_x(ctx->ctx_clip, batch_f32.entries[0].get());
- image_tokens->ny = clip_n_output_tokens_y(ctx->ctx_clip, batch_f32.entries[0].get());
+ image_tokens->nx = clip_n_output_tokens_x(ctx->ctx_v, batch_f32.entries[0].get());
+ image_tokens->ny = clip_n_output_tokens_y(ctx->ctx_v, batch_f32.entries[0].get());
image_tokens->use_mrope_pos = true;
} else {
// other models, we only need the total number of tokens
image_tokens->ny = 1;
}
image_tokens->batch_f32 = std::move(batch_f32);
- image_tokens->id = bitmaps[i_bm]->id; // optional
+ image_tokens->id = bitmap->id; // optional
LOG_DBG("image_tokens->nx = %d\n", image_tokens->nx);
LOG_DBG("image_tokens->ny = %d\n", image_tokens->ny);
std::move(image_tokens),
nullptr, // audio tokens
};
- output->entries.emplace_back(std::move(chunk));
+ cur.entries.emplace_back(std::move(chunk));
}
- i_bm++; // move to next image
- continue;
+ if (!ctx->img_end.empty()) {
+ add_text(ctx->img_end, true); // add image end token
+ }
} else {
// handle audio
- if (i_bm >= n_bitmaps) {
- LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
- return 1;
- }
-
- if (!ctx->has_audio) {
+ if (!ctx->ctx_a) {
LOG_ERR("%s: error: model does not support audio input\n", __func__);
return 2;
}
- if (bitmaps[i_bm]->data.size() == 0) {
+ if (bitmap->data.size() == 0) {
LOG_ERR("%s: error: empty audio data\n", __func__);
return 2;
}
+ if (!ctx->aud_beg.empty()) {
+ add_text(ctx->aud_beg, true); // add audio begin token
+ }
+
// preprocess audio
GGML_ASSERT(ctx->w_filters.n_mel); // make sure we have filter preloaded
std::vector<whisper_preprocessor::whisper_mel> mel_spec_chunks;
- const float * samples = (const float *)bitmaps[i_bm]->data.data();
- size_t n_samples = bitmaps[i_bm]->data.size() / sizeof(float);
+ const float * samples = (const float *)bitmap->data.data();
+ size_t n_samples = bitmap->data.size() / sizeof(float);
bool ok = whisper_preprocessor::preprocess_audio(samples, n_samples, ctx->w_filters, mel_spec_chunks);
if (!ok) {
LOG_ERR("Unable to preprocess audio\n");
mel_f32->nx = mel_spec.n_len;
mel_f32->ny = mel_spec.n_mel;
mel_f32->buf = std::move(mel_spec.data);
- size_t n_tokens = clip_n_output_tokens(ctx->ctx_clip, mel_f32.get());
+ size_t n_tokens = clip_n_output_tokens(ctx->ctx_a, mel_f32.get());
clip_image_f32_batch batch_f32;
batch_f32.is_audio = true;
mtmd_audio_tokens_ptr audio_tokens(new mtmd_audio_tokens);
audio_tokens->n_tokens = n_tokens;
audio_tokens->batch_f32 = std::move(batch_f32);
- audio_tokens->id = bitmaps[i_bm]->id; // optional
+ audio_tokens->id = bitmap->id; // optional
LOG_DBG("audio_tokens->n_tokens = %d\n", audio_tokens->n_tokens);
nullptr, // image tokens
std::move(audio_tokens),
};
- output->entries.emplace_back(std::move(chunk));
+ cur.entries.emplace_back(std::move(chunk));
}
- i_bm++;
- continue;
+ if (!ctx->aud_end.empty()) {
+ add_text(ctx->aud_end, true); // add audio end token
+ }
}
+
+ return 0;
}
- return 0;
+ std::vector<mtmd_input_chunk> split_batch_to_chunk(clip_image_f32_batch && batch_f32, const std::string & id) {
+ std::vector<mtmd_input_chunk> chunks;
+
+ for (auto & entry : batch_f32.entries) {
+ mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
+ image_tokens->nx = clip_n_output_tokens(ctx->ctx_v, entry.get());
+ image_tokens->ny = 1;
+ image_tokens->batch_f32.entries.push_back(std::move(entry));
+ image_tokens->id = id;
+
+ mtmd_input_chunk chunk{
+ MTMD_INPUT_CHUNK_TYPE_IMAGE,
+ {}, // text tokens
+ std::move(image_tokens),
+ nullptr, // audio tokens
+ };
+ chunks.emplace_back(std::move(chunk));
+ }
+
+ return chunks;
+ }
+
+ // for example: "a <__media__> b <__media__> c" --> "a", "<__media__>", "b", "<__media__>", "c"
+ static std::vector<std::string> split_text(const std::string & input, const std::string & delimiter) {
+ std::vector<std::string> result;
+ if (input.empty()) {
+ return result;
+ }
+ size_t start = 0;
+ size_t pos = 0;
+ while ((pos = input.find(delimiter, start)) != std::string::npos) {
+ if (pos > start) {
+ result.push_back(input.substr(start, pos - start));
+ }
+ result.push_back(delimiter);
+ start = pos + delimiter.length();
+ }
+ if (start < input.length()) {
+ result.push_back(input.substr(start));
+ }
+ return result;
+ }
+
+ // copied from common_tokenize
+ static std::vector<llama_token> mtmd_tokenize_text_internal(
+ const struct llama_vocab * vocab,
+ const std::string & text,
+ bool add_special,
+ bool parse_special) {
+ // upper limit for the number of tokens
+ int n_tokens = text.length() + 2 * add_special;
+ std::vector<llama_token> result(n_tokens);
+ n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
+ if (n_tokens < 0) {
+ result.resize(-n_tokens);
+ int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
+ GGML_ASSERT(check == -n_tokens);
+ } else {
+ result.resize(n_tokens);
+ }
+ return result;
+ }
+};
+
+int32_t mtmd_tokenize(mtmd_context * ctx,
+ mtmd_input_chunks * output,
+ const mtmd_input_text * text,
+ const mtmd_bitmap ** bitmaps,
+ size_t n_bitmaps) {
+ mtmd_tokenizer tokenizer(ctx, text, bitmaps, n_bitmaps);
+ return tokenizer.tokenize(output);
}
int32_t mtmd_encode_chunk(mtmd_context * ctx, const mtmd_input_chunk * chunk) {
LOG_WRN("mtmd_encode_chunk has no effect for text chunks\n");
return 0;
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
+ if (!ctx->ctx_v) {
+ LOG_ERR("%s: model does not support vision input\n", __func__);
+ return 1;
+ }
return mtmd_encode(ctx, chunk->tokens_image.get());
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
- int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
+ if (!ctx->ctx_a) {
+ LOG_ERR("%s: model does not support audio input\n", __func__);
+ return 1;
+ }
+ int n_mmproj_embd = ctx->n_embd_text;
ctx->image_embd_v.resize(chunk->tokens_audio->n_tokens * n_mmproj_embd);
bool ok = clip_image_batch_encode(
- ctx->ctx_clip,
+ ctx->ctx_a,
ctx->n_threads,
&chunk->tokens_audio->batch_f32,
ctx->image_embd_v.data());
return ok ? 0 : 1;
}
- LOG_ERR("mtmd_encode_chunk: unknown chunk type %d\n", (int)chunk->type);
+ LOG_ERR("%s: unknown chunk type %d\n", __func__, (int)chunk->type);
return 1;
}
int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
- int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
+ clip_ctx * ctx_clip = ctx->ctx_v;
+ if (!ctx_clip) {
+ LOG_ERR("%s: this API does not support non-vision input, please use mtmd_encode_chunk instead\n", __func__);
+ return 1;
+ }
+ int n_mmproj_embd = clip_n_mmproj_embd(ctx_clip);
ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd);
bool ok = false;
- if (clip_is_llava(ctx->ctx_clip) || clip_is_minicpmv(ctx->ctx_clip) || clip_is_glm(ctx->ctx_clip)) {
+ if (clip_is_llava(ctx_clip) || clip_is_minicpmv(ctx_clip) || clip_is_glm(ctx_clip)) {
// TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode()
const auto & entries = image_tokens->batch_f32.entries;
for (size_t i = 0; i < entries.size(); i++) {
- int n_tokens_per_image = clip_n_output_tokens(ctx->ctx_clip, entries[i].get());
+ int n_tokens_per_image = clip_n_output_tokens(ctx_clip, entries[i].get());
ok = clip_image_encode(
- ctx->ctx_clip,
+ ctx_clip,
ctx->n_threads,
entries[i].get(),
ctx->image_embd_v.data() + i*n_mmproj_embd*n_tokens_per_image);
}
} else {
ok = clip_image_batch_encode(
- ctx->ctx_clip,
+ ctx_clip,
ctx->n_threads,
&image_tokens->batch_f32,
ctx->image_embd_v.data());
}
bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
- projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
- if (proj_type == PROJECTOR_TYPE_GEMMA3) {
+ if (ctx->ctx_v && clip_get_projector_type(ctx->ctx_v) == PROJECTOR_TYPE_GEMMA3) {
return true;
}
return false;
}
bool mtmd_support_vision(mtmd_context * ctx) {
- return ctx->has_vision;
+ return ctx->ctx_v != nullptr;
}
bool mtmd_support_audio(mtmd_context * ctx) {
- return ctx->has_audio;
+ return ctx->ctx_a != nullptr;
}
// these 2 helpers below use internal clip_image_u8_ptr,