#include <cinttypes>
#include <limits>
#include <array>
+#include <numeric>
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
std::vector<int32_t> image_grid_pinpoints;
int32_t image_crop_resolution;
std::unordered_set<int32_t> vision_feature_layer;
+ int32_t attn_window_size;
+ int32_t n_wa_pattern;
};
struct clip_layer {
struct ggml_tensor * ff_down_w = nullptr;
struct ggml_tensor * ff_down_b = nullptr;
+ struct ggml_tensor * ff_g_w = NULL;
+ struct ggml_tensor * ff_g_b = NULL;
+
// layernorm 2
struct ggml_tensor * ln_2_w = nullptr;
struct ggml_tensor * ln_2_b = nullptr;
float image_std[3];
bool use_gelu = false;
bool use_silu = false;
+ int32_t ftype = 1;
gguf_context_ptr ctx_gguf;
ggml_context_ptr ctx_data;
return gf;
}
+static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
+ const auto & model = ctx->vision_model;
+ const auto & hparams = model.hparams;
+
+ const int image_size_width = imgs.entries[0]->nx;
+ const int image_size_height = imgs.entries[0]->ny;
+
+ const bool use_mrope = ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL;
+ const bool use_window_attn = hparams.n_wa_pattern > 0;
+
+ const int n_wa_pattern = hparams.n_wa_pattern;
+ const int patch_size = hparams.patch_size;
+ const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
+ const int patches_w = image_size_width / patch_size;
+ const int patches_h = image_size_height / patch_size;
+ const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
+ const int num_position_ids = use_mrope ? num_positions * 4 : num_positions;
+ const int hidden_size = hparams.hidden_size;
+ const int n_head = hparams.n_head;
+ const int d_head = hidden_size / n_head;
+ const float eps = hparams.eps;
+
+ int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
+
+ const int batch_size = imgs.entries.size();
+ GGML_ASSERT(batch_size == 1);
+
+ struct ggml_init_params params = {
+ /*.mem_size =*/ ctx->buf_compute_meta.size(),
+ /*.mem_buffer =*/ ctx->buf_compute_meta.data(),
+ /*.no_alloc =*/ true,
+ };
+
+ ggml_context_ptr ctx0_ptr(ggml_init(params));
+ auto ctx0 = ctx0_ptr.get();
+
+ struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+ struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size);
+ ggml_set_name(inp_raw, "inp_raw");
+ ggml_set_input(inp_raw);
+
+ struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
+
+ GGML_ASSERT(image_size_width % (patch_size * 2) == 0);
+ GGML_ASSERT(image_size_height % (patch_size * 2) == 0);
+
+ auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
+ inp = ggml_add(ctx0, inp, inp_1);
+
+ inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b]
+ inp = ggml_reshape_4d(
+ ctx0, inp,
+ hidden_size * 2, patches_w / 2, patches_h, batch_size);
+ inp = ggml_reshape_4d(
+ ctx0, inp,
+ hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
+ inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3));
+ inp = ggml_reshape_3d(
+ ctx0, inp,
+ hidden_size, patches_w * patches_h, batch_size);
+
+ if (model.patch_bias) {
+ // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp));
+ inp = ggml_add(ctx0, inp, model.patch_bias);
+ }
+ struct ggml_tensor * embeddings = inp;
+ struct ggml_tensor * window_mask = nullptr;
+ struct ggml_tensor * window_idx = nullptr;
+ struct ggml_tensor * inv_window_idx = nullptr;
+
+ struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
+ ggml_set_name(positions, "positions");
+ ggml_set_input(positions);
+
+ // pre-layernorm
+ if (model.pre_ln_w) {
+ embeddings = ggml_rms_norm(ctx0, embeddings, eps);
+ ggml_set_name(embeddings, "pre_ln");
+
+ embeddings = ggml_mul(ctx0, embeddings, model.pre_ln_w);
+ }
+
+ if (use_window_attn) {
+ // handle window attention inputs
+ inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4);
+ ggml_set_name(inv_window_idx, "inv_window_idx");
+ ggml_set_input(inv_window_idx);
+ // mask for window attention
+ window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, num_positions, num_positions);
+ ggml_set_name(window_mask, "window_mask");
+ ggml_set_input(window_mask);
+
+ // embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
+ GGML_ASSERT(batch_size == 1);
+ embeddings = ggml_reshape_2d(ctx0, embeddings, hidden_size * 4, patches_w * patches_h * batch_size / 4);
+ embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx);
+ embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, patches_w * patches_h, batch_size);
+ }
+
+ // loop over layers
+ for (int il = 0; il < ctx->max_feature_layer; il++) {
+ struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
+
+ // rmsnorm1
+ cur = ggml_rms_norm(ctx0, cur, eps);
+ cur = ggml_mul(ctx0, cur, model.layers[il].ln_1_w);
+
+ // self-attention
+ {
+
+ struct ggml_tensor * Q =
+ ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
+
+ Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
+ Q = ggml_rope_multi(
+ ctx0, Q, positions, nullptr,
+ d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
+ Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
+ Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);
+
+ struct ggml_tensor * K =
+ ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
+
+ K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
+ K = ggml_rope_multi(
+ ctx0, K, positions, nullptr,
+ d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
+ K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
+ K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
+
+ struct ggml_tensor * V =
+ ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b);
+
+ V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
+ V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
+ V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
+
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+ const bool full_attn = use_window_attn ? (il + 1) % n_wa_pattern == 0 : true;
+ if (full_attn) {
+ KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
+ } else {
+ KQ = ggml_soft_max_ext(ctx0, KQ, window_mask, 1.0f / sqrtf((float)d_head), 0.0f);
+ }
+
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
+ KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
+ KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+
+ cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size);
+ }
+
+ // attention output
+ cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b);
+
+ // re-add the layer input, e.g., residual
+ cur = ggml_add(ctx0, cur, embeddings);
+
+ embeddings = cur; // embeddings = residual, cur = hidden_states
+
+ // rms norm2
+ cur = ggml_rms_norm(ctx0, cur, eps);
+ cur = ggml_mul(ctx0, cur, model.layers[il].ln_2_w);
+
+ // mlp
+ // ffn_up
+ auto cur_up = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
+ cur_up = ggml_add(ctx0, cur_up, model.layers[il].ff_o_b);
+
+ auto cur_gate = ggml_mul_mat(ctx0, model.layers[il].ff_g_w, cur);
+ cur_gate = ggml_add(ctx0, cur_gate, model.layers[il].ff_g_b);
+ // TODO : only 2 of these 3 are actually used, should we remove one of them?
+ if (ctx->use_gelu) {
+ cur_gate = ggml_gelu_inplace(ctx0, cur_gate);
+ } else if (ctx->use_silu) {
+ cur_gate = ggml_silu_inplace(ctx0, cur_gate);
+ } else {
+ cur_gate = ggml_gelu_quick_inplace(ctx0, cur_gate);
+ }
+ cur = ggml_mul(ctx0, cur_gate, cur_up);
+
+ // ffn_down
+ cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
+ cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
+
+ // residual 2
+ cur = ggml_add(ctx0, embeddings, cur);
+
+ embeddings = cur;
+ }
+
+ // post-layernorm
+ if (model.post_ln_w) {
+ embeddings = ggml_rms_norm(ctx0, embeddings, eps);
+ ggml_set_name(embeddings, "post_ln");
+
+ embeddings = ggml_mul(ctx0, embeddings, model.post_ln_w);
+ }
+
+ embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
+
+ embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
+ embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
+
+ // GELU activation
+ embeddings = ggml_gelu(ctx0, embeddings);
+
+ // Second linear layer
+ embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
+ embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
+
+ if (use_window_attn) {
+ window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4);
+ ggml_set_name(window_idx, "window_idx");
+ ggml_set_input(window_idx);
+
+ // embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
+ GGML_ASSERT(batch_size == 1);
+ embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4);
+ embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
+ embeddings = ggml_reshape_3d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4, batch_size);
+ }
+
+ // build the graph
+ ggml_build_forward_expand(gf, embeddings);
+
+ return gf;
+}
+
static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
const auto & model = ctx->vision_model;
const auto & hparams = model.hparams;
GGML_ASSERT(imgs.entries.size() == 1);
res = clip_image_build_graph_pixtral(ctx, *imgs.entries[0]);
} break;
+ case PROJECTOR_TYPE_QWEN25VL:
+ {
+ res = clip_image_build_graph_qwen25vl(ctx, imgs);
+ } break;
default:
{
// TODO: we should have one build_* function per model
{
hparams.rope_theta = 10000.0f;
} break;
+ case PROJECTOR_TYPE_QWEN25VL:
+ {
+ get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern);
+ } break;
default:
break;
}
// legacy naming (the in and out is reversed! don't ask me why)
layer.ff_i_w = layer.ff_down_w;
layer.ff_o_w = layer.ff_up_w;
+ layer.ff_g_w = layer.ff_gate_w;
layer.ff_i_b = layer.ff_down_b;
layer.ff_o_b = layer.ff_up_b;
+ layer.ff_g_b = layer.ff_gate_b;
}
switch (ctx_clip.proj_type) {
vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H,"weight"));
} break;
case PROJECTOR_TYPE_QWEN2VL:
+ case PROJECTOR_TYPE_QWEN25VL:
{
vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
else {
GGML_ABORT("Unknown minicpmv version");
}
- } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
+ } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
int patch_size = params.patch_size * 2;
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
const int pos_w = ctx->load_image_size.width / patch_size;
const int pos_h = ctx->load_image_size.height / patch_size;
+ const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl
+
{
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
std::vector<float> inp_data(ggml_nelements(inp_raw));
// non-minicpmv models
if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
- struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
+ // pw * ph = number of tokens output by ViT after apply patch merger
+ // ipw * ipw = number of vision token been processed inside ViT
+ const int merge_ratio = 2;
+ const int pw = image_size_width / patch_size / merge_ratio;
+ const int ph = image_size_height / patch_size / merge_ratio;
+ const int ipw = image_size_width / patch_size;
+ const int iph = image_size_height / patch_size;
+
+ std::vector<int> idx (ph * pw);
+ std::vector<int> inv_idx(ph * pw);
+
+ if (use_window_attn) {
+ const int attn_window_size = 112;
+ struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx");
+ struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
+ struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask");
+
+ const int grid_window = attn_window_size / patch_size / merge_ratio;
+ int dst = 0;
+ // [num_vision_tokens, num_vision_tokens] attention mask tensor
+ std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest());
+ int mask_row = 0;
+
+ for (int y = 0; y < ph; y += grid_window)
+ {
+ for (int x = 0; x < pw; x += grid_window)
+ {
+ const int win_h = std::min(grid_window, ph - y);
+ const int win_w = std::min(grid_window, pw - x);
+ const int dst_0 = dst;
+ // group all tokens belong to the same window togather (to a continue range)
+ for (int dy = 0; dy < win_h; dy++) {
+ for (int dx = 0; dx < win_w; dx++) {
+ const int src = (y + dy) * pw + (x + dx);
+ assert(src < (int)idx.size());
+ assert(dst < (int)inv_idx.size());
+ idx [src] = dst;
+ inv_idx[dst] = src;
+ dst++;
+ }
+ }
- const int pw = image_size_width / patch_size;
- const int ph = image_size_height / patch_size;
- int* positions_data = (int*)malloc(ggml_nbytes(positions));
+ for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) {
+ int row_offset = mask_row * (ipw * iph);
+ std::fill(
+ mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio),
+ mask.begin() + row_offset + (dst * merge_ratio * merge_ratio),
+ 0.0);
+ mask_row++;
+ }
+ }
+ }
+
+ ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx));
+ ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
+ ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask));
+ } else {
+ std::iota(idx.begin(), idx.end(), 0);
+ std::iota(inv_idx.begin(), inv_idx.end(), 0);
+ }
+
+ struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
+ const int mpow = merge_ratio * merge_ratio;
+ std::vector<int> positions_data(ggml_nelements(positions));
+ int * data = positions_data.data();
int ptr = 0;
- for (int y = 0; y < ph; y+=2)
+ for (int y = 0; y < iph; y += merge_ratio)
{
- for (int x = 0; x < pw; x+=2)
+ for (int x = 0; x < ipw; x += merge_ratio)
{
for (int dy = 0; dy < 2; dy++) {
for (int dx = 0; dx < 2; dx++) {
- positions_data[ptr] = y + dy;
- positions_data[num_patches + ptr] = x + dx;
- positions_data[num_patches * 2 + ptr] = y + dy;
- positions_data[num_patches * 3 + ptr] = x + dx;
+ auto remap = idx[ptr / mpow];
+ remap = remap * mpow + (ptr % mpow);
+
+ data[ remap] = y + dy;
+ data[ num_patches + remap] = x + dx;
+ data[2 * num_patches + remap] = y + dy;
+ data[3 * num_patches + remap] = x + dx;
ptr++;
}
}
}
}
- ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
- free(positions_data);
+ ggml_backend_tensor_set(positions, data, 0, ggml_nbytes(positions));
}
else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
// do nothing
}
}
+ if (use_window_attn && ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
+ struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx");
+ struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
+ struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask");
+
+ const int merge_ratio = 2;
+ const int attn_window_size = 112;
+ const int pw = image_size_width / patch_size / merge_ratio;
+ const int ph = image_size_height / patch_size / merge_ratio;
+ const int grid_window = attn_window_size / patch_size / merge_ratio;
+ const int ipw = image_size_width / patch_size;
+ const int iph = image_size_height / patch_size;
+ /*
+ pw * ph = number of tokens output by ViT after apply patch merger
+ ipw * ipw = number of vision token been processed inside ViT
+ */
+
+ std::vector<int> idx(ph * pw);
+ std::vector<int> inv_idx(ph * pw);
+ int dst = 0;
+ // [num_vision_tokens, num_vision_tokens] attention mask tensor
+ std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest());
+ int mask_row = 0;
+
+ for (int y = 0; y < ph; y+=grid_window)
+ {
+ for (int x = 0; x < pw; x+=grid_window)
+ {
+ const int win_h = std::min(grid_window, ph - y);
+ const int win_w = std::min(grid_window, pw - x);
+ const int dst_0 = dst;
+ // group all tokens belong to the same window togather (to a continue range)
+ for (int dy = 0; dy < win_h; dy++) {
+ for (int dx = 0; dx < win_w; dx++) {
+ const int src = (y + dy) * pw + (x + dx);
+ assert(src < (int)idx.size());
+ assert(dst < (int)inv_idx.size());
+ idx[src] = dst;
+ inv_idx[dst] = src;
+ dst++;
+ }
+ }
+
+ for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) {
+ int row_offset = mask_row * (ipw * iph);
+ std::fill(
+ mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio),
+ mask.begin() + row_offset + (dst * merge_ratio * merge_ratio),
+ 0.0);
+ mask_row++;
+ }
+ }
+ }
+
+ ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx));
+ ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
+ ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask));
+ }
+
ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads);
auto status = ggml_backend_sched_graph_compute(ctx->sched.get(), gf);
case PROJECTOR_TYPE_GLM_EDGE:
return ctx->vision_model.mm_model_mlp_3_w->ne[1];
case PROJECTOR_TYPE_QWEN2VL:
+ case PROJECTOR_TYPE_QWEN25VL:
return ctx->vision_model.mm_1_b->ne[0];
case PROJECTOR_TYPE_GEMMA3:
return ctx->vision_model.mm_input_proj_w->ne[0];
import argparse
-from typing import Dict
+from typing import Dict, List, Optional
import torch
import numpy as np
from gguf import *
from transformers import (
- Qwen2VLForConditionalGeneration,
- Qwen2VLProcessor,
AutoProcessor,
- Qwen2VLConfig
+ Qwen2VLConfig,
+ Qwen2VLProcessor,
+ Qwen2VLForConditionalGeneration,
+ Qwen2_5_VLConfig, # type: ignore[reportAttributeAccessIssue]
+ Qwen2_5_VLForConditionalGeneration, # type: ignore[reportAttributeAccessIssue]
)
return raw_key.format(arch=arch)
-def to_gguf_name(name: str) -> str:
- og = name
- name = name.replace("text_model", "t").replace("vision_model", "v")
- name = name.replace("blocks", "blk").replace("embeddings.", "")
- name = name.replace("attn.", "attn_")
- name = name.replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("proj.", "out.")
- # name = name.replace("layrnorm", "ln").replace("layer_norm", "ln").replace("layernorm", "ln")
- name = name.replace("norm1", "ln1").replace("norm2", "ln2")
- name = name.replace("merger.mlp", 'mm')
- print(f"[to_gguf_name] {og} --> {name}")
- return name
-
-
-def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]:
- vision_model = qwen2vl.visual
- tensor_map = {}
- for name, ten in vision_model.state_dict().items():
- ten = ten.numpy()
- if 'qkv' in name:
- if ten.ndim == 2: # weight
- c3, _ = ten.shape
- else: # bias
- c3 = ten.shape[0]
- assert c3 % 3 == 0
- c = c3 // 3
- wq = ten[:c]
- wk = ten[c: c * 2]
- wv = ten[c * 2:]
- tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "q")] = wq
- tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "k")] = wk
- tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "v")] = wv
- elif 'merger' in name:
- if name.endswith("ln_q.weight"):
- tensor_map['v.post_ln.weight'] = ten
- elif name.endswith("ln_q.bias"):
- tensor_map['v.post_ln.bias'] = ten
+def get_n_wa_pattern(fullatt_block_indexes: Optional[List[int]]):
+ if fullatt_block_indexes is None:
+ return 0
+ n_wa = fullatt_block_indexes[0]
+ for a, b in zip(fullatt_block_indexes, fullatt_block_indexes[1:]):
+ if b - a - 1 != n_wa:
+ raise ValueError(
+ f"window/full attention layer should have fix pattern of "
+ f"for each full-attention layer followed by {n_wa} window-attention layers"
+ )
+ return n_wa + 1
+
+
+class VL2:
+
+ @staticmethod
+ def to_gguf_name(name: str) -> str:
+ og = name
+ name = name.replace("text_model", "t").replace("vision_model", "v")
+ name = name.replace("blocks", "blk").replace("embeddings.", "")
+ name = name.replace("attn.", "attn_")
+ name = name.replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("proj.", "out.")
+ # name = name.replace("layrnorm", "ln").replace("layer_norm", "ln").replace("layernorm", "ln")
+ name = name.replace("norm1", "ln1").replace("norm2", "ln2")
+ name = name.replace("merger.mlp", 'mm')
+ print(f"[to_gguf_name] {og} --> {name}")
+ return name
+
+ @classmethod
+ def find_vision_tensors(cls, qwen2vl, dtype) -> Dict[str, np.ndarray]:
+ vision_model = qwen2vl.visual
+ tensor_map = {}
+ for name, ten in vision_model.state_dict().items():
+ ten = ten.numpy()
+ if 'qkv' in name:
+ if ten.ndim == 2: # weight
+ c3, _ = ten.shape
+ else: # bias
+ c3 = ten.shape[0]
+ assert c3 % 3 == 0
+ c = c3 // 3
+ wq = ten[:c]
+ wk = ten[c: c * 2]
+ wv = ten[c * 2:]
+ tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "q")] = wq
+ tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "k")] = wk
+ tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "v")] = wv
+ elif 'merger' in name:
+ if name.endswith("ln_q.weight"):
+ tensor_map['v.post_ln.weight'] = ten
+ elif name.endswith("ln_q.bias"):
+ tensor_map['v.post_ln.bias'] = ten
+ else:
+ # "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias"
+ tensor_map[cls.to_gguf_name(name)] = ten
+ elif 'patch_embed.proj.weight' in name:
+ # NOTE: split Conv3D into Conv2Ds
+ c1, c2, kt, kh, kw = ten.shape
+ assert kt == 2, "Current implmentation only support temporal_patch_size of 2"
+ tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...]
+ tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...]
else:
- # "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias"
- tensor_map[to_gguf_name(name)] = ten
- elif 'patch_embed.proj.weight' in name:
- # NOTE: split Conv3D into Conv2Ds
- c1, c2, kt, kh, kw = ten.shape
- assert kt == 2, "Current implmentation only support temporal_patch_size of 2"
- tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...]
- tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...]
- else:
- tensor_map[to_gguf_name(f"vision_model.{name}")] = ten
-
- for new_name, ten in tensor_map.items():
- if ten.ndim <= 1 or new_name.endswith("_norm.weight"):
- tensor_map[new_name] = ten.astype(np.float32)
- else:
- tensor_map[new_name] = ten.astype(dtype)
- tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32) # dummy tensor, just here as a placeholder
- return tensor_map
+ tensor_map[cls.to_gguf_name(f"vision_model.{name}")] = ten
+
+ for new_name, ten in tensor_map.items():
+ if ten.ndim <= 1 or new_name.endswith("_norm.weight"):
+ tensor_map[new_name] = ten.astype(np.float32)
+ else:
+ tensor_map[new_name] = ten.astype(dtype)
+ tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32) # dummy tensor, just here as a placeholder
+ return tensor_map
+
+
+class VL25(VL2):
+
+ @staticmethod
+ def to_gguf_name(name: str) -> str:
+ og = name
+ name = name.replace("text_model", "t").replace("vision_model", "v")
+ name = name.replace("blocks", "blk").replace("embeddings.", "")
+ name = name.replace("attn.", "attn_")
+ name = name.replace("mlp.down_proj", "ffn_down").replace("mlp.up_proj", "ffn_up")
+ name = name.replace("mlp.gate_proj", "ffn_gate").replace("proj.", "out.")
+ name = name.replace("norm1", "ln1").replace("norm2", "ln2")
+ name = name.replace("merger.mlp", 'mm')
+ print(f"[vl25][to_gguf_name] {og} --> {name}")
+ return name
def main(args):
np_dtype = np.float32
ftype = 0
elif args.data_type == 'fp16':
- dtype = torch.float32
+ dtype = torch.float16
np_dtype = np.float16
ftype = 1
else:
model_path = ""
model_name = args.model_name
print("model_name: ", model_name)
- qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained(
- model_name, torch_dtype=dtype, device_map="cpu"
- )
- cfg: Qwen2VLConfig = qwen2vl.config # type: ignore[reportAssignmentType]
- vcfg = cfg.vision_config
+ if args.model_type == "qwen2vl":
+ qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained(
+ model_name, torch_dtype=dtype, device_map="cpu"
+ )
+ cfg: Qwen2VLConfig = qwen2vl.config # type: ignore[reportAssignmentType]
+ vcfg = cfg.vision_config
+ else:
+ qwen2vl = Qwen2_5_VLForConditionalGeneration.from_pretrained(
+ model_name, torch_dtype=dtype, device_map="cpu"
+ )
+ cfg: Qwen2_5_VLConfig = qwen2vl.config # type: ignore[reportAssignmentType]
+ vcfg = cfg.vision_config
if os.path.isdir(model_name):
local_model = True
fout.add_bool("clip.has_text_encoder", False)
fout.add_bool("clip.has_vision_encoder", True)
fout.add_bool("clip.has_qwen2vl_merger", True)
- fout.add_string("clip.projector_type", "qwen2vl_merger")
print(cfg.vision_config)
if 'silu' in cfg.vision_config.hidden_act.lower():
else:
raise ValueError()
- tensor_map = find_vision_tensors(qwen2vl, np_dtype)
+ if args.model_type == "qwen2.5vl":
+ fout.add_uint32("clip.vision.n_wa_pattern", get_n_wa_pattern(vcfg.fullatt_block_indexes))
+ fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.hidden_size)
+ fout.add_uint32("clip.vision.projection_dim", vcfg.out_hidden_size)
+ fout.add_string("clip.projector_type", "qwen2.5vl_merger")
+ else:
+ fout.add_string("clip.projector_type", "qwen2vl_merger")
+ fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim)
+ fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size)
+
+ if args.model_type == "qwen2.5vl":
+ tensor_map = VL25.find_vision_tensors(qwen2vl, np_dtype)
+ else:
+ tensor_map = VL2.find_vision_tensors(qwen2vl, np_dtype)
for name, data in tensor_map.items():
fout.add_tensor(name, data)
fout.add_uint32("clip.vision.patch_size", vcfg.patch_size)
fout.add_uint32("clip.vision.image_size", 14 * 40) # some reasonable size that is divable by (14*2)
- fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim)
- fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size)
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), vcfg.num_heads)
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), vcfg.depth)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct")
+ parser.add_argument("--model_type", nargs='?', choices=['qwen2vl', 'qwen2.5vl'], default="qwen2vl")
parser.add_argument("--data_type", nargs='?', choices=['fp32', 'fp16'], default="fp32")
args = parser.parse_args()
main(args)