From: Xuan-Son Nguyen Date: Mon, 28 Apr 2025 10:18:59 +0000 (+0200) Subject: clip : refactor set input for cgraph + fix qwen2.5vl input (#13136) X-Git-Tag: upstream/0.0.5318~110 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=5fa9e63be82225fb3249c76f39ddda3e5bdec0a3;p=pkg%2Fggml%2Fsources%2Fllama.cpp clip : refactor set input for cgraph + fix qwen2.5vl input (#13136) * clip : refactor set input for cgraph * more strict assert * minicpmv : use clip_n_mmproj_embd instead of copying the same code everywhere * split qwen2 and qwen2.5 code blocks * minor style fix --- diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 3cd27d5b..8c5d56cc 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -170,8 +170,8 @@ struct clip_hparams { std::vector image_grid_pinpoints; int32_t image_crop_resolution; std::unordered_set vision_feature_layer; - int32_t attn_window_size; - int32_t n_wa_pattern; + int32_t attn_window_size = 0; + int32_t n_wa_pattern = 0; }; struct clip_layer { @@ -325,7 +325,6 @@ struct clip_ctx { float image_std[3]; bool use_gelu = false; bool use_silu = false; - int32_t ftype = 1; gguf_context_ptr ctx_gguf; ggml_context_ptr ctx_data; @@ -776,7 +775,6 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_ const int image_size_width = imgs.entries[0]->nx; const int image_size_height = imgs.entries[0]->ny; - const bool use_mrope = ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL; const bool use_window_attn = hparams.n_wa_pattern > 0; const int n_wa_pattern = hparams.n_wa_pattern; @@ -785,10 +783,11 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_ const int patches_w = image_size_width / patch_size; const int patches_h = image_size_height / patch_size; const int num_positions = num_patches + (model.class_embedding ? 1 : 0); - const int num_position_ids = use_mrope ? num_positions * 4 : num_positions; + const int num_position_ids = num_positions * 4; // m-rope requires 4 dim per position const int hidden_size = hparams.hidden_size; const int n_head = hparams.n_head; const int d_head = hidden_size / n_head; + const int n_layer = hparams.n_layer; const float eps = hparams.eps; int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; @@ -870,7 +869,7 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_ } // loop over layers - for (int il = 0; il < ctx->max_feature_layer; il++) { + for (int il = 0; il < n_layer; il++) { struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states // rmsnorm1 @@ -1115,15 +1114,8 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) { int pos_w = image_size_width/patch_size; int pos_h = image_size_height/patch_size; - if (ctx->minicpmv_version == 2) { - pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1); - } - else if (ctx->minicpmv_version == 3) { - pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1); - } - else if (ctx->minicpmv_version == 4) { - pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1); - } + int n_output_dim = clip_n_mmproj_embd(ctx); + pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_output_dim, pos_w * pos_h, 1); ggml_set_name(pos_embed, "pos_embed"); ggml_set_input(pos_embed); } @@ -1461,23 +1453,17 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im } { // attention - int hidden_size = 4096; + int hidden_size = clip_n_mmproj_embd(ctx); const int d_head = 128; int n_head = hidden_size/d_head; int num_query = 96; if (ctx->minicpmv_version == 2) { - hidden_size = 4096; - n_head = hidden_size/d_head; num_query = 96; } else if (ctx->minicpmv_version == 3) { - hidden_size = 3584; - n_head = hidden_size/d_head; num_query = 64; } else if (ctx->minicpmv_version == 4) { - hidden_size = 3584; - n_head = hidden_size/d_head; num_query = 64; } @@ -1760,6 +1746,8 @@ struct clip_model_loader { LOG_INF("%s: projector: %s\n", __func__, proj_type.c_str()); LOG_INF("%s: has_llava_proj: %d\n", __func__, ctx_clip.has_llava_projector); LOG_INF("%s: minicpmv_version: %d\n", __func__, ctx_clip.minicpmv_version); + LOG_INF("%s: proj_scale_factor: %d\n", __func__, hparams.proj_scale_factor); + LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern); LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0); LOG_INF("%s: metadata size: %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0); } @@ -3038,15 +3026,43 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima const int patch_size = hparams.patch_size; const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); const int num_positions = num_patches + (model.class_embedding ? 1 : 0); - const int pos_w = ctx->load_image_size.width / patch_size; + const int pos_w = ctx->load_image_size.width / patch_size; const int pos_h = ctx->load_image_size.height / patch_size; const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl + auto get_inp_tensor = [&gf](const char * name) { + struct ggml_tensor * inp = ggml_graph_get_tensor(gf, name); + if (inp == nullptr) { + GGML_ABORT("Failed to get tensor %s", name); + } + if (!(inp->flags & GGML_TENSOR_FLAG_INPUT)) { + GGML_ABORT("Tensor %s is not an input tensor", name); + } + return inp; + }; + + auto set_input_f32 = [&get_inp_tensor](const char * name, std::vector & values) { + ggml_tensor * cur = get_inp_tensor(name); + GGML_ASSERT(cur->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_nelements(cur) == (int64_t)values.size()); + ggml_backend_tensor_set(cur, values.data(), 0, ggml_nbytes(cur)); + }; + + auto set_input_i32 = [&get_inp_tensor](const char * name, std::vector & values) { + ggml_tensor * cur = get_inp_tensor(name); + GGML_ASSERT(cur->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_nelements(cur) == (int64_t)values.size()); + ggml_backend_tensor_set(cur, values.data(), 0, ggml_nbytes(cur)); + }; + + // set input pixel values { - struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw"); - std::vector inp_data(ggml_nelements(inp_raw)); - float * data = inp_data.data(); + size_t nelem = 0; + for (const auto & img : imgs.entries) { + nelem += img->nx * img->ny * 3; + } + std::vector inp_raw(nelem); // layout of data (note: the channel dim is unrolled to better visualize the layout): // @@ -3065,7 +3081,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima const int n = nx * ny; for (int b = 0; b < batch_size; b++) { - float * batch_entry = data + b * (3*n); + float * batch_entry = inp_raw.data() + b * (3*n); for (int y = 0; y < ny; y++) { for (int x = 0; x < nx; x++) { size_t base_src = 3*(y * nx + x); // idx of the first channel @@ -3077,266 +3093,207 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } } } - ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw)); + set_input_f32("inp_raw", inp_raw); } - if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) { - { - // inspired from siglip: - // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit - // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316 - struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); - std::vector pos_data(ggml_nelements(positions)); - int * data = pos_data.data(); - int bucket_coords_h[1024]; - int bucket_coords_w[1024]; - for (int i = 0; i < pos_h; i++){ - bucket_coords_h[i] = std::floor(70.0*i/pos_h); - } - for (int i = 0; i < pos_w; i++){ - bucket_coords_w[i] = std::floor(70.0*i/pos_w); - } - for (int i = 0, id = 0; i < pos_h; i++){ - for (int j = 0; j < pos_w; j++){ - data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j]; + // set input per projector + switch (ctx->proj_type) { + case PROJECTOR_TYPE_MINICPMV: + { + // inspired from siglip: + // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit + // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316 + std::vector positions(pos_h * pos_w); + int bucket_coords_h[1024]; + int bucket_coords_w[1024]; + for (int i = 0; i < pos_h; i++){ + bucket_coords_h[i] = std::floor(70.0*i/pos_h); } - } - ggml_backend_tensor_set(positions, data, 0, ggml_nbytes(positions)); - } + for (int i = 0; i < pos_w; i++){ + bucket_coords_w[i] = std::floor(70.0*i/pos_w); + } + for (int i = 0, id = 0; i < pos_h; i++){ + for (int j = 0; j < pos_w; j++){ + positions[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j]; + } + } + set_input_i32("positions", positions); - { - // inspired from resampler of Qwen-VL: - // -> https://huggingface.co/Qwen/Qwen-VL/tree/main - // -> https://huggingface.co/Qwen/Qwen-VL/blob/0547ed36a86561e2e42fecec8fd0c4f6953e33c4/visual.py#L23 - struct ggml_tensor * pos_embed = ggml_graph_get_tensor(gf, "pos_embed"); - int embed_dim = 4096; - if (ctx->minicpmv_version == 2) { - embed_dim = 4096; - } - else if (ctx->minicpmv_version == 3) { - embed_dim = 3584; - } - else if (ctx->minicpmv_version == 4) { - embed_dim = 3584; - } - else { - GGML_ABORT("Unknown minicpmv version"); - } + // inspired from resampler of Qwen-VL: + // -> https://huggingface.co/Qwen/Qwen-VL/tree/main + // -> https://huggingface.co/Qwen/Qwen-VL/blob/0547ed36a86561e2e42fecec8fd0c4f6953e33c4/visual.py#L23 + int embed_dim = clip_n_mmproj_embd(ctx); - // TODO @ngxson : this is very inefficient, can we do this using ggml_sin and ggml_cos? - auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h)); + // TODO @ngxson : this is very inefficient, can we do this using ggml_sin and ggml_cos? + auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h)); - std::vector pos_data(ggml_nelements(pos_embed)); - float * data = pos_data.data(); - for(int i = 0; i < pos_w * pos_h; ++i){ - for(int j = 0; j < embed_dim; ++j){ - data[i * embed_dim + j] = pos_embed_t[i][j]; + std::vector pos_embed(embed_dim * pos_w * pos_h); + for(int i = 0; i < pos_w * pos_h; ++i){ + for(int j = 0; j < embed_dim; ++j){ + pos_embed[i * embed_dim + j] = pos_embed_t[i][j]; + } } - } - ggml_backend_tensor_set(pos_embed, data, 0, ggml_nbytes(pos_embed)); - } - } - else { - // non-minicpmv models - - if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) { - // pw * ph = number of tokens output by ViT after apply patch merger - // ipw * ipw = number of vision token been processed inside ViT - const int merge_ratio = 2; - const int pw = image_size_width / patch_size / merge_ratio; - const int ph = image_size_height / patch_size / merge_ratio; - const int ipw = image_size_width / patch_size; - const int iph = image_size_height / patch_size; - - std::vector idx (ph * pw); - std::vector inv_idx(ph * pw); - - if (use_window_attn) { - const int attn_window_size = 112; - struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx"); - struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx"); - struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask"); - - const int grid_window = attn_window_size / patch_size / merge_ratio; - int dst = 0; - // [num_vision_tokens, num_vision_tokens] attention mask tensor - std::vector mask(pow(ipw * iph, 2), std::numeric_limits::lowest()); - int mask_row = 0; - - for (int y = 0; y < ph; y += grid_window) - { - for (int x = 0; x < pw; x += grid_window) - { - const int win_h = std::min(grid_window, ph - y); - const int win_w = std::min(grid_window, pw - x); - const int dst_0 = dst; - // group all tokens belong to the same window togather (to a continue range) - for (int dy = 0; dy < win_h; dy++) { - for (int dx = 0; dx < win_w; dx++) { - const int src = (y + dy) * pw + (x + dx); - assert(src < (int)idx.size()); - assert(dst < (int)inv_idx.size()); - idx [src] = dst; - inv_idx[dst] = src; - dst++; + set_input_f32("pos_embed", pos_embed); + } break; + case PROJECTOR_TYPE_QWEN2VL: + { + const int merge_ratio = 2; + const int pw = image_size_width / patch_size; + const int ph = image_size_height / patch_size; + std::vector positions(num_positions * 4); + int ptr = 0; + for (int y = 0; y < ph; y += merge_ratio) { + for (int x = 0; x < pw; x += merge_ratio) { + for (int dy = 0; dy < 2; dy++) { + for (int dx = 0; dx < 2; dx++) { + positions[ ptr] = y + dy; + positions[ num_patches + ptr] = x + dx; + positions[2 * num_patches + ptr] = y + dy; + positions[3 * num_patches + ptr] = x + dx; + ptr++; } } - - for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) { - int row_offset = mask_row * (ipw * iph); - std::fill( - mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio), - mask.begin() + row_offset + (dst * merge_ratio * merge_ratio), - 0.0); - mask_row++; - } } } - ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx)); - ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx)); - ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask)); - } else { - std::iota(idx.begin(), idx.end(), 0); - std::iota(inv_idx.begin(), inv_idx.end(), 0); - } + set_input_i32("positions", positions); + } break; + case PROJECTOR_TYPE_QWEN25VL: + { + // pw * ph = number of tokens output by ViT after apply patch merger + // ipw * ipw = number of vision token been processed inside ViT + const int merge_ratio = 2; + const int pw = image_size_width / patch_size / merge_ratio; + const int ph = image_size_height / patch_size / merge_ratio; + const int ipw = image_size_width / patch_size; + const int iph = image_size_height / patch_size; + + std::vector idx (ph * pw); + std::vector inv_idx(ph * pw); + + if (use_window_attn) { + const int attn_window_size = 112; + const int grid_window = attn_window_size / patch_size / merge_ratio; + int dst = 0; + // [num_vision_tokens, num_vision_tokens] attention mask tensor + std::vector mask(pow(ipw * iph, 2), std::numeric_limits::lowest()); + int mask_row = 0; + + for (int y = 0; y < ph; y += grid_window) { + for (int x = 0; x < pw; x += grid_window) { + const int win_h = std::min(grid_window, ph - y); + const int win_w = std::min(grid_window, pw - x); + const int dst_0 = dst; + // group all tokens belong to the same window togather (to a continue range) + for (int dy = 0; dy < win_h; dy++) { + for (int dx = 0; dx < win_w; dx++) { + const int src = (y + dy) * pw + (x + dx); + GGML_ASSERT(src < (int)idx.size()); + GGML_ASSERT(dst < (int)inv_idx.size()); + idx [src] = dst; + inv_idx[dst] = src; + dst++; + } + } - struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); - const int mpow = merge_ratio * merge_ratio; - std::vector positions_data(ggml_nelements(positions)); - int * data = positions_data.data(); + for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) { + int row_offset = mask_row * (ipw * iph); + std::fill( + mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio), + mask.begin() + row_offset + (dst * merge_ratio * merge_ratio), + 0.0); + mask_row++; + } + } + } - int ptr = 0; - for (int y = 0; y < iph; y += merge_ratio) - { - for (int x = 0; x < ipw; x += merge_ratio) - { - for (int dy = 0; dy < 2; dy++) { - for (int dx = 0; dx < 2; dx++) { - auto remap = idx[ptr / mpow]; - remap = remap * mpow + (ptr % mpow); - - data[ remap] = y + dy; - data[ num_patches + remap] = x + dx; - data[2 * num_patches + remap] = y + dy; - data[3 * num_patches + remap] = x + dx; - ptr++; + set_input_i32("window_idx", idx); + set_input_i32("inv_window_idx", inv_idx); + set_input_f32("window_mask", mask); + } else { + for (int i = 0; i < ph * pw; i++) { + idx[i] = i; + } + } + + const int mpow = merge_ratio * merge_ratio; + std::vector positions(num_positions * 4); + + int ptr = 0; + for (int y = 0; y < iph; y += merge_ratio) { + for (int x = 0; x < ipw; x += merge_ratio) { + for (int dy = 0; dy < 2; dy++) { + for (int dx = 0; dx < 2; dx++) { + auto remap = idx[ptr / mpow]; + remap = (remap * mpow) + (ptr % mpow); + + positions[ remap] = y + dy; + positions[ num_patches + remap] = x + dx; + positions[2 * num_patches + remap] = y + dy; + positions[3 * num_patches + remap] = x + dx; + ptr++; + } } } } - } - ggml_backend_tensor_set(positions, data, 0, ggml_nbytes(positions)); - } - else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { - // do nothing - } - else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) { - // do nothing - } - else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) { - // set the 2D positions - int n_patches_per_col = image_size_width / patch_size; - std::vector pos_data(num_positions); - struct ggml_tensor * pos; - // dimension H - pos = ggml_graph_get_tensor(gf, "pos_h"); - for (int i = 0; i < num_positions; i++) { - pos_data[i] = i / n_patches_per_col; - } - ggml_backend_tensor_set(pos, pos_data.data(), 0, ggml_nbytes(pos)); - // dimension W - pos = ggml_graph_get_tensor(gf, "pos_w"); - for (int i = 0; i < num_positions; i++) { - pos_data[i] = i % n_patches_per_col; - } - ggml_backend_tensor_set(pos, pos_data.data(), 0, ggml_nbytes(pos)); - } - else { + set_input_i32("positions", positions); + } break; + case PROJECTOR_TYPE_PIXTRAL: + { + // set the 2D positions + int n_patches_per_col = image_size_width / patch_size; + std::vector pos_data(num_positions); + // dimension H + for (int i = 0; i < num_positions; i++) { + pos_data[i] = i / n_patches_per_col; + } + set_input_i32("pos_h", pos_data); + // dimension W + for (int i = 0; i < num_positions; i++) { + pos_data[i] = i % n_patches_per_col; + } + set_input_i32("pos_w", pos_data); + } break; + case PROJECTOR_TYPE_GLM_EDGE: + { // llava and other models - struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); - - int* positions_data = (int*)malloc(ggml_nbytes(positions)); + std::vector positions(num_positions); for (int i = 0; i < num_positions; i++) { - positions_data[i] = i; + positions[i] = i; } - ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); - free(positions_data); + set_input_i32("positions", positions); + } break; + case PROJECTOR_TYPE_MLP: + case PROJECTOR_TYPE_MLP_NORM: + case PROJECTOR_TYPE_LDP: + case PROJECTOR_TYPE_LDPV2: + { + // llava and other models + std::vector positions(num_positions); + for (int i = 0; i < num_positions; i++) { + positions[i] = i; + } + set_input_i32("positions", positions); - if (ctx->proj_type != PROJECTOR_TYPE_GLM_EDGE) { - struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches"); // The patches vector is used to get rows to index into the embeds with; // we should skip dim 0 only if we have CLS to avoid going out of bounds // when retrieving the rows. int patch_offset = model.class_embedding ? 1 : 0; - int* patches_data = (int*)malloc(ggml_nbytes(patches)); + std::vector patches(num_patches); for (int i = 0; i < num_patches; i++) { - patches_data[i] = i + patch_offset; + patches[i] = i + patch_offset; } - ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches)); - free(patches_data); - } - } - } - - if (use_window_attn && (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL)) { - struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx"); - struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx"); - struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask"); - - const int merge_ratio = 2; - const int attn_window_size = 112; - const int pw = image_size_width / patch_size / merge_ratio; - const int ph = image_size_height / patch_size / merge_ratio; - const int grid_window = attn_window_size / patch_size / merge_ratio; - const int ipw = image_size_width / patch_size; - const int iph = image_size_height / patch_size; - /* - pw * ph = number of tokens output by ViT after apply patch merger - ipw * ipw = number of vision token been processed inside ViT - */ - - std::vector idx(ph * pw); - std::vector inv_idx(ph * pw); - int dst = 0; - // [num_vision_tokens, num_vision_tokens] attention mask tensor - std::vector mask(pow(ipw * iph, 2), std::numeric_limits::lowest()); - int mask_row = 0; - - for (int y = 0; y < ph; y+=grid_window) - { - for (int x = 0; x < pw; x+=grid_window) + set_input_i32("patches", patches); + } break; + case PROJECTOR_TYPE_GEMMA3: + case PROJECTOR_TYPE_IDEFICS3: { - const int win_h = std::min(grid_window, ph - y); - const int win_w = std::min(grid_window, pw - x); - const int dst_0 = dst; - // group all tokens belong to the same window togather (to a continue range) - for (int dy = 0; dy < win_h; dy++) { - for (int dx = 0; dx < win_w; dx++) { - const int src = (y + dy) * pw + (x + dx); - assert(src < (int)idx.size()); - assert(dst < (int)inv_idx.size()); - idx[src] = dst; - inv_idx[dst] = src; - dst++; - } - } - - for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) { - int row_offset = mask_row * (ipw * iph); - std::fill( - mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio), - mask.begin() + row_offset + (dst * merge_ratio * merge_ratio), - 0.0); - mask_row++; - } - } - } - - ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx)); - ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx)); - ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask)); + // do nothing + } break; + default: + GGML_ABORT("Unknown projector type"); } ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads); @@ -3537,7 +3494,7 @@ bool clip_is_glm(const struct clip_ctx * ctx) { } bool clip_is_qwen2vl(const struct clip_ctx * ctx) { - return ctx->proj_type == PROJECTOR_TYPE_QWEN2VL; + return ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL; } bool clip_is_llava(const struct clip_ctx * ctx) {