]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
clip : Add Qwen2.5VL support (#12402)
authorHimariO <redacted>
Sun, 27 Apr 2025 08:10:34 +0000 (16:10 +0800)
committerGitHub <redacted>
Sun, 27 Apr 2025 08:10:34 +0000 (10:10 +0200)
* implment vision model architecture, gguf convertor

* handle window attention inputs

* add debug utils

* fix few incorrect tensor memory layout

* move position id remap out of ggml to avoid int32 cuda operations

* cleaning up

* ignore transformers Qwen2_5_xxx type check

* remove not so often use `qwen2vl-cli` debug functions

* remove commented-out code blocks

* fix attn weight scaling after rebase

* add `PROJECTOR_TYPE_QWEN2_5_VL`

* remove `KEY_USE_GLU_MLP`, `KEY_USE_RMS_NORM`

* replace `KEY_FULLATTN_BLK_IDX` with `KEY_WIN_ATTN_PATTERN`

* remove `attn_window_size` from gguf

* fix model conversion

* clean up

* fix merging problem

* add test

---------

Co-authored-by: Xuan Son Nguyen <redacted>
convert_hf_to_gguf.py
examples/llava/clip-impl.h
examples/llava/clip.cpp
examples/llava/qwen2_vl_surgery.py
examples/llava/qwen2vl-cli.cpp
examples/llava/tests.sh

index cf35fb86ecfec4df3d154e97e26623946e1de129..ea3a951b93753d2372f9b15ef22af2eb2074beda 100755 (executable)
@@ -2554,11 +2554,12 @@ class Qwen2VLModel(TextModel):
         except FileNotFoundError:
             self._set_vocab_gpt2()
 
-    def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
-        for name, data in super().get_tensors():
-            if name.startswith("visual."):
-                continue
-            yield name, data
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        del bid  # unused
+        if name.startswith("visual."):
+            # skip visual tensors
+            return []
+        return [(self.map_tensor_name(name), data_torch)]
 
 
 @ModelBase.register("WavTokenizerDec")
index 16d0a8efc56ae0c0cf88ca0bae89c6521c09b90d..04bfcbb5e575f745bdba2957356008101ab215f0 100644 (file)
 #define KEY_PROJ_SCALE_FACTOR   "clip.vision.projector.scale_factor"
 #define KEY_PROJ_TYPE           "clip.projector_type"
 
+#define KEY_USE_GLU_MLP         "clip.use_glu_mlp"  // for qwen2.5vl
+#define KEY_USE_RMS_NORM        "clip.use_rms_norm" // for qwen2.5vl
+
 #define KEY_MM_PATCH_MERGE_TYPE   "clip.vision.mm_patch_merge_type"
 #define KEY_IMAGE_GRID_PINPOINTS  "clip.vision.image_grid_pinpoints"
 #define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
+#define KEY_WIN_ATTN_PATTERN      "clip.vision.n_wa_pattern"
+#define KEY_ATTN_WINDOW_SIZE      "clip.vision.window_size"
 
 
 //
@@ -55,6 +60,7 @@
 #define TN_FFN_DOWN        "%s.blk.%d.ffn_down.%s"
 #define TN_FFN_GATE        "%s.blk.%d.ffn_gate.%s"
 #define TN_FFN_UP          "%s.blk.%d.ffn_up.%s"
+#define TN_FFN_GATE        "%s.blk.%d.ffn_gate.%s"
 #define TN_LN_1            "%s.blk.%d.ln1.%s"
 #define TN_LN_2            "%s.blk.%d.ln2.%s"
 #define TN_LN_PRE          "%s.pre_ln.%s"
@@ -95,6 +101,7 @@ enum projector_type {
     PROJECTOR_TYPE_GEMMA3,
     PROJECTOR_TYPE_IDEFICS3,
     PROJECTOR_TYPE_PIXTRAL,
+    PROJECTOR_TYPE_QWEN25VL,
     PROJECTOR_TYPE_UNKNOWN,
 };
 
@@ -105,6 +112,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
     { PROJECTOR_TYPE_MINICPMV,  "resampler"},
     { PROJECTOR_TYPE_GLM_EDGE,  "adapter"},
     { PROJECTOR_TYPE_QWEN2VL,   "qwen2vl_merger"},
+    { PROJECTOR_TYPE_QWEN25VL,  "qwen2.5vl_merger"},
     { PROJECTOR_TYPE_GEMMA3,    "gemma3"},
     { PROJECTOR_TYPE_IDEFICS3,  "idefics3"},
     { PROJECTOR_TYPE_PIXTRAL,   "pixtral"},
index e8c01c68a9779e8a6ac85420b8be5a54d6d7246d..b6a1f40e8a580fba18b4e3c5bcbbac00f8b9b3d0 100644 (file)
@@ -28,6 +28,7 @@
 #include <cinttypes>
 #include <limits>
 #include <array>
+#include <numeric>
 
 struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
 
@@ -169,6 +170,8 @@ struct clip_hparams {
     std::vector<int32_t> image_grid_pinpoints;
     int32_t image_crop_resolution;
     std::unordered_set<int32_t> vision_feature_layer;
+    int32_t attn_window_size;
+    int32_t n_wa_pattern;
 };
 
 struct clip_layer {
@@ -200,6 +203,9 @@ struct clip_layer {
     struct ggml_tensor * ff_down_w = nullptr;
     struct ggml_tensor * ff_down_b = nullptr;
 
+    struct ggml_tensor * ff_g_w = NULL;
+    struct ggml_tensor * ff_g_b = NULL;
+
     // layernorm 2
     struct ggml_tensor * ln_2_w = nullptr;
     struct ggml_tensor * ln_2_b = nullptr;
@@ -319,6 +325,7 @@ struct clip_ctx {
     float image_std[3];
     bool use_gelu = false;
     bool use_silu = false;
+    int32_t ftype = 1;
 
     gguf_context_ptr ctx_gguf;
     ggml_context_ptr ctx_data;
@@ -762,6 +769,236 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
     return gf;
 }
 
+static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
+    const auto & model = ctx->vision_model;
+    const auto & hparams = model.hparams;
+
+    const int image_size_width  = imgs.entries[0]->nx;
+    const int image_size_height = imgs.entries[0]->ny;
+
+    const bool use_mrope       = ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL;
+    const bool use_window_attn = hparams.n_wa_pattern > 0;
+
+    const int n_wa_pattern         = hparams.n_wa_pattern;
+    const int patch_size           = hparams.patch_size;
+    const int num_patches          = ((image_size_width / patch_size) * (image_size_height / patch_size));
+    const int patches_w            = image_size_width / patch_size;
+    const int patches_h            = image_size_height / patch_size;
+    const int num_positions        = num_patches + (model.class_embedding ? 1 : 0);
+    const int num_position_ids     = use_mrope ? num_positions * 4 : num_positions;
+    const int hidden_size          = hparams.hidden_size;
+    const int n_head               = hparams.n_head;
+    const int d_head               = hidden_size / n_head;
+    const float eps                = hparams.eps;
+
+    int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
+
+    const int batch_size = imgs.entries.size();
+    GGML_ASSERT(batch_size == 1);
+
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ ctx->buf_compute_meta.size(),
+        /*.mem_buffer =*/ ctx->buf_compute_meta.data(),
+        /*.no_alloc   =*/ true,
+    };
+
+    ggml_context_ptr ctx0_ptr(ggml_init(params));
+    auto ctx0 = ctx0_ptr.get();
+
+    struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+    struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size);
+    ggml_set_name(inp_raw, "inp_raw");
+    ggml_set_input(inp_raw);
+
+    struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
+
+    GGML_ASSERT(image_size_width  % (patch_size * 2) == 0);
+    GGML_ASSERT(image_size_height % (patch_size * 2) == 0);
+
+    auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
+    inp = ggml_add(ctx0, inp, inp_1);
+
+    inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3));  // [w, h, c, b] -> [c, w, h, b]
+    inp = ggml_reshape_4d(
+        ctx0, inp,
+        hidden_size * 2, patches_w / 2, patches_h, batch_size);
+    inp = ggml_reshape_4d(
+        ctx0, inp,
+        hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
+    inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3));
+    inp = ggml_reshape_3d(
+        ctx0, inp,
+        hidden_size, patches_w * patches_h, batch_size);
+
+    if (model.patch_bias) {
+        // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp));
+        inp = ggml_add(ctx0, inp, model.patch_bias);
+    }
+    struct ggml_tensor * embeddings     = inp;
+    struct ggml_tensor * window_mask    = nullptr;
+    struct ggml_tensor * window_idx     = nullptr;
+    struct ggml_tensor * inv_window_idx = nullptr;
+
+    struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
+    ggml_set_name(positions, "positions");
+    ggml_set_input(positions);
+
+    // pre-layernorm
+    if (model.pre_ln_w) {
+        embeddings = ggml_rms_norm(ctx0, embeddings, eps);
+        ggml_set_name(embeddings, "pre_ln");
+
+        embeddings = ggml_mul(ctx0, embeddings, model.pre_ln_w);
+    }
+
+    if (use_window_attn) {
+        // handle window attention inputs
+        inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4);
+        ggml_set_name(inv_window_idx, "inv_window_idx");
+        ggml_set_input(inv_window_idx);
+        // mask for window attention
+        window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, num_positions, num_positions);
+        ggml_set_name(window_mask, "window_mask");
+        ggml_set_input(window_mask);
+
+        // embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
+        GGML_ASSERT(batch_size == 1);
+        embeddings = ggml_reshape_2d(ctx0, embeddings, hidden_size * 4, patches_w * patches_h * batch_size / 4);
+        embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx);
+        embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, patches_w * patches_h, batch_size);
+    }
+
+    // loop over layers
+    for (int il = 0; il < ctx->max_feature_layer; il++) {
+        struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
+
+        // rmsnorm1
+        cur = ggml_rms_norm(ctx0, cur, eps);
+        cur = ggml_mul(ctx0, cur, model.layers[il].ln_1_w);
+
+        // self-attention
+        {
+
+            struct ggml_tensor * Q =
+                ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
+
+            Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
+            Q = ggml_rope_multi(
+                ctx0, Q, positions, nullptr,
+                d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
+            Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
+            Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);
+
+            struct ggml_tensor * K =
+                ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
+
+            K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
+            K = ggml_rope_multi(
+                ctx0, K, positions, nullptr,
+                d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
+            K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
+            K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
+
+            struct ggml_tensor * V =
+                ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b);
+
+            V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
+            V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
+            V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
+
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+            const bool full_attn = use_window_attn ? (il + 1) % n_wa_pattern == 0 : true;
+            if (full_attn) {
+                KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
+            } else {
+                KQ = ggml_soft_max_ext(ctx0, KQ, window_mask, 1.0f / sqrtf((float)d_head), 0.0f);
+            }
+
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
+            KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
+            KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+
+            cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size);
+        }
+
+        // attention output
+        cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b);
+
+        // re-add the layer input, e.g., residual
+        cur = ggml_add(ctx0, cur, embeddings);
+
+        embeddings = cur; // embeddings = residual, cur = hidden_states
+
+        // rms norm2
+        cur = ggml_rms_norm(ctx0, cur, eps);
+        cur = ggml_mul(ctx0, cur, model.layers[il].ln_2_w);
+
+        // mlp
+        // ffn_up
+        auto cur_up = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
+        cur_up = ggml_add(ctx0, cur_up, model.layers[il].ff_o_b);
+
+        auto cur_gate = ggml_mul_mat(ctx0, model.layers[il].ff_g_w, cur);
+        cur_gate = ggml_add(ctx0, cur_gate, model.layers[il].ff_g_b);
+        // TODO : only 2 of these 3 are actually used, should we remove one of them?
+        if (ctx->use_gelu) {
+            cur_gate = ggml_gelu_inplace(ctx0, cur_gate);
+        } else if (ctx->use_silu) {
+            cur_gate = ggml_silu_inplace(ctx0, cur_gate);
+        } else {
+            cur_gate = ggml_gelu_quick_inplace(ctx0, cur_gate);
+        }
+        cur = ggml_mul(ctx0, cur_gate, cur_up);
+
+        // ffn_down
+        cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
+        cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
+
+        // residual 2
+        cur = ggml_add(ctx0, embeddings, cur);
+
+        embeddings = cur;
+    }
+
+    // post-layernorm
+    if (model.post_ln_w) {
+        embeddings = ggml_rms_norm(ctx0, embeddings, eps);
+        ggml_set_name(embeddings, "post_ln");
+
+        embeddings = ggml_mul(ctx0, embeddings, model.post_ln_w);
+    }
+
+    embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
+
+    embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
+    embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
+
+    // GELU activation
+    embeddings = ggml_gelu(ctx0, embeddings);
+
+    // Second linear layer
+    embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
+    embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
+
+    if (use_window_attn) {
+        window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4);
+        ggml_set_name(window_idx, "window_idx");
+        ggml_set_input(window_idx);
+
+        // embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
+        GGML_ASSERT(batch_size == 1);
+        embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4);
+        embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
+        embeddings = ggml_reshape_3d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4, batch_size);
+    }
+
+    // build the graph
+    ggml_build_forward_expand(gf, embeddings);
+
+    return gf;
+}
+
 static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
     const auto & model = ctx->vision_model;
     const auto & hparams = model.hparams;
@@ -1331,6 +1568,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
                 GGML_ASSERT(imgs.entries.size() == 1);
                 res = clip_image_build_graph_pixtral(ctx, *imgs.entries[0]);
             } break;
+        case PROJECTOR_TYPE_QWEN25VL:
+            {
+                res = clip_image_build_graph_qwen25vl(ctx, imgs);
+            } break;
         default:
             {
                 // TODO: we should have one build_* function per model
@@ -1507,6 +1748,10 @@ struct clip_model_loader {
                     {
                         hparams.rope_theta = 10000.0f;
                     } break;
+                case PROJECTOR_TYPE_QWEN25VL:
+                    {
+                        get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern);
+                    } break;
                 default:
                     break;
             }
@@ -1600,8 +1845,10 @@ struct clip_model_loader {
             // legacy naming (the in and out is reversed! don't ask me why)
             layer.ff_i_w = layer.ff_down_w;
             layer.ff_o_w = layer.ff_up_w;
+            layer.ff_g_w = layer.ff_gate_w;
             layer.ff_i_b = layer.ff_down_b;
             layer.ff_o_b = layer.ff_up_b;
+            layer.ff_g_b = layer.ff_gate_b;
         }
 
         switch (ctx_clip.proj_type) {
@@ -1700,6 +1947,7 @@ struct clip_model_loader {
                     vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H,"weight"));
                 } break;
             case PROJECTOR_TYPE_QWEN2VL:
+            case PROJECTOR_TYPE_QWEN25VL:
                 {
                     vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
                     vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
@@ -2651,7 +2899,7 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
         else {
             GGML_ABORT("Unknown minicpmv version");
         }
-    } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
+    } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
         int patch_size = params.patch_size * 2;
         int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
         int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
@@ -2792,6 +3040,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
     const int pos_w = ctx->load_image_size.width / patch_size;
     const int pos_h = ctx->load_image_size.height / patch_size;
 
+    const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl
+
     {
         struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
         std::vector<float> inp_data(ggml_nelements(inp_raw));
@@ -2890,31 +3140,93 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
         // non-minicpmv models
 
         if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
-            struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
+            // pw * ph = number of tokens output by ViT after apply patch merger
+            // ipw * ipw = number of vision token been processed inside ViT
+            const int merge_ratio = 2;
+            const int pw  = image_size_width  / patch_size / merge_ratio;
+            const int ph  = image_size_height / patch_size / merge_ratio;
+            const int ipw = image_size_width  / patch_size;
+            const int iph = image_size_height / patch_size;
+
+            std::vector<int> idx    (ph * pw);
+            std::vector<int> inv_idx(ph * pw);
+
+            if (use_window_attn) {
+                const int attn_window_size = 112;
+                struct ggml_tensor * window_idx     = ggml_graph_get_tensor(gf, "window_idx");
+                struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
+                struct ggml_tensor * window_mask    = ggml_graph_get_tensor(gf, "window_mask");
+
+                const int grid_window = attn_window_size / patch_size / merge_ratio;
+                int dst = 0;
+                // [num_vision_tokens, num_vision_tokens] attention mask tensor
+                std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest());
+                int mask_row = 0;
+
+                for (int y = 0; y < ph; y += grid_window)
+                {
+                    for (int x = 0; x < pw; x += grid_window)
+                    {
+                        const int win_h = std::min(grid_window, ph - y);
+                        const int win_w = std::min(grid_window, pw - x);
+                        const int dst_0 = dst;
+                        // group all tokens belong to the same window togather (to a continue range)
+                        for (int dy = 0; dy < win_h; dy++) {
+                            for (int dx = 0; dx < win_w; dx++) {
+                                const int src = (y + dy) * pw + (x + dx);
+                                assert(src < (int)idx.size());
+                                assert(dst < (int)inv_idx.size());
+                                idx    [src] = dst;
+                                inv_idx[dst] = src;
+                                dst++;
+                            }
+                        }
 
-            const int pw = image_size_width / patch_size;
-            const int ph = image_size_height / patch_size;
-            int* positions_data = (int*)malloc(ggml_nbytes(positions));
+                        for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) {
+                            int row_offset = mask_row * (ipw * iph);
+                            std::fill(
+                                mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio),
+                                mask.begin() + row_offset + (dst   * merge_ratio * merge_ratio),
+                                0.0);
+                            mask_row++;
+                        }
+                    }
+                }
+
+                ggml_backend_tensor_set(window_idx,     idx.data(),     0, ggml_nbytes(window_idx));
+                ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
+                ggml_backend_tensor_set(window_mask,    mask.data(),    0, ggml_nbytes(window_mask));
+            } else {
+                std::iota(idx.begin(), idx.end(), 0);
+                std::iota(inv_idx.begin(), inv_idx.end(), 0);
+            }
+
+            struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
+            const int mpow = merge_ratio * merge_ratio;
+            std::vector<int> positions_data(ggml_nelements(positions));
+            int * data = positions_data.data();
 
             int ptr = 0;
-            for (int y = 0; y < ph; y+=2)
+            for (int y = 0; y < iph; y += merge_ratio)
             {
-                for (int x = 0; x < pw; x+=2)
+                for (int x = 0; x < ipw; x += merge_ratio)
                 {
                     for (int dy = 0; dy < 2; dy++) {
                         for (int dx = 0; dx < 2; dx++) {
-                            positions_data[ptr]                 = y + dy;
-                            positions_data[num_patches + ptr]     = x + dx;
-                            positions_data[num_patches * 2 + ptr] = y + dy;
-                            positions_data[num_patches * 3 + ptr] = x + dx;
+                            auto remap = idx[ptr / mpow];
+                            remap = remap * mpow + (ptr % mpow);
+
+                            data[                  remap] = y + dy;
+                            data[    num_patches + remap] = x + dx;
+                            data[2 * num_patches + remap] = y + dy;
+                            data[3 * num_patches + remap] = x + dx;
                             ptr++;
                         }
                     }
                 }
             }
 
-            ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
-            free(positions_data);
+            ggml_backend_tensor_set(positions, data, 0, ggml_nbytes(positions));
         }
         else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
             // do nothing
@@ -2967,6 +3279,65 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
         }
     }
 
+    if (use_window_attn && ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
+        struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx");
+        struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
+        struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask");
+
+        const int merge_ratio = 2;
+        const int attn_window_size = 112;
+        const int pw = image_size_width / patch_size / merge_ratio;
+        const int ph = image_size_height / patch_size / merge_ratio;
+        const int grid_window = attn_window_size / patch_size / merge_ratio;
+        const int ipw = image_size_width / patch_size;
+        const int iph = image_size_height / patch_size;
+        /*
+        pw * ph = number of tokens output by ViT after apply patch merger
+        ipw * ipw = number of vision token been processed inside ViT
+        */
+
+        std::vector<int> idx(ph * pw);
+        std::vector<int> inv_idx(ph * pw);
+        int dst = 0;
+        // [num_vision_tokens, num_vision_tokens] attention mask tensor
+        std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest());
+        int mask_row = 0;
+
+        for (int y = 0; y < ph; y+=grid_window)
+        {
+            for (int x = 0; x < pw; x+=grid_window)
+            {
+                const int win_h = std::min(grid_window, ph - y);
+                const int win_w = std::min(grid_window, pw - x);
+                const int dst_0 = dst;
+                // group all tokens belong to the same window togather (to a continue range)
+                for (int dy = 0; dy < win_h; dy++) {
+                    for (int dx = 0; dx < win_w; dx++) {
+                        const int src = (y + dy) * pw + (x + dx);
+                        assert(src < (int)idx.size());
+                        assert(dst < (int)inv_idx.size());
+                        idx[src] = dst;
+                        inv_idx[dst] = src;
+                        dst++;
+                    }
+                }
+
+                for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) {
+                    int row_offset = mask_row * (ipw * iph);
+                    std::fill(
+                        mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio),
+                        mask.begin() + row_offset + (dst   * merge_ratio * merge_ratio),
+                        0.0);
+                    mask_row++;
+                }
+            }
+        }
+
+        ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx));
+        ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
+        ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask));
+    }
+
     ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads);
 
     auto status = ggml_backend_sched_graph_compute(ctx->sched.get(), gf);
@@ -3142,6 +3513,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
         case PROJECTOR_TYPE_GLM_EDGE:
             return ctx->vision_model.mm_model_mlp_3_w->ne[1];
         case PROJECTOR_TYPE_QWEN2VL:
+        case PROJECTOR_TYPE_QWEN25VL:
             return ctx->vision_model.mm_1_b->ne[0];
         case PROJECTOR_TYPE_GEMMA3:
             return ctx->vision_model.mm_input_proj_w->ne[0];
index c87606b4fdf4fc0a60088ffd2afc1931dedab524..7951a6fa8951e58cda32731f2606745c31ff210b 100644 (file)
@@ -1,14 +1,16 @@
 import argparse
-from typing import Dict
+from typing import Dict, List, Optional
 
 import torch
 import numpy as np
 from gguf import *
 from transformers import (
-    Qwen2VLForConditionalGeneration,
-    Qwen2VLProcessor,
     AutoProcessor,
-    Qwen2VLConfig
+    Qwen2VLConfig,
+    Qwen2VLProcessor,
+    Qwen2VLForConditionalGeneration,
+    Qwen2_5_VLConfig, # type: ignore[reportAttributeAccessIssue]
+    Qwen2_5_VLForConditionalGeneration, # type: ignore[reportAttributeAccessIssue]
 )
 
 
@@ -19,61 +21,93 @@ def k(raw_key: str, arch: str) -> str:
     return raw_key.format(arch=arch)
 
 
-def to_gguf_name(name: str) -> str:
-    og = name
-    name = name.replace("text_model", "t").replace("vision_model", "v")
-    name = name.replace("blocks", "blk").replace("embeddings.", "")
-    name = name.replace("attn.", "attn_")
-    name = name.replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("proj.", "out.")
-    # name = name.replace("layrnorm", "ln").replace("layer_norm", "ln").replace("layernorm", "ln")
-    name = name.replace("norm1", "ln1").replace("norm2", "ln2")
-    name = name.replace("merger.mlp", 'mm')
-    print(f"[to_gguf_name] {og} --> {name}")
-    return name
-
-
-def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]:
-    vision_model = qwen2vl.visual
-    tensor_map = {}
-    for name, ten in vision_model.state_dict().items():
-        ten = ten.numpy()
-        if 'qkv' in name:
-            if ten.ndim == 2: # weight
-                c3, _ = ten.shape
-            else:             # bias
-                c3 = ten.shape[0]
-            assert c3 % 3 == 0
-            c = c3 // 3
-            wq = ten[:c]
-            wk = ten[c: c * 2]
-            wv = ten[c * 2:]
-            tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "q")] = wq
-            tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "k")] = wk
-            tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "v")] = wv
-        elif 'merger' in name:
-            if name.endswith("ln_q.weight"):
-                tensor_map['v.post_ln.weight'] = ten
-            elif name.endswith("ln_q.bias"):
-                tensor_map['v.post_ln.bias'] = ten
+def get_n_wa_pattern(fullatt_block_indexes: Optional[List[int]]):
+    if fullatt_block_indexes is None:
+        return 0
+    n_wa = fullatt_block_indexes[0]
+    for a, b in zip(fullatt_block_indexes, fullatt_block_indexes[1:]):
+        if b - a - 1 != n_wa:
+            raise ValueError(
+                f"window/full attention layer should have fix pattern of "
+                f"for each full-attention layer followed by {n_wa} window-attention layers"
+            )
+    return n_wa + 1
+
+
+class VL2:
+
+    @staticmethod
+    def to_gguf_name(name: str) -> str:
+        og = name
+        name = name.replace("text_model", "t").replace("vision_model", "v")
+        name = name.replace("blocks", "blk").replace("embeddings.", "")
+        name = name.replace("attn.", "attn_")
+        name = name.replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("proj.", "out.")
+        # name = name.replace("layrnorm", "ln").replace("layer_norm", "ln").replace("layernorm", "ln")
+        name = name.replace("norm1", "ln1").replace("norm2", "ln2")
+        name = name.replace("merger.mlp", 'mm')
+        print(f"[to_gguf_name] {og} --> {name}")
+        return name
+
+    @classmethod
+    def find_vision_tensors(cls, qwen2vl, dtype) -> Dict[str, np.ndarray]:
+        vision_model = qwen2vl.visual
+        tensor_map = {}
+        for name, ten in vision_model.state_dict().items():
+            ten = ten.numpy()
+            if 'qkv' in name:
+                if ten.ndim == 2: # weight
+                    c3, _ = ten.shape
+                else:             # bias
+                    c3 = ten.shape[0]
+                assert c3 % 3 == 0
+                c = c3 // 3
+                wq = ten[:c]
+                wk = ten[c: c * 2]
+                wv = ten[c * 2:]
+                tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "q")] = wq
+                tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "k")] = wk
+                tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "v")] = wv
+            elif 'merger' in name:
+                if name.endswith("ln_q.weight"):
+                    tensor_map['v.post_ln.weight'] = ten
+                elif name.endswith("ln_q.bias"):
+                    tensor_map['v.post_ln.bias'] = ten
+                else:
+                    # "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias"
+                    tensor_map[cls.to_gguf_name(name)] = ten
+            elif 'patch_embed.proj.weight' in name:
+                # NOTE: split Conv3D into Conv2Ds
+                c1, c2, kt, kh, kw = ten.shape
+                assert kt == 2, "Current implmentation only support temporal_patch_size of 2"
+                tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...]
+                tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...]
             else:
-                # "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias"
-                tensor_map[to_gguf_name(name)] = ten
-        elif 'patch_embed.proj.weight' in name:
-            # NOTE: split Conv3D into Conv2Ds
-            c1, c2, kt, kh, kw = ten.shape
-            assert kt == 2, "Current implmentation only support temporal_patch_size of 2"
-            tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...]
-            tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...]
-        else:
-            tensor_map[to_gguf_name(f"vision_model.{name}")] = ten
-
-    for new_name, ten in tensor_map.items():
-        if ten.ndim <= 1 or new_name.endswith("_norm.weight"):
-            tensor_map[new_name] = ten.astype(np.float32)
-        else:
-            tensor_map[new_name] = ten.astype(dtype)
-    tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32)  # dummy tensor, just here as a placeholder
-    return tensor_map
+                tensor_map[cls.to_gguf_name(f"vision_model.{name}")] = ten
+
+        for new_name, ten in tensor_map.items():
+            if ten.ndim <= 1 or new_name.endswith("_norm.weight"):
+                tensor_map[new_name] = ten.astype(np.float32)
+            else:
+                tensor_map[new_name] = ten.astype(dtype)
+        tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32)  # dummy tensor, just here as a placeholder
+        return tensor_map
+
+
+class VL25(VL2):
+
+    @staticmethod
+    def to_gguf_name(name: str) -> str:
+        og = name
+        name = name.replace("text_model", "t").replace("vision_model", "v")
+        name = name.replace("blocks", "blk").replace("embeddings.", "")
+        name = name.replace("attn.", "attn_")
+        name = name.replace("mlp.down_proj", "ffn_down").replace("mlp.up_proj", "ffn_up")
+        name = name.replace("mlp.gate_proj", "ffn_gate").replace("proj.", "out.")
+        name = name.replace("norm1", "ln1").replace("norm2", "ln2")
+        name = name.replace("merger.mlp", 'mm')
+        print(f"[vl25][to_gguf_name] {og} --> {name}")
+        return name
 
 
 def main(args):
@@ -82,7 +116,7 @@ def main(args):
         np_dtype = np.float32
         ftype = 0
     elif args.data_type == 'fp16':
-        dtype = torch.float32
+        dtype = torch.float16
         np_dtype = np.float16
         ftype = 1
     else:
@@ -92,11 +126,18 @@ def main(args):
     model_path = ""
     model_name = args.model_name
     print("model_name: ", model_name)
-    qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained(
-        model_name, torch_dtype=dtype, device_map="cpu"
-    )
-    cfg: Qwen2VLConfig = qwen2vl.config  # type: ignore[reportAssignmentType]
-    vcfg = cfg.vision_config
+    if args.model_type == "qwen2vl":
+        qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained(
+            model_name, torch_dtype=dtype, device_map="cpu"
+        )
+        cfg: Qwen2VLConfig = qwen2vl.config  # type: ignore[reportAssignmentType]
+        vcfg = cfg.vision_config
+    else:
+        qwen2vl = Qwen2_5_VLForConditionalGeneration.from_pretrained(
+            model_name, torch_dtype=dtype, device_map="cpu"
+        )
+        cfg: Qwen2_5_VLConfig = qwen2vl.config  # type: ignore[reportAssignmentType]
+        vcfg = cfg.vision_config
 
     if os.path.isdir(model_name):
         local_model = True
@@ -113,7 +154,6 @@ def main(args):
     fout.add_bool("clip.has_text_encoder", False)
     fout.add_bool("clip.has_vision_encoder", True)
     fout.add_bool("clip.has_qwen2vl_merger", True)
-    fout.add_string("clip.projector_type", "qwen2vl_merger")
 
     print(cfg.vision_config)
     if 'silu' in cfg.vision_config.hidden_act.lower():
@@ -125,14 +165,25 @@ def main(args):
     else:
         raise ValueError()
 
-    tensor_map = find_vision_tensors(qwen2vl, np_dtype)
+    if args.model_type == "qwen2.5vl":
+        fout.add_uint32("clip.vision.n_wa_pattern", get_n_wa_pattern(vcfg.fullatt_block_indexes))
+        fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.hidden_size)
+        fout.add_uint32("clip.vision.projection_dim", vcfg.out_hidden_size)
+        fout.add_string("clip.projector_type", "qwen2.5vl_merger")
+    else:
+        fout.add_string("clip.projector_type", "qwen2vl_merger")
+        fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim)
+        fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size)
+
+    if args.model_type == "qwen2.5vl":
+        tensor_map = VL25.find_vision_tensors(qwen2vl, np_dtype)
+    else:
+        tensor_map = VL2.find_vision_tensors(qwen2vl, np_dtype)
     for name, data in tensor_map.items():
         fout.add_tensor(name, data)
 
     fout.add_uint32("clip.vision.patch_size", vcfg.patch_size)
     fout.add_uint32("clip.vision.image_size", 14 * 40)  # some reasonable size that is divable by (14*2)
-    fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim)
-    fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size)
     fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), vcfg.num_heads)
     fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
     fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), vcfg.depth)
@@ -160,6 +211,7 @@ def main(args):
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
     parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct")
+    parser.add_argument("--model_type", nargs='?', choices=['qwen2vl', 'qwen2.5vl'], default="qwen2vl")
     parser.add_argument("--data_type", nargs='?', choices=['fp32', 'fp16'], default="fp32")
     args = parser.parse_args()
     main(args)
index eca7b7f10b9e35d32371f6f656fe3acc763c96d5..cf4271086919100b1614e22ed03574fb966b7120 100644 (file)
@@ -23,6 +23,9 @@
 #include <algorithm>
 #include <iostream>
 #include <fstream>
+#include <limits>
+#include <cassert>
+#include <cmath>
 
 
 static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed,
@@ -367,14 +370,14 @@ static void debug_test_mrope_2d() {
     // 1. Initialize backend
     ggml_backend_t backend = NULL;
     std::string backend_name = "";
-#ifdef GGML_USE_CUDA
-    fprintf(stderr, "%s: using CUDA backend\n", __func__);
-    backend = ggml_backend_cuda_init(0); // init device 0
-    backend_name = "cuda";
-    if (!backend) {
-        fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
-    }
-#endif
+// #ifdef GGML_USE_CUDA
+//     fprintf(stderr, "%s: using CUDA backend\n", __func__);
+//     backend = ggml_backend_cuda_init(0); // init device 0
+//     backend_name = "cuda";
+//     if (!backend) {
+//         fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
+//     }
+// #endif
     // if there aren't GPU Backends fallback to CPU backend
     if (!backend) {
         backend = ggml_backend_cpu_init();
@@ -483,28 +486,82 @@ static void debug_test_mrope_2d() {
     ggml_backend_free(backend);
 }
 
-static void debug_dump_img_embed(struct llava_context * ctx_llava) {
-    int n_embd  = llama_model_n_embd(llama_get_model(ctx_llava->ctx_llama));
-    int ne = n_embd * 4;
-    float vals[56 * 56 * 3];
+enum model_output_type {
+    conv3d,
+    patch_embed,
+    patch_win_attn_scatter,
+    first_attn_layer,
+    last_attn_layer,
+    attn_softmax,
+    final_layer,
+};
+
+static void debug_dump_img_embed(struct llava_context * ctx_llava, model_output_type output_type) {
+    constexpr int ih = 140;
+    constexpr int iw = 196;
+    // constexpr int ih = 56;
+    // constexpr int iw = 56;
+    // int n_embd  = llama_model_n_embd(llama_get_model(ctx_llava->ctx_llama));
+    int n_embd  = 1280;
+    int merge = 1;
+    if (output_type == model_output_type::final_layer) {
+        n_embd  = 2048;
+        merge = 2;
+    }
+    else if (output_type == model_output_type::attn_softmax) {
+        merge = 1;
+        n_embd = (ih/14/merge) * (iw/14/merge) * 16;
+    }
+
+    int ne = (ih/14/merge) * (iw/14/merge) * n_embd;
+    float vals[iw * ih * 3];
     // float embd[ne];
     std::vector<float> embd;
     embd.resize(ne);
 
-    for (int i = 0; i < 56*56; i++)
+    for (int i = 0; i < iw*ih; i++)
     {
         for (int c = 0; c < 3; c++)
-            vals[i * 3 + c] = (float)(i % (56 * 56)) / (56*56);
+            vals[i * 3 + c] = (float)i / (iw*ih);
     }
 
-    clip_encode_float_image(ctx_llava->ctx_clip, 16, vals, 56, 56, embd.data());
+    clip_encode_float_image(ctx_llava->ctx_clip, 8, vals, ih, iw, embd.data());
+
+    std::string file_postfix = "";
+    switch (output_type)
+    {
+    case model_output_type::conv3d:
+        file_postfix = "conv3d";
+        break;
+    case model_output_type::patch_embed:
+        file_postfix = "patch_embed";
+        break;
+    case model_output_type::patch_win_attn_scatter:
+        file_postfix = "scatter";
+        break;
+    case model_output_type::first_attn_layer:
+        file_postfix = "first_attn";
+        break;
+    case model_output_type::last_attn_layer:
+        file_postfix = "last_attn";
+        break;
+    case model_output_type::attn_softmax:
+        file_postfix = "attn_softmax";
+        break;
+    case model_output_type::final_layer:
+        file_postfix = "final";
+        break;
+    default:
+        break;
+    }
+    auto output_path = "img_embed_" + file_postfix + ".bin";
 
-    std::ofstream outFile("img_embed.bin", std::ios::binary);
+    std::ofstream outFile(output_path, std::ios::binary);
     if (outFile.is_open()) {
         outFile.write(reinterpret_cast<const char*>(embd.data()), ne * sizeof(float));
 
         outFile.close();
-        std::cout << "Data successfully written to mrope.bin" << std::endl;
+        std::cout << "Data successfully written to ::[ " << output_path << std::endl;
     } else {
         std::cerr << "Error opening file!" << std::endl;
     }
@@ -551,8 +608,9 @@ int main(int argc, char ** argv) {
     } else if (params.image[0].empty()) {
         auto ctx_llava = llava_init_context(&params, model);
 
-        debug_test_mrope_2d();
-        debug_dump_img_embed(ctx_llava);
+        // debug_test_mrope_2d();
+        debug_dump_img_embed(ctx_llava, model_output_type::final_layer);
+        // debug_dump_img_embed(ctx_llava, model_output_type::last_attn_layer);
 
         llama_perf_context_print(ctx_llava->ctx_llama);
         ctx_llava->model = NULL;
index e612857edc55dca780d81b3ceb0dfd66c747b065..4002f9d531bd257bded601c2c4bf9ba64363100e 100755 (executable)
@@ -55,6 +55,7 @@ add_test "llama-mtmd-cli"  "second-state/MiniCPM-Llama3-V-2_5-GGUF:Q2_K"  # mode
 add_test "llama-mtmd-cli"  "openbmb/MiniCPM-V-2_6-gguf:Q2_K"
 add_test "llama-mtmd-cli"  "openbmb/MiniCPM-o-2_6-gguf:Q4_0"
 add_test "llama-qwen2vl-cli"  "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
+add_test "llama-qwen2vl-cli"  "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
 
 # to test the big models, run: ./tests.sh big
 add_test_big "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"