]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
mtmd : add qwen2vl and qwen2.5vl (#13141)
authorXuan-Son Nguyen <redacted>
Tue, 29 Apr 2025 09:47:04 +0000 (11:47 +0200)
committerGitHub <redacted>
Tue, 29 Apr 2025 09:47:04 +0000 (11:47 +0200)
* llava : add clip_n_output_tokens, deprecate clip_n_patches

* mtmd : add qwen2vl and qwen2.5vl

* decode_embd_batch::set_position_...

* working version

* deprecate llama-qwen2vl-cli

* correct order W, H of clip_embd_nbytes_by_img

* edit existing line in hot topics

README.md
examples/llava/CMakeLists.txt
examples/llava/clip.cpp
examples/llava/clip.h
examples/llava/llava.cpp
examples/llava/mtmd-cli.cpp
examples/llava/mtmd.cpp
examples/llava/mtmd.h
examples/llava/qwen2vl-cli.cpp [deleted file]
examples/llava/qwen2vl-test.cpp [new file with mode: 0644]
examples/llava/tests.sh

index 1785493c3e2b0e3081aa0735151b9699e93c5452..42c0eb633ef5d839d5f6c09815d83ea0021fdbb6 100644 (file)
--- a/README.md
+++ b/README.md
@@ -17,7 +17,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
 ## Hot topics
 
 - **GGML developer experience survey (organized and reviewed by NVIDIA):** [link](https://forms.gle/Gasw3cRgyhNEnrwK9)
-- A new binary `llama-mtmd-cli` is introduced to replace `llava-cli`, `minicpmv-cli` and `gemma3-cli` https://github.com/ggml-org/llama.cpp/pull/13012, `libllava` will be deprecated
+- A new binary `llama-mtmd-cli` is introduced to replace `llava-cli`, `minicpmv-cli`, `gemma3-cli` ([#13012](https://github.com/ggml-org/llama.cpp/pull/13012)) and `qwen2vl-cli` ([#13141]((https://github.com/ggml-org/llama.cpp/pull/13141))), `libllava` will be deprecated
 - VS Code extension for FIM completions: https://github.com/ggml-org/llama.vscode
 - Universal [tool call support](./docs/function-calling.md) in `llama-server` https://github.com/ggml-org/llama.cpp/pull/9639
 - Vim/Neovim plugin for FIM completions: https://github.com/ggml-org/llama.vim
index 6409b4f5e6cd02fb2d4d780b7a6bb78431d7bdf1..27b6d27e5cac3fe27a6a0ff34f7b0a87802c3921 100644 (file)
@@ -64,13 +64,7 @@ endif()
 add_executable(llama-llava-cli    deprecation-warning.cpp)
 add_executable(llama-gemma3-cli   deprecation-warning.cpp)
 add_executable(llama-minicpmv-cli deprecation-warning.cpp)
-
-set(TARGET llama-qwen2vl-cli)
-add_executable(${TARGET} qwen2vl-cli.cpp)
-set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-qwen2vl-cli)
-install(TARGETS ${TARGET} RUNTIME)
-target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
-target_compile_features(${TARGET} PRIVATE cxx_std_17)
+add_executable(llama-qwen2vl-cli  deprecation-warning.cpp)
 
 set(TARGET llama-mtmd-cli)
 add_executable(${TARGET} mtmd-cli.cpp)
index a5eb55f4d412d67d16ccfb1e63fc158f8792569d..ad3e7df1d8a3a8c241f673df5601a68ede08619f 100644 (file)
@@ -2825,15 +2825,18 @@ void clip_free(clip_ctx * ctx) {
     delete ctx;
 }
 
+// deprecated
 size_t clip_embd_nbytes(const struct clip_ctx * ctx) {
-    return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float);
+    const int32_t nx = ctx->vision_model.hparams.image_size;
+    const int32_t ny = ctx->vision_model.hparams.image_size;
+    return clip_embd_nbytes_by_img(ctx, nx, ny);
 }
 
-size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w) {
+size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_w, int img_h) {
     clip_image_f32 img;
     img.nx = img_w;
     img.ny = img_h;
-    return clip_n_patches_by_img(ctx, &img) * clip_n_mmproj_embd(ctx) * sizeof(float);
+    return clip_n_output_tokens(ctx, &img) * clip_n_mmproj_embd(ctx) * sizeof(float);
 }
 
 int32_t clip_get_image_size(const struct clip_ctx * ctx) {
@@ -2863,14 +2866,37 @@ size_t get_clip_image_grid_size(const struct clip_ctx * ctx) {
     return ctx->vision_model.hparams.image_grid_pinpoints.size();
 }
 
+// deprecated
 int clip_n_patches(const struct clip_ctx * ctx) {
     clip_image_f32 img;
     img.nx = ctx->vision_model.hparams.image_size;
     img.ny = ctx->vision_model.hparams.image_size;
-    return clip_n_patches_by_img(ctx, &img);
+    return clip_n_output_tokens(ctx, &img);
 }
 
+// deprecated
 int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
+    return clip_n_output_tokens(ctx, img);
+}
+
+int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
+    const auto & params = ctx->vision_model.hparams;
+    const int n_total = clip_n_output_tokens(ctx, img);
+    if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
+        return img->nx / (params.patch_size * 2) + (int)(img->nx % params.patch_size > 0);
+    }
+    return n_total;
+}
+
+int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
+    const auto & params = ctx->vision_model.hparams;
+    if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
+        return img->ny / (params.patch_size * 2) + (int)(img->ny % params.patch_size > 0);
+    }
+    return 1;
+}
+
+int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
     const auto & params = ctx->vision_model.hparams;
 
     int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
index 6ba42ad8921468843f21e30b6359e95550cb2ce6..0a53bd8eb78e1b5755afd80e1635ad810464d11c 100644 (file)
@@ -47,7 +47,7 @@ CLIP_API struct clip_ctx * clip_init(const char * fname, struct clip_context_par
 CLIP_API void clip_free(struct clip_ctx * ctx);
 
 CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx);
-CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w);
+CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_w, int img_h);
 
 CLIP_API int32_t clip_get_image_size (const struct clip_ctx * ctx);
 CLIP_API int32_t clip_get_patch_size (const struct clip_ctx * ctx);
@@ -59,9 +59,20 @@ CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx);
 CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx);
 CLIP_API size_t get_clip_image_grid_size(const struct clip_ctx * ctx);
 
-CLIP_API int clip_n_patches        (const struct clip_ctx * ctx);
-CLIP_API int clip_n_patches_by_img (const struct clip_ctx * ctx, struct clip_image_f32 * img);
-CLIP_API int clip_n_mmproj_embd    (const struct clip_ctx * ctx);
+GGML_DEPRECATED(CLIP_API int clip_n_patches(const struct clip_ctx * ctx),
+    "use clip_n_output_tokens instead");
+GGML_DEPRECATED(CLIP_API int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * img),
+    "use clip_n_output_tokens instead");
+
+CLIP_API int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img);
+
+// for M-RoPE, this will be the number of token positions in X and Y directions
+// for other models, X will be the total number of tokens and Y will be 1
+CLIP_API int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img);
+CLIP_API int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img);
+
+// this should be equal to the embedding dimension of the text model
+CLIP_API int clip_n_mmproj_embd(const struct clip_ctx * ctx);
 
 CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip);
 CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size);
index 03a22cbb4c20541cbf5e0bc562759cc7adfead4c..c00d16aefff10eaba7fefabe06e48691c499c992 100644 (file)
@@ -112,7 +112,7 @@ static struct clip_image_grid_shape get_anyres_image_grid_shape(const std::pair<
 }
 
 // Take the image segments in a grid configuration and return the embeddings and the number of embeddings into preallocated memory (image_embd_out)
-static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *> & image_embd_v, struct clip_image_grid_shape grid_shape, float * image_embd_out, int * n_img_pos_out) {
+static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *> & image_embd_v, struct clip_image_grid_shape grid_shape, float * image_embd_out, int * n_img_pos_out, clip_image_f32 * img_input) {
     struct {
         struct ggml_context * ctx;
     } model;
@@ -175,7 +175,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *>
 
     model.ctx = ggml_init(params);
 
-    struct ggml_tensor * image_features = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, clip_n_mmproj_embd(ctx_clip), clip_n_patches(ctx_clip), num_images - 1); // example: 4096 x 576 x 4
+    struct ggml_tensor * image_features = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, clip_n_mmproj_embd(ctx_clip), clip_n_output_tokens(ctx_clip, img_input), num_images - 1); // example: 4096 x 576 x 4
     // ggml_tensor_printf(image_features,"image_features",__LINE__,false,false);
     // fill it with the image embeddings, ignoring the base
     for (size_t i = 1; i < num_images; i++) {
@@ -214,8 +214,8 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *>
 
     memcpy(image_embd_out, image_embd_v[0], clip_embd_nbytes(ctx_clip)); // main image as global context
     // append without newline tokens (default behavior in llava_arch when not using unpad ):
-    memcpy(image_embd_out + clip_n_patches(ctx_clip) * clip_n_mmproj_embd(ctx_clip), (float*)result->data, clip_embd_nbytes(ctx_clip) * (num_images-1)); // grid patches
-    *n_img_pos_out = static_cast<int>(result->ne[1]+clip_n_patches(ctx_clip));
+    memcpy(image_embd_out + clip_n_output_tokens(ctx_clip, img_input) * clip_n_mmproj_embd(ctx_clip), (float*)result->data, clip_embd_nbytes(ctx_clip) * (num_images-1)); // grid patches
+    *n_img_pos_out = static_cast<int>(result->ne[1]+clip_n_output_tokens(ctx_clip, img_input));
 
     // Debug: Test single segments
     // Current findings: sending base image, sending a segment embedding all works similar to python
@@ -313,7 +313,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
                 image_embd + n_img_pos_out * clip_n_mmproj_embd(ctx_clip),
                 image_embd_v[i],
                 clip_embd_nbytes_by_img(ctx_clip, nx, ny));
-            n_img_pos_out += clip_n_patches_by_img(ctx_clip, img_res);
+            n_img_pos_out += clip_n_output_tokens(ctx_clip, img_res);
         }
         *n_img_pos = n_img_pos_out;
         for (size_t i = 0; i < image_embd_v.size(); i++) {
@@ -342,8 +342,8 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
     }
     else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) {
         // flat / default llava-1.5 type embedding
-        *n_img_pos = clip_n_patches(ctx_clip);
         clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), 0);
+        *n_img_pos = clip_n_output_tokens(ctx_clip, img_res);
         bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd); // image_embd shape is 576 x 4096
         if (!encoded) {
             LOG_ERR("Unable to encode image\n");
@@ -381,7 +381,8 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
         struct clip_image_grid_shape grid_shape = get_anyres_image_grid_shape({img->nx,img->ny}, grid_pinpoints, image_size);
 
         int n_img_pos_out;
-        clip_llava_handle_patches(ctx_clip, image_embd_v, grid_shape, image_embd, &n_img_pos_out);
+        clip_image_f32 * img_input = clip_image_f32_get_img(img_res_v.get(), 0);
+        clip_llava_handle_patches(ctx_clip, image_embd_v, grid_shape, image_embd, &n_img_pos_out, img_input);
         *n_img_pos = n_img_pos_out;
 
         for (size_t i = 0; i < image_embd_v.size(); i++) {
index 250e8c9a9e8714cabbcc3e101773ca8c55da9e51..4d857ca64e0b497c41ef671389107625f79e5da8 100644 (file)
@@ -136,39 +136,6 @@ struct mtmd_cli_context {
     }
 };
 
-struct decode_embd_batch {
-    std::vector<llama_pos>      pos;
-    std::vector<int32_t>        n_seq_id;
-    std::vector<llama_seq_id>   seq_id_0;
-    std::vector<llama_seq_id *> seq_ids;
-    std::vector<int8_t>         logits;
-    llama_batch batch;
-    decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
-        pos     .resize(n_tokens);
-        n_seq_id.resize(n_tokens);
-        seq_ids .resize(n_tokens + 1);
-        logits  .resize(n_tokens);
-        seq_id_0.resize(1);
-        seq_id_0[0] = seq_id;
-        seq_ids [n_tokens] = nullptr;
-        batch = {
-            /*n_tokens       =*/ n_tokens,
-            /*tokens         =*/ nullptr,
-            /*embd           =*/ embd,
-            /*pos            =*/ pos.data(),
-            /*n_seq_id       =*/ n_seq_id.data(),
-            /*seq_id         =*/ seq_ids.data(),
-            /*logits         =*/ logits.data(),
-        };
-        for (int i = 0; i < n_tokens; i++) {
-            batch.pos     [i] = pos_0 + i;
-            batch.n_seq_id[i] = 1;
-            batch.seq_id  [i] = seq_id_0.data();
-            batch.logits  [i] = false;
-        }
-    }
-};
-
 static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
     llama_tokens generated_tokens;
     for (int i = 0; i < n_predict; i++) {
@@ -243,7 +210,7 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect
         return 1;
     }
 
-    ctx.n_past += mtmd_helper_get_n_tokens(chunks);
+    ctx.n_past += mtmd_helper_get_n_pos(chunks);
 
     return 0;
 }
@@ -371,6 +338,7 @@ int main(int argc, char ** argv) {
         }
     }
     if (g_is_interrupted) LOG("\nInterrupted by user\n");
+    LOG("\n\n");
     llama_perf_context_print(ctx.lctx);
     return g_is_interrupted ? 130 : 0;
 }
index f95f0503569f97ca636f0589415788fc30548530..7081fd7352bb772ce86e71575a46dd557f8f4b51 100644 (file)
@@ -40,11 +40,14 @@ struct mtmd_context {
     llama_token tok_sli_img_end   = LLAMA_TOKEN_NULL; // single slice
     llama_token tok_row_end       = LLAMA_TOKEN_NULL; // end of row
 
+    bool use_mrope = false; // for Qwen2VL, we need to use M-RoPE
+
     // TODO @ngxson : add timings
 
     mtmd_context(const char * mmproj_fname,
                    const llama_model * text_model,
                    const mtmd_context_params & ctx_params) :
+        text_model   (text_model),
         print_timings(ctx_params.print_timings),
         n_threads    (ctx_params.n_threads),
         image_marker (ctx_params.image_marker)
@@ -56,9 +59,8 @@ struct mtmd_context {
         if (!ctx_clip) {
             throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname));
         }
-        this->text_model = text_model;
 
-        GGML_ASSERT(!clip_is_qwen2vl(ctx_clip) && "Qwen2VL model is not supported yet, use llama-qwen2vl-cli instead");
+        use_mrope = clip_is_qwen2vl(ctx_clip);
 
         int minicpmv_version = clip_is_minicpmv(ctx_clip);
         if (minicpmv_version == 2) {
@@ -126,6 +128,7 @@ struct mtmd_image_tokens_data {
 struct mtmd_image_tokens {
     uint32_t nx; // number of tokens in x direction
     uint32_t ny; // number of tokens in y direction
+    bool use_mrope_pos = false; // use M-RoPE position counting (the whole image is 1 temporal position)
     uint32_t n_tokens() const { return nx * ny; }
     clip_image_f32_batch batch_f32; // preprocessed image patches
     std::string id; // optional user-defined ID, useful for KV cache tracking
@@ -202,6 +205,13 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
         string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
     }
 
+    else if (proj_type == PROJECTOR_TYPE_QWEN2VL || proj_type == PROJECTOR_TYPE_QWEN25VL) {
+        // <|vision_start|> ... (image embeddings) ... <|vision_end|>
+        marker_modified = "<|vision_start|>" + ctx->image_marker + "<|vision_end|>";
+        string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
+
+    }
+
     // llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
 
     std::vector<std::string> parts = string_split_str(prompt_modified, ctx->image_marker);
@@ -226,7 +236,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
 
         for (auto & entry : batch_f32.entries) {
             mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
-            image_tokens->nx = clip_n_patches_by_img(ctx->ctx_clip, entry.get());
+            image_tokens->nx = clip_n_output_tokens(ctx->ctx_clip, entry.get());
             image_tokens->ny = 1;
             image_tokens->batch_f32.entries.push_back(std::move(entry));
             image_tokens->id = id;
@@ -322,12 +332,20 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
             } else {
                 size_t n_tokens = 0;
                 for (const auto & entry : batch_f32.entries) {
-                    n_tokens += clip_n_patches_by_img(ctx->ctx_clip, entry.get());
+                    n_tokens += clip_n_output_tokens(ctx->ctx_clip, entry.get());
                 }
 
                 mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
-                image_tokens->nx = n_tokens;
-                image_tokens->ny = 1; // TODO
+                if (ctx->use_mrope) {
+                    // for Qwen2VL, we need this information for M-RoPE decoding positions
+                    image_tokens->nx = clip_n_output_tokens_x(ctx->ctx_clip, batch_f32.entries[0].get());
+                    image_tokens->ny = clip_n_output_tokens_y(ctx->ctx_clip, batch_f32.entries[0].get());
+                    image_tokens->use_mrope_pos = true;
+                } else {
+                    // other models, we only need the total number of tokens
+                    image_tokens->nx = n_tokens;
+                    image_tokens->ny = 1;
+                }
                 image_tokens->batch_f32 = std::move(batch_f32);
                 image_tokens->id = bitmaps[i_img].id; // optional
 
@@ -372,6 +390,13 @@ std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
     return image_tokens->id;
 }
 
+llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
+    if (image_tokens->use_mrope_pos) {
+        return 1; // for M-RoPE, the whole image is 1 in temporal dimension
+    }
+    return image_tokens->n_tokens();
+}
+
 int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
     int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
     ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd);
@@ -389,7 +414,7 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens)
         // TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode()
         const auto & entries = image_tokens->batch_f32.entries;
         for (size_t i = 0; i < entries.size(); i++) {
-            int n_tokens_per_image = clip_n_patches_by_img(ctx->ctx_clip, entries[i].get());
+            int n_tokens_per_image = clip_n_output_tokens(ctx->ctx_clip, entries[i].get());
             ok = clip_image_encode(
                 ctx->ctx_clip,
                 ctx->n_threads,
@@ -417,7 +442,7 @@ size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks) {
         if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
             n_tokens += chunk.tokens_text.size();
         } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
-            n_tokens += chunk.tokens_image->n_tokens();
+            n_tokens += mtmd_image_tokens_get_n_tokens(chunk.tokens_image.get());
         } else {
             GGML_ASSERT(false && "chunk type not supported");
         }
@@ -425,22 +450,38 @@ size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks) {
     return n_tokens;
 }
 
+llama_pos mtmd_helper_get_n_pos(mtmd_input_chunks & chunks) {
+    llama_pos n_pos = 0;
+    for (auto & chunk : chunks) {
+        if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
+            n_pos += chunk.tokens_text.size();
+        } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
+            n_pos += mtmd_image_tokens_get_n_pos(chunk.tokens_image.get());
+        } else {
+            GGML_ASSERT(false && "chunk type not supported");
+        }
+    }
+    return n_pos;
+}
+
 // helper struct to make working with embd batch easier
 // note: this will be removed after llama_batch_ext refactoring
 struct decode_embd_batch {
+    int n_pos_per_embd;
+    int n_mmproj_embd;
     std::vector<llama_pos>      pos;
+    std::vector<llama_pos>      pos_view; // used by mrope
     std::vector<int32_t>        n_seq_id;
     std::vector<llama_seq_id>   seq_id_0;
     std::vector<llama_seq_id *> seq_ids;
     std::vector<int8_t>         logits;
     llama_batch batch;
-    decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
-        pos     .resize(n_tokens);
+    decode_embd_batch(float * embd, int32_t n_tokens, int n_pos_per_embd, int n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) {
+        pos     .resize(n_tokens * n_pos_per_embd);
         n_seq_id.resize(n_tokens);
         seq_ids .resize(n_tokens + 1);
         logits  .resize(n_tokens);
         seq_id_0.resize(1);
-        seq_id_0[0] = seq_id;
         seq_ids [n_tokens] = nullptr;
         batch = {
             /*n_tokens       =*/ n_tokens,
@@ -451,13 +492,64 @@ struct decode_embd_batch {
             /*seq_id         =*/ seq_ids.data(),
             /*logits         =*/ logits.data(),
         };
-        for (int i = 0; i < n_tokens; i++) {
+    }
+
+    void set_position_normal(llama_pos pos_0, llama_seq_id seq_id) {
+        seq_id_0[0] = seq_id;
+        for (int i = 0; i < batch.n_tokens; i++) {
             batch.pos     [i] = pos_0 + i;
             batch.n_seq_id[i] = 1;
             batch.seq_id  [i] = seq_id_0.data();
             batch.logits  [i] = false;
         }
     }
+
+    void set_position_mrope(llama_pos pos_0, int nx, int ny, llama_seq_id seq_id) {
+        GGML_ASSERT(n_pos_per_embd == 4);
+        seq_id_0[0] = seq_id;
+        for (int y = 0; y < ny; y++) {
+            for (int x = 0; x < nx; x++) {
+                int i = y * nx + x;
+                pos[i                     ] = pos_0;
+                pos[i + batch.n_tokens    ] = pos_0 + y;
+                pos[i + batch.n_tokens * 2] = pos_0 + x;
+                pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused
+            }
+        }
+        for (int i = 0; i < batch.n_tokens; i++) {
+            batch.n_seq_id[i] = 1;
+            batch.seq_id  [i] = seq_id_0.data();
+            batch.logits  [i] = false;
+        }
+    }
+
+    llama_batch get_view(int offset, int n_tokens) {
+        llama_pos * pos_ptr;
+        pos_view.clear();
+        pos_view.resize(n_tokens * n_pos_per_embd);
+        if (n_pos_per_embd > 1) {
+            // mrope
+            // for example, with layout of src: 1234...1234...1234...1234...
+            //       offset 2 will give us dst: 34...34...34...34...
+            for (int i = 0; i < n_pos_per_embd; i++) {
+                auto src = pos.begin() + i * batch.n_tokens + offset;
+                pos_view.insert(pos_view.end(), src, src + n_tokens);
+            }
+            pos_ptr = pos_view.data();
+        } else {
+            // normal
+            pos_ptr = pos.data() + offset;
+        }
+        return {
+            /*n_tokens       =*/ n_tokens,
+            /*tokens         =*/ nullptr,
+            /*embd           =*/ batch.embd     + offset * n_mmproj_embd,
+            /*pos            =*/ pos_ptr,
+            /*n_seq_id       =*/ batch.n_seq_id + offset,
+            /*seq_id         =*/ batch.seq_id   + offset,
+            /*logits         =*/ batch.logits   + offset,
+        };
+    }
 };
 
 int32_t mtmd_helper_eval(mtmd_context * ctx,
@@ -470,6 +562,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
     llama_pos n_past = pos0;
     llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
     int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
+    int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1;
 
     for (auto & chunk : chunks) {
         bool is_last = &chunk == &chunks.back();
@@ -517,6 +610,16 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
             int32_t i_batch = 0;
             int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
             float * embd = mtmd_get_output_embd(ctx);
+            decode_embd_batch batch_embd(embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
+
+            const int nx = mtmd_image_tokens_get_nx(chunk.tokens_image.get());
+            const int ny = mtmd_image_tokens_get_ny(chunk.tokens_image.get());
+
+            if (mtmd_decode_use_mrope(ctx)) {
+                batch_embd.set_position_mrope(n_past, nx, ny, seq_id);
+            } else {
+                batch_embd.set_position_normal(n_past, seq_id);
+            }
 
             if (mtmd_decode_use_non_causal(ctx)) {
                 llama_set_causal_attn(lctx, false);
@@ -524,15 +627,14 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
             }
 
             while (i_batch < n_img_batches) { // split into batches
-                int32_t pos_offset = i_batch*n_batch;
-                int32_t n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
-                float * embd_batch = embd + pos_offset*n_mmproj_embd;
-                decode_embd_batch batch_img(embd_batch, n_tokens_batch, n_past, 0);
+                int pos_offset = i_batch*n_batch;
+                int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
+                llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);
 
-                printf("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
+                LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
 
                 int64_t t1 = ggml_time_ms();
-                ret = llama_decode(lctx, batch_img.batch);
+                ret = llama_decode(lctx, batch_embd_view);
                 if (ret != 0) {
                     LOG_ERR("failed to decode image\n");
                     llama_set_causal_attn(lctx, true); // restore causal attn
@@ -545,9 +647,11 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
                 }
 
                 i_batch++;
-                n_past += n_tokens_batch;
             }
 
+            // for mrope, one image is one single **temporal** position
+            n_past += mtmd_decode_use_mrope(ctx) ? 1 : n_tokens;
+
             if (mtmd_decode_use_non_causal(ctx)) {
                 llama_set_causal_attn(lctx, true);
             }
@@ -595,6 +699,10 @@ bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
     return false;
 }
 
+bool mtmd_decode_use_mrope(mtmd_context * ctx) {
+    return ctx->use_mrope;
+}
+
 void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) {
     mtmd_image_tokens_free(val);
 }
index 78be192dd6eb65dc13dfe4d68661d19119863ceb..6805e5e4816c325f0566df3eb360c95020e6c5f9 100644 (file)
@@ -102,6 +102,7 @@ MTMD_API size_t      mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * im
 MTMD_API size_t      mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens);
 MTMD_API size_t      mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens);
 MTMD_API std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens);
+MTMD_API llama_pos   mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens); // number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
 MTMD_API void        mtmd_image_tokens_free(mtmd_image_tokens * image_tokens);
 
 // returns 0 on success
@@ -114,15 +115,21 @@ MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
 // whether we need to set non-causal mask before llama_decode
 MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);
 
+// whether the current model use M-RoPE for llama_decode
+MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx);
+
 
 
 //
 // helper functions (can be implemented based on other functions)
 //
 
-// helper to count the total number of tokens from a list of chunks, useful to keep track of n_past
+// helper to count the total number of tokens from a list of chunks, useful to keep track of KV cache
 MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks);
 
+// helper to count the total position of tokens from a list of chunks, useful to keep track of n_past
+MTMD_API llama_pos mtmd_helper_get_n_pos(mtmd_input_chunks & chunks);
+
 // helper function that automatically:
 // 1. run llama_decode() on text chunks
 // 2. run mtmd_encode() on image chunks, then mtmd_get_output_embd() and then llama_decode()
diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp
deleted file mode 100644 (file)
index 1e54851..0000000
+++ /dev/null
@@ -1,634 +0,0 @@
-#include "arg.h"
-#include "base64.hpp"
-#include "log.h"
-#include "common.h"
-#include "sampling.h"
-#include "clip.h"
-#include "llava.h"
-#include "llama.h"
-#include "ggml.h"
-
-#ifdef GGML_USE_CUDA
-#include "ggml-cuda.h"
-#endif
-#ifdef NDEBUG
-#include "ggml-alloc.h"
-#include "ggml-backend.h"
-#endif
-
-#include <cstdio>
-#include <cstdlib>
-#include <cstring>
-#include <vector>
-#include <algorithm>
-#include <iostream>
-#include <fstream>
-#include <limits>
-#include <cassert>
-#include <cmath>
-
-
-static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed,
-                                     int n_batch, int * n_past, int * st_pos_id, struct clip_image_size * image_size) {
-    int n_embd  = llama_model_n_embd(llama_get_model(ctx_llama));
-    const int patch_size = 14 * 2;
-    const int ph = image_size->height / patch_size + (image_size->height % patch_size > 0);
-    const int pw = image_size->width / patch_size + (image_size->width % patch_size > 0);
-    auto img_tokens = image_embed->n_image_pos;
-    // llama_pos mrope_pos[img_tokens * 4];
-    std::vector<llama_pos> mrope_pos;
-    mrope_pos.resize(img_tokens * 4);
-
-    for (int y = 0; y < ph; y++)
-    {
-        for (int x = 0; x < pw; x++)
-        {
-            int i = y * pw + x;
-            mrope_pos[i] = *st_pos_id;
-            mrope_pos[i + img_tokens] = *st_pos_id + y;
-            mrope_pos[i + img_tokens * 2] = *st_pos_id + x;
-            mrope_pos[i + img_tokens * 3] = 0;
-        }
-    }
-    *st_pos_id += std::max(pw, ph);
-
-    int processed = 0;
-    std::vector<llama_pos> batch_mrope_pos;
-    batch_mrope_pos.resize(img_tokens * 4);
-
-    for (int i = 0; i < img_tokens; i += n_batch) {
-        int n_eval = img_tokens - i;
-        if (n_eval > n_batch) {
-            n_eval = n_batch;
-        }
-
-        // llama_pos batch_mrope_pos[n_eval * 4];
-        std::fill(batch_mrope_pos.begin(), batch_mrope_pos.end(), 0);
-        memcpy(batch_mrope_pos.data(), &mrope_pos[processed], n_eval * sizeof(llama_pos));
-        memcpy(&batch_mrope_pos[n_eval * 1], &mrope_pos[img_tokens * 1 + processed], n_eval * sizeof(llama_pos));
-        memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos));
-        memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos));
-
-        llama_batch batch = {
-            int32_t(n_eval),                // n_tokens
-            nullptr,                        // token
-            (image_embed->embed+i*n_embd),  // embed
-            batch_mrope_pos.data(),         // pos
-            nullptr,  // n_seq_id
-            nullptr,  // seq_id
-            nullptr,  // logits
-        };
-
-        if (llama_decode(ctx_llama, batch)) {
-            LOG_ERR("%s : failed to eval\n", __func__);
-            return false;
-        }
-        *n_past += n_eval;
-        processed += n_eval;
-    }
-    return true;
-}
-
-
-static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_token> tokens, int n_batch, int * n_past, int * st_pos_id) {
-    int N = (int) tokens.size();
-    for (int i = 0; i < N; i += n_batch) {
-        int n_eval = (int) tokens.size() - i;
-        if (n_eval > n_batch) {
-            n_eval = n_batch;
-        }
-        auto batch = llama_batch_get_one(&tokens[i], n_eval);
-
-        if (llama_decode(ctx_llama, batch)) {
-            LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
-            return false;
-        }
-        *n_past += n_eval;
-        *st_pos_id += n_eval;
-    }
-    return true;
-}
-
-static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past, int * st_pos_id) {
-    std::vector<llama_token> tokens;
-    tokens.push_back(id);
-    return eval_tokens(ctx_llama, tokens, 1, n_past, st_pos_id);
-}
-
-static bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, int * st_pos_id, bool add_bos){
-    std::string              str2     = str;
-    std::vector<llama_token> embd_inp = common_tokenize(ctx_llama, str2, add_bos, true);
-    eval_tokens(ctx_llama, embd_inp, n_batch, n_past, st_pos_id);
-    return true;
-}
-
-static const char * sample(struct common_sampler * smpl,
-                           struct llama_context * ctx_llama,
-                           int * n_past, int * st_pos_id) {
-    const llama_token id = common_sampler_sample(smpl, ctx_llama, -1);
-    common_sampler_accept(smpl, id, true);
-
-    const llama_model * model = llama_get_model(ctx_llama);
-    const llama_vocab * vocab = llama_model_get_vocab(model);
-
-    static std::string ret;
-    if (llama_vocab_is_eog(vocab, id)) {
-        ret = "</s>";
-    } else {
-        ret = common_token_to_piece(ctx_llama, id);
-    }
-    eval_id(ctx_llama, id, n_past, st_pos_id);
-    return ret.c_str();
-}
-
-static const char* IMG_BASE64_TAG_BEGIN = "<img src=\"data:image/jpeg;base64,";
-static const char* IMG_BASE64_TAG_END = "\">";
-
-static void find_image_tag_in_prompt(const std::string& prompt, size_t& begin_out, size_t& end_out) {
-    begin_out = prompt.find(IMG_BASE64_TAG_BEGIN);
-    end_out = prompt.find(IMG_BASE64_TAG_END, (begin_out == std::string::npos) ? 0UL : begin_out);
-}
-
-static bool prompt_contains_image(const std::string& prompt) {
-    size_t begin, end;
-    find_image_tag_in_prompt(prompt, begin, end);
-    return (begin != std::string::npos);
-}
-
-// replaces the base64 image tag in the prompt with `replacement`
-static llava_image_embed * llava_image_embed_make_with_prompt_base64(struct clip_ctx * ctx_clip, int n_threads, const std::string& prompt) {
-    size_t img_base64_str_start, img_base64_str_end;
-    find_image_tag_in_prompt(prompt, img_base64_str_start, img_base64_str_end);
-    if (img_base64_str_start == std::string::npos || img_base64_str_end == std::string::npos) {
-        LOG_ERR("%s: invalid base64 image tag. must be %s<base64 byte string>%s\n", __func__, IMG_BASE64_TAG_BEGIN, IMG_BASE64_TAG_END);
-        return NULL;
-    }
-
-    auto base64_bytes_start = img_base64_str_start + strlen(IMG_BASE64_TAG_BEGIN);
-    auto base64_bytes_count = img_base64_str_end - base64_bytes_start;
-    auto base64_str = prompt.substr(base64_bytes_start, base64_bytes_count );
-
-    auto required_bytes = base64::required_encode_size(base64_str.size());
-    auto img_bytes = std::vector<unsigned char>(required_bytes);
-    base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin());
-
-    auto embed = llava_image_embed_make_with_bytes(ctx_clip, n_threads, img_bytes.data(), img_bytes.size());
-    if (!embed) {
-        LOG_ERR("%s: could not load image from base64 string.\n", __func__);
-        return NULL;
-    }
-
-    return embed;
-}
-
-static std::string remove_image_from_prompt(const std::string& prompt, const char * replacement = "") {
-    size_t begin, end;
-    find_image_tag_in_prompt(prompt, begin, end);
-    if (begin == std::string::npos || end == std::string::npos) {
-        return prompt;
-    }
-    auto pre = prompt.substr(0, begin);
-    auto post = prompt.substr(end + strlen(IMG_BASE64_TAG_END));
-    return pre + replacement + post;
-}
-
-struct llava_context {
-    struct clip_ctx * ctx_clip = NULL;
-    struct llama_context * ctx_llama = NULL;
-    struct llama_model * model = NULL;
-};
-
-static void print_usage(int, char ** argv) {
-    LOG("\n example usage:\n");
-    LOG("\n     %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> --image <path/to/another/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
-    LOG("\n note: a lower temperature value like 0.1 is recommended for better quality.\n");
-}
-
-static struct llava_image_embed * load_image(llava_context * ctx_llava, common_params * params, const std::string & fname) {
-
-    // load and preprocess the image
-    llava_image_embed * embed = NULL;
-    auto prompt = params->prompt;
-    if (prompt_contains_image(prompt)) {
-        if (!params->image.empty()) {
-            LOG_INF("using base64 encoded image instead of command line image path\n");
-        }
-        embed = llava_image_embed_make_with_prompt_base64(ctx_llava->ctx_clip, params->cpuparams.n_threads, prompt);
-        if (!embed) {
-            LOG_ERR("%s: can't load image from prompt\n", __func__);
-            return NULL;
-        }
-        params->prompt = remove_image_from_prompt(prompt);
-    } else {
-        embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->cpuparams.n_threads, fname.c_str());
-        if (!embed) {
-            fprintf(stderr, "%s: is %s really an image file?\n", __func__, fname.c_str());
-            return NULL;
-        }
-    }
-
-    return embed;
-}
-
-static void process_prompt(struct llava_context * ctx_llava, struct llava_image_embed * image_embed, common_params * params, const std::string & prompt) {
-    int n_past = 0;
-    int cur_pos_id = 0;
-
-    const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;
-
-    std::string system_prompt, user_prompt;
-    size_t image_pos = prompt.find("<|vision_start|>");
-    if (image_pos != std::string::npos) {
-        // new templating mode: Provide the full prompt including system message and use <image> as a placeholder for the image
-        system_prompt = prompt.substr(0, image_pos);
-        user_prompt = prompt.substr(image_pos + std::string("<|vision_pad|>").length());
-        LOG_INF("system_prompt: %s\n", system_prompt.c_str());
-        if (params->verbose_prompt) {
-            auto tmp = common_tokenize(ctx_llava->ctx_llama, system_prompt, true, true);
-            for (int i = 0; i < (int) tmp.size(); i++) {
-                LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
-            }
-        }
-        LOG_INF("user_prompt: %s\n", user_prompt.c_str());
-        if (params->verbose_prompt) {
-            auto tmp = common_tokenize(ctx_llava->ctx_llama, user_prompt, true, true);
-            for (int i = 0; i < (int) tmp.size(); i++) {
-                LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
-            }
-        }
-    } else {
-        // llava-1.5 native mode
-        system_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|>";
-        user_prompt = "<|vision_end|>" + prompt + "<|im_end|>\n<|im_start|>assistant\n";
-        if (params->verbose_prompt) {
-            auto tmp = common_tokenize(ctx_llava->ctx_llama, user_prompt, true, true);
-            for (int i = 0; i < (int) tmp.size(); i++) {
-                LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
-            }
-        }
-    }
-
-    eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, &cur_pos_id, true);
-    if (image_embed != nullptr) {
-        auto image_size = clip_get_load_image_size(ctx_llava->ctx_clip);
-        qwen2vl_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past, &cur_pos_id, image_size);
-    }
-    eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, &cur_pos_id, false);
-
-    // generate the response
-
-    LOG("\n");
-
-    struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling);
-    if (!smpl) {
-        LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
-        exit(1);
-    }
-
-    std::string response = "";
-    for (int i = 0; i < max_tgt_len; i++) {
-        const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past, &cur_pos_id);
-        response += tmp;
-        if (strcmp(tmp, "</s>") == 0) break;
-        if (strstr(tmp, "###")) break; // Yi-VL behavior
-        LOG("%s", tmp);
-        if (strstr(response.c_str(), "<|im_end|>")) break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works)
-        if (strstr(response.c_str(), "<|im_start|>")) break; // Yi-34B llava-1.6
-        if (strstr(response.c_str(), "USER:")) break; // mistral llava-1.6
-
-        fflush(stdout);
-    }
-
-    common_sampler_free(smpl);
-    LOG("\n");
-}
-
-static struct llama_model * llava_init(common_params * params) {
-    llama_backend_init();
-    llama_numa_init(params->numa);
-
-    llama_model_params model_params = common_model_params_to_llama(*params);
-
-    llama_model * model = llama_model_load_from_file(params->model.path.c_str(), model_params);
-    if (model == NULL) {
-        LOG_ERR("%s: unable to load model\n" , __func__);
-        return NULL;
-    }
-    return model;
-}
-
-static struct llava_context * llava_init_context(common_params * params, llama_model * model) {
-    const char * clip_path = params->mmproj.path.c_str();
-
-    auto prompt = params->prompt;
-    if (prompt.empty()) {
-        prompt = "describe the image in detail.";
-    }
-
-    auto ctx_clip = clip_model_load(clip_path, GGML_LOG_LEVEL_INFO);
-
-    llama_context_params ctx_params = common_context_params_to_llama(*params);
-    ctx_params.n_ctx           = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings
-
-    llama_context * ctx_llama = llama_init_from_model(model, ctx_params);
-
-    if (ctx_llama == NULL) {
-        LOG_ERR("%s: failed to create the llama_context\n" , __func__);
-        return NULL;
-    }
-
-    auto * ctx_llava = (struct llava_context *)malloc(sizeof(llava_context));
-
-    ctx_llava->ctx_llama = ctx_llama;
-    ctx_llava->ctx_clip = ctx_clip;
-    ctx_llava->model = model;
-    return ctx_llava;
-}
-
-static void llava_free(struct llava_context * ctx_llava) {
-    if (ctx_llava->ctx_clip) {
-        clip_free(ctx_llava->ctx_clip);
-        ctx_llava->ctx_clip = NULL;
-    }
-
-    llama_free(ctx_llava->ctx_llama);
-    llama_model_free(ctx_llava->model);
-    llama_backend_free();
-}
-
-#ifndef NDEBUG
-
-static void debug_test_mrope_2d() {
-    // 1. Initialize backend
-    ggml_backend_t backend = NULL;
-    std::string backend_name = "";
-// #ifdef GGML_USE_CUDA
-//     fprintf(stderr, "%s: using CUDA backend\n", __func__);
-//     backend = ggml_backend_cuda_init(0); // init device 0
-//     backend_name = "cuda";
-//     if (!backend) {
-//         fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
-//     }
-// #endif
-    // if there aren't GPU Backends fallback to CPU backend
-    if (!backend) {
-        backend = ggml_backend_cpu_init();
-        backend_name = "cpu";
-    }
-
-    // Calculate the size needed to allocate
-    size_t ctx_size = 0;
-    ctx_size += 2 * ggml_tensor_overhead(); // tensors
-    // no need to allocate anything else!
-
-    // 2. Allocate `ggml_context` to store tensor data
-    struct ggml_init_params params = {
-        /*.mem_size   =*/ ctx_size,
-        /*.mem_buffer =*/ NULL,
-        /*.no_alloc   =*/ true, // the tensors will be allocated later by ggml_backend_alloc_ctx_tensors()
-    };
-    struct ggml_context * ctx = ggml_init(params);
-
-    struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 128, 12, 30);
-    ggml_set_name(inp_raw, "inp_raw");
-    ggml_set_input(inp_raw);
-
-    struct ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 30 * 4);
-    ggml_set_name(pos, "pos");
-    ggml_set_input(pos);
-
-    std::vector<float> dummy_q;
-    dummy_q.resize(128 * 12 * 30);
-    std::fill(dummy_q.begin(), dummy_q.end(), 0.1);
-    // memcpy(inp_raw->data, dummy_q.data(), 128 * 12 * 30 * ggml_element_size(inp_raw));
-
-    std::vector<int> pos_id;
-    pos_id.resize(30 * 4);
-    for (int i = 0; i < 30; i ++) {
-        pos_id[i] = i;
-        pos_id[i + 30] = i + 10;
-        pos_id[i + 60] = i + 20;
-        pos_id[i + 90] = i + 30;
-    }
-    int sections[4] = {32, 32, 0, 0};
-
-    // 4. Allocate a `ggml_backend_buffer` to store all tensors
-    ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
-
-    // 5. Copy tensor data from main memory (RAM) to backend buffer
-    ggml_backend_tensor_set(inp_raw, dummy_q.data(), 0, ggml_nbytes(inp_raw));
-    ggml_backend_tensor_set(pos, pos_id.data(), 0, ggml_nbytes(pos));
-
-    // 6. Create a `ggml_cgraph` for mul_mat operation
-    struct ggml_cgraph * gf = NULL;
-    struct ggml_context * ctx_cgraph = NULL;
-
-    // create a temporally context to build the graph
-    struct ggml_init_params params0 = {
-        /*.mem_size   =*/ ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(),
-        /*.mem_buffer =*/ NULL,
-        /*.no_alloc   =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()
-    };
-    ctx_cgraph = ggml_init(params0);
-    gf = ggml_new_graph(ctx_cgraph);
-
-    struct ggml_tensor * result0 = ggml_rope_multi(
-        ctx_cgraph, inp_raw, pos, nullptr,
-        128/2, sections, LLAMA_ROPE_TYPE_VISION, 32768, 1000000, 1,
-        0, 1, 32, 1);
-
-    // Add "result" tensor and all of its dependencies to the cgraph
-    ggml_build_forward_expand(gf, result0);
-
-    // 7. Create a `ggml_gallocr` for cgraph computation
-    ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
-    ggml_gallocr_alloc_graph(allocr, gf);
-
-    // 9. Run the computation
-    int n_threads = 1; // Optional: number of threads to perform some operations with multi-threading
-    if (ggml_backend_is_cpu(backend)) {
-        ggml_backend_cpu_set_n_threads(backend, n_threads);
-    }
-    ggml_backend_graph_compute(backend, gf);
-
-    // 10. Retrieve results (output tensors)
-    // in this example, output tensor is always the last tensor in the graph
-    struct ggml_tensor * result = result0;
-    // struct ggml_tensor * result = gf->nodes[gf->n_nodes - 1];
-    float * result_data = (float *)malloc(ggml_nbytes(result));
-    // because the tensor data is stored in device buffer, we need to copy it back to RAM
-    ggml_backend_tensor_get(result, result_data, 0, ggml_nbytes(result));
-    const std::string bin_file = "mrope_2d_" + backend_name +".bin";
-    std::ofstream outFile(bin_file, std::ios::binary);
-
-    if (outFile.is_open()) {
-        outFile.write(reinterpret_cast<const char*>(result_data), ggml_nbytes(result));
-        outFile.close();
-        std::cout << "Data successfully written to " + bin_file << std::endl;
-    } else {
-        std::cerr << "Error opening file!" << std::endl;
-    }
-
-    free(result_data);
-    // 11. Free memory and exit
-    ggml_free(ctx_cgraph);
-    ggml_gallocr_free(allocr);
-    ggml_free(ctx);
-    ggml_backend_buffer_free(buffer);
-    ggml_backend_free(backend);
-}
-
-enum model_output_type {
-    conv3d,
-    patch_embed,
-    patch_win_attn_scatter,
-    first_attn_layer,
-    last_attn_layer,
-    attn_softmax,
-    final_layer,
-};
-
-static void debug_dump_img_embed(struct llava_context * ctx_llava, model_output_type output_type) {
-    constexpr int ih = 140;
-    constexpr int iw = 196;
-    // constexpr int ih = 56;
-    // constexpr int iw = 56;
-    // int n_embd  = llama_model_n_embd(llama_get_model(ctx_llava->ctx_llama));
-    int n_embd  = 1280;
-    int merge = 1;
-    if (output_type == model_output_type::final_layer) {
-        n_embd  = 2048;
-        merge = 2;
-    }
-    else if (output_type == model_output_type::attn_softmax) {
-        merge = 1;
-        n_embd = (ih/14/merge) * (iw/14/merge) * 16;
-    }
-
-    int ne = (ih/14/merge) * (iw/14/merge) * n_embd;
-    float vals[iw * ih * 3];
-    // float embd[ne];
-    std::vector<float> embd;
-    embd.resize(ne);
-
-    for (int i = 0; i < iw*ih; i++)
-    {
-        for (int c = 0; c < 3; c++)
-            vals[i * 3 + c] = (float)i / (iw*ih);
-    }
-
-    clip_encode_float_image(ctx_llava->ctx_clip, 8, vals, ih, iw, embd.data());
-
-    std::string file_postfix = "";
-    switch (output_type)
-    {
-    case model_output_type::conv3d:
-        file_postfix = "conv3d";
-        break;
-    case model_output_type::patch_embed:
-        file_postfix = "patch_embed";
-        break;
-    case model_output_type::patch_win_attn_scatter:
-        file_postfix = "scatter";
-        break;
-    case model_output_type::first_attn_layer:
-        file_postfix = "first_attn";
-        break;
-    case model_output_type::last_attn_layer:
-        file_postfix = "last_attn";
-        break;
-    case model_output_type::attn_softmax:
-        file_postfix = "attn_softmax";
-        break;
-    case model_output_type::final_layer:
-        file_postfix = "final";
-        break;
-    default:
-        break;
-    }
-    auto output_path = "img_embed_" + file_postfix + ".bin";
-
-    std::ofstream outFile(output_path, std::ios::binary);
-    if (outFile.is_open()) {
-        outFile.write(reinterpret_cast<const char*>(embd.data()), ne * sizeof(float));
-
-        outFile.close();
-        std::cout << "Data successfully written to ::[ " << output_path << std::endl;
-    } else {
-        std::cerr << "Error opening file!" << std::endl;
-    }
-}
-
-#endif
-
-
-int main(int argc, char ** argv) {
-    ggml_time_init();
-
-    common_params params;
-
-    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, print_usage)) {
-        return 1;
-    }
-
-    common_init();
-
-    if (params.mmproj.path.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) {
-        print_usage(argc, argv);
-        return 1;
-    }
-
-    auto * model = llava_init(&params);
-    if (model == NULL) {
-        fprintf(stderr, "%s: error: failed to init llava model\n", __func__);
-        return 1;
-    }
-
-    if (prompt_contains_image(params.prompt)) {
-        auto * ctx_llava = llava_init_context(&params, model);
-
-        auto * image_embed = load_image(ctx_llava, &params, "");
-
-        // process the prompt
-        process_prompt(ctx_llava, image_embed, &params, params.prompt);
-
-        llama_perf_context_print(ctx_llava->ctx_llama);
-        llava_image_embed_free(image_embed);
-        ctx_llava->model = NULL;
-        llava_free(ctx_llava);
-#ifndef NDEBUG
-    } else if (params.image[0].empty()) {
-        auto ctx_llava = llava_init_context(&params, model);
-
-        // debug_test_mrope_2d();
-        debug_dump_img_embed(ctx_llava, model_output_type::final_layer);
-        // debug_dump_img_embed(ctx_llava, model_output_type::last_attn_layer);
-
-        llama_perf_context_print(ctx_llava->ctx_llama);
-        ctx_llava->model = NULL;
-        llava_free(ctx_llava);
-#endif
-    } else {
-        for (auto & image : params.image) {
-            auto * ctx_llava = llava_init_context(&params, model);
-
-            auto * image_embed = load_image(ctx_llava, &params, image);
-            if (!image_embed) {
-                LOG_ERR("%s: failed to load image %s. Terminating\n\n", __func__, image.c_str());
-                return 1;
-            }
-
-            // process the prompt
-            process_prompt(ctx_llava, image_embed, &params, params.prompt);
-
-            llama_perf_context_print(ctx_llava->ctx_llama);
-            llava_image_embed_free(image_embed);
-            ctx_llava->model = NULL;
-            llava_free(ctx_llava);
-        }
-    }
-
-    llama_model_free(model);
-
-    return 0;
-}
diff --git a/examples/llava/qwen2vl-test.cpp b/examples/llava/qwen2vl-test.cpp
new file mode 100644 (file)
index 0000000..7f9e3dc
--- /dev/null
@@ -0,0 +1,636 @@
+#include "arg.h"
+#include "base64.hpp"
+#include "log.h"
+#include "common.h"
+#include "sampling.h"
+#include "clip.h"
+#include "llava.h"
+#include "llama.h"
+#include "ggml.h"
+
+#ifdef GGML_USE_CUDA
+#include "ggml-cuda.h"
+#endif
+#ifdef NDEBUG
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
+#endif
+
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include <vector>
+#include <algorithm>
+#include <iostream>
+#include <fstream>
+#include <limits>
+#include <cassert>
+#include <cmath>
+
+// THIS FILE IS ONLY USED FOR TESTING THE QWEN2VL MODEL
+// IT IS NOT A PRODUCTION CODE
+
+static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed,
+                                     int n_batch, int * n_past, int * st_pos_id, struct clip_image_size * image_size) {
+    int n_embd  = llama_model_n_embd(llama_get_model(ctx_llama));
+    const int patch_size = 14 * 2;
+    const int ph = image_size->height / patch_size + (image_size->height % patch_size > 0);
+    const int pw = image_size->width / patch_size + (image_size->width % patch_size > 0);
+    auto img_tokens = image_embed->n_image_pos;
+    // llama_pos mrope_pos[img_tokens * 4];
+    std::vector<llama_pos> mrope_pos;
+    mrope_pos.resize(img_tokens * 4);
+
+    for (int y = 0; y < ph; y++)
+    {
+        for (int x = 0; x < pw; x++)
+        {
+            int i = y * pw + x;
+            mrope_pos[i] = *st_pos_id;
+            mrope_pos[i + img_tokens] = *st_pos_id + y;
+            mrope_pos[i + img_tokens * 2] = *st_pos_id + x;
+            mrope_pos[i + img_tokens * 3] = 0;
+        }
+    }
+    *st_pos_id += std::max(pw, ph);
+
+    int processed = 0;
+    std::vector<llama_pos> batch_mrope_pos;
+    batch_mrope_pos.resize(img_tokens * 4);
+
+    for (int i = 0; i < img_tokens; i += n_batch) {
+        int n_eval = img_tokens - i;
+        if (n_eval > n_batch) {
+            n_eval = n_batch;
+        }
+
+        // llama_pos batch_mrope_pos[n_eval * 4];
+        std::fill(batch_mrope_pos.begin(), batch_mrope_pos.end(), 0);
+        memcpy(batch_mrope_pos.data(), &mrope_pos[processed], n_eval * sizeof(llama_pos));
+        memcpy(&batch_mrope_pos[n_eval * 1], &mrope_pos[img_tokens * 1 + processed], n_eval * sizeof(llama_pos));
+        memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos));
+        memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos));
+
+        llama_batch batch = {
+            int32_t(n_eval),                // n_tokens
+            nullptr,                        // token
+            (image_embed->embed+i*n_embd),  // embed
+            batch_mrope_pos.data(),         // pos
+            nullptr,  // n_seq_id
+            nullptr,  // seq_id
+            nullptr,  // logits
+        };
+
+        if (llama_decode(ctx_llama, batch)) {
+            LOG_ERR("%s : failed to eval\n", __func__);
+            return false;
+        }
+        *n_past += n_eval;
+        processed += n_eval;
+    }
+    return true;
+}
+
+
+static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_token> tokens, int n_batch, int * n_past, int * st_pos_id) {
+    int N = (int) tokens.size();
+    for (int i = 0; i < N; i += n_batch) {
+        int n_eval = (int) tokens.size() - i;
+        if (n_eval > n_batch) {
+            n_eval = n_batch;
+        }
+        auto batch = llama_batch_get_one(&tokens[i], n_eval);
+
+        if (llama_decode(ctx_llama, batch)) {
+            LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
+            return false;
+        }
+        *n_past += n_eval;
+        *st_pos_id += n_eval;
+    }
+    return true;
+}
+
+static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past, int * st_pos_id) {
+    std::vector<llama_token> tokens;
+    tokens.push_back(id);
+    return eval_tokens(ctx_llama, tokens, 1, n_past, st_pos_id);
+}
+
+static bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, int * st_pos_id, bool add_bos){
+    std::string              str2     = str;
+    std::vector<llama_token> embd_inp = common_tokenize(ctx_llama, str2, add_bos, true);
+    eval_tokens(ctx_llama, embd_inp, n_batch, n_past, st_pos_id);
+    return true;
+}
+
+static const char * sample(struct common_sampler * smpl,
+                           struct llama_context * ctx_llama,
+                           int * n_past, int * st_pos_id) {
+    const llama_token id = common_sampler_sample(smpl, ctx_llama, -1);
+    common_sampler_accept(smpl, id, true);
+
+    const llama_model * model = llama_get_model(ctx_llama);
+    const llama_vocab * vocab = llama_model_get_vocab(model);
+
+    static std::string ret;
+    if (llama_vocab_is_eog(vocab, id)) {
+        ret = "</s>";
+    } else {
+        ret = common_token_to_piece(ctx_llama, id);
+    }
+    eval_id(ctx_llama, id, n_past, st_pos_id);
+    return ret.c_str();
+}
+
+static const char* IMG_BASE64_TAG_BEGIN = "<img src=\"data:image/jpeg;base64,";
+static const char* IMG_BASE64_TAG_END = "\">";
+
+static void find_image_tag_in_prompt(const std::string& prompt, size_t& begin_out, size_t& end_out) {
+    begin_out = prompt.find(IMG_BASE64_TAG_BEGIN);
+    end_out = prompt.find(IMG_BASE64_TAG_END, (begin_out == std::string::npos) ? 0UL : begin_out);
+}
+
+static bool prompt_contains_image(const std::string& prompt) {
+    size_t begin, end;
+    find_image_tag_in_prompt(prompt, begin, end);
+    return (begin != std::string::npos);
+}
+
+// replaces the base64 image tag in the prompt with `replacement`
+static llava_image_embed * llava_image_embed_make_with_prompt_base64(struct clip_ctx * ctx_clip, int n_threads, const std::string& prompt) {
+    size_t img_base64_str_start, img_base64_str_end;
+    find_image_tag_in_prompt(prompt, img_base64_str_start, img_base64_str_end);
+    if (img_base64_str_start == std::string::npos || img_base64_str_end == std::string::npos) {
+        LOG_ERR("%s: invalid base64 image tag. must be %s<base64 byte string>%s\n", __func__, IMG_BASE64_TAG_BEGIN, IMG_BASE64_TAG_END);
+        return NULL;
+    }
+
+    auto base64_bytes_start = img_base64_str_start + strlen(IMG_BASE64_TAG_BEGIN);
+    auto base64_bytes_count = img_base64_str_end - base64_bytes_start;
+    auto base64_str = prompt.substr(base64_bytes_start, base64_bytes_count );
+
+    auto required_bytes = base64::required_encode_size(base64_str.size());
+    auto img_bytes = std::vector<unsigned char>(required_bytes);
+    base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin());
+
+    auto embed = llava_image_embed_make_with_bytes(ctx_clip, n_threads, img_bytes.data(), img_bytes.size());
+    if (!embed) {
+        LOG_ERR("%s: could not load image from base64 string.\n", __func__);
+        return NULL;
+    }
+
+    return embed;
+}
+
+static std::string remove_image_from_prompt(const std::string& prompt, const char * replacement = "") {
+    size_t begin, end;
+    find_image_tag_in_prompt(prompt, begin, end);
+    if (begin == std::string::npos || end == std::string::npos) {
+        return prompt;
+    }
+    auto pre = prompt.substr(0, begin);
+    auto post = prompt.substr(end + strlen(IMG_BASE64_TAG_END));
+    return pre + replacement + post;
+}
+
+struct llava_context {
+    struct clip_ctx * ctx_clip = NULL;
+    struct llama_context * ctx_llama = NULL;
+    struct llama_model * model = NULL;
+};
+
+static void print_usage(int, char ** argv) {
+    LOG("\n example usage:\n");
+    LOG("\n     %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> --image <path/to/another/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
+    LOG("\n note: a lower temperature value like 0.1 is recommended for better quality.\n");
+}
+
+static struct llava_image_embed * load_image(llava_context * ctx_llava, common_params * params, const std::string & fname) {
+
+    // load and preprocess the image
+    llava_image_embed * embed = NULL;
+    auto prompt = params->prompt;
+    if (prompt_contains_image(prompt)) {
+        if (!params->image.empty()) {
+            LOG_INF("using base64 encoded image instead of command line image path\n");
+        }
+        embed = llava_image_embed_make_with_prompt_base64(ctx_llava->ctx_clip, params->cpuparams.n_threads, prompt);
+        if (!embed) {
+            LOG_ERR("%s: can't load image from prompt\n", __func__);
+            return NULL;
+        }
+        params->prompt = remove_image_from_prompt(prompt);
+    } else {
+        embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->cpuparams.n_threads, fname.c_str());
+        if (!embed) {
+            fprintf(stderr, "%s: is %s really an image file?\n", __func__, fname.c_str());
+            return NULL;
+        }
+    }
+
+    return embed;
+}
+
+static void process_prompt(struct llava_context * ctx_llava, struct llava_image_embed * image_embed, common_params * params, const std::string & prompt) {
+    int n_past = 0;
+    int cur_pos_id = 0;
+
+    const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;
+
+    std::string system_prompt, user_prompt;
+    size_t image_pos = prompt.find("<|vision_start|>");
+    if (image_pos != std::string::npos) {
+        // new templating mode: Provide the full prompt including system message and use <image> as a placeholder for the image
+        system_prompt = prompt.substr(0, image_pos);
+        user_prompt = prompt.substr(image_pos + std::string("<|vision_pad|>").length());
+        LOG_INF("system_prompt: %s\n", system_prompt.c_str());
+        if (params->verbose_prompt) {
+            auto tmp = common_tokenize(ctx_llava->ctx_llama, system_prompt, true, true);
+            for (int i = 0; i < (int) tmp.size(); i++) {
+                LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
+            }
+        }
+        LOG_INF("user_prompt: %s\n", user_prompt.c_str());
+        if (params->verbose_prompt) {
+            auto tmp = common_tokenize(ctx_llava->ctx_llama, user_prompt, true, true);
+            for (int i = 0; i < (int) tmp.size(); i++) {
+                LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
+            }
+        }
+    } else {
+        // llava-1.5 native mode
+        system_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|>";
+        user_prompt = "<|vision_end|>" + prompt + "<|im_end|>\n<|im_start|>assistant\n";
+        if (params->verbose_prompt) {
+            auto tmp = common_tokenize(ctx_llava->ctx_llama, user_prompt, true, true);
+            for (int i = 0; i < (int) tmp.size(); i++) {
+                LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
+            }
+        }
+    }
+
+    eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, &cur_pos_id, true);
+    if (image_embed != nullptr) {
+        auto image_size = clip_get_load_image_size(ctx_llava->ctx_clip);
+        qwen2vl_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past, &cur_pos_id, image_size);
+    }
+    eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, &cur_pos_id, false);
+
+    // generate the response
+
+    LOG("\n");
+
+    struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling);
+    if (!smpl) {
+        LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
+        exit(1);
+    }
+
+    std::string response = "";
+    for (int i = 0; i < max_tgt_len; i++) {
+        const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past, &cur_pos_id);
+        response += tmp;
+        if (strcmp(tmp, "</s>") == 0) break;
+        if (strstr(tmp, "###")) break; // Yi-VL behavior
+        LOG("%s", tmp);
+        if (strstr(response.c_str(), "<|im_end|>")) break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works)
+        if (strstr(response.c_str(), "<|im_start|>")) break; // Yi-34B llava-1.6
+        if (strstr(response.c_str(), "USER:")) break; // mistral llava-1.6
+
+        fflush(stdout);
+    }
+
+    common_sampler_free(smpl);
+    LOG("\n");
+}
+
+static struct llama_model * llava_init(common_params * params) {
+    llama_backend_init();
+    llama_numa_init(params->numa);
+
+    llama_model_params model_params = common_model_params_to_llama(*params);
+
+    llama_model * model = llama_model_load_from_file(params->model.path.c_str(), model_params);
+    if (model == NULL) {
+        LOG_ERR("%s: unable to load model\n" , __func__);
+        return NULL;
+    }
+    return model;
+}
+
+static struct llava_context * llava_init_context(common_params * params, llama_model * model) {
+    const char * clip_path = params->mmproj.path.c_str();
+
+    auto prompt = params->prompt;
+    if (prompt.empty()) {
+        prompt = "describe the image in detail.";
+    }
+
+    auto ctx_clip = clip_model_load(clip_path, GGML_LOG_LEVEL_INFO);
+
+    llama_context_params ctx_params = common_context_params_to_llama(*params);
+    ctx_params.n_ctx           = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings
+
+    llama_context * ctx_llama = llama_init_from_model(model, ctx_params);
+
+    if (ctx_llama == NULL) {
+        LOG_ERR("%s: failed to create the llama_context\n" , __func__);
+        return NULL;
+    }
+
+    auto * ctx_llava = (struct llava_context *)malloc(sizeof(llava_context));
+
+    ctx_llava->ctx_llama = ctx_llama;
+    ctx_llava->ctx_clip = ctx_clip;
+    ctx_llava->model = model;
+    return ctx_llava;
+}
+
+static void llava_free(struct llava_context * ctx_llava) {
+    if (ctx_llava->ctx_clip) {
+        clip_free(ctx_llava->ctx_clip);
+        ctx_llava->ctx_clip = NULL;
+    }
+
+    llama_free(ctx_llava->ctx_llama);
+    llama_model_free(ctx_llava->model);
+    llama_backend_free();
+}
+
+#ifndef NDEBUG
+
+static void debug_test_mrope_2d() {
+    // 1. Initialize backend
+    ggml_backend_t backend = NULL;
+    std::string backend_name = "";
+// #ifdef GGML_USE_CUDA
+//     fprintf(stderr, "%s: using CUDA backend\n", __func__);
+//     backend = ggml_backend_cuda_init(0); // init device 0
+//     backend_name = "cuda";
+//     if (!backend) {
+//         fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
+//     }
+// #endif
+    // if there aren't GPU Backends fallback to CPU backend
+    if (!backend) {
+        backend = ggml_backend_cpu_init();
+        backend_name = "cpu";
+    }
+
+    // Calculate the size needed to allocate
+    size_t ctx_size = 0;
+    ctx_size += 2 * ggml_tensor_overhead(); // tensors
+    // no need to allocate anything else!
+
+    // 2. Allocate `ggml_context` to store tensor data
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ ctx_size,
+        /*.mem_buffer =*/ NULL,
+        /*.no_alloc   =*/ true, // the tensors will be allocated later by ggml_backend_alloc_ctx_tensors()
+    };
+    struct ggml_context * ctx = ggml_init(params);
+
+    struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 128, 12, 30);
+    ggml_set_name(inp_raw, "inp_raw");
+    ggml_set_input(inp_raw);
+
+    struct ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 30 * 4);
+    ggml_set_name(pos, "pos");
+    ggml_set_input(pos);
+
+    std::vector<float> dummy_q;
+    dummy_q.resize(128 * 12 * 30);
+    std::fill(dummy_q.begin(), dummy_q.end(), 0.1);
+    // memcpy(inp_raw->data, dummy_q.data(), 128 * 12 * 30 * ggml_element_size(inp_raw));
+
+    std::vector<int> pos_id;
+    pos_id.resize(30 * 4);
+    for (int i = 0; i < 30; i ++) {
+        pos_id[i] = i;
+        pos_id[i + 30] = i + 10;
+        pos_id[i + 60] = i + 20;
+        pos_id[i + 90] = i + 30;
+    }
+    int sections[4] = {32, 32, 0, 0};
+
+    // 4. Allocate a `ggml_backend_buffer` to store all tensors
+    ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
+
+    // 5. Copy tensor data from main memory (RAM) to backend buffer
+    ggml_backend_tensor_set(inp_raw, dummy_q.data(), 0, ggml_nbytes(inp_raw));
+    ggml_backend_tensor_set(pos, pos_id.data(), 0, ggml_nbytes(pos));
+
+    // 6. Create a `ggml_cgraph` for mul_mat operation
+    struct ggml_cgraph * gf = NULL;
+    struct ggml_context * ctx_cgraph = NULL;
+
+    // create a temporally context to build the graph
+    struct ggml_init_params params0 = {
+        /*.mem_size   =*/ ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(),
+        /*.mem_buffer =*/ NULL,
+        /*.no_alloc   =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()
+    };
+    ctx_cgraph = ggml_init(params0);
+    gf = ggml_new_graph(ctx_cgraph);
+
+    struct ggml_tensor * result0 = ggml_rope_multi(
+        ctx_cgraph, inp_raw, pos, nullptr,
+        128/2, sections, LLAMA_ROPE_TYPE_VISION, 32768, 1000000, 1,
+        0, 1, 32, 1);
+
+    // Add "result" tensor and all of its dependencies to the cgraph
+    ggml_build_forward_expand(gf, result0);
+
+    // 7. Create a `ggml_gallocr` for cgraph computation
+    ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
+    ggml_gallocr_alloc_graph(allocr, gf);
+
+    // 9. Run the computation
+    int n_threads = 1; // Optional: number of threads to perform some operations with multi-threading
+    if (ggml_backend_is_cpu(backend)) {
+        ggml_backend_cpu_set_n_threads(backend, n_threads);
+    }
+    ggml_backend_graph_compute(backend, gf);
+
+    // 10. Retrieve results (output tensors)
+    // in this example, output tensor is always the last tensor in the graph
+    struct ggml_tensor * result = result0;
+    // struct ggml_tensor * result = gf->nodes[gf->n_nodes - 1];
+    float * result_data = (float *)malloc(ggml_nbytes(result));
+    // because the tensor data is stored in device buffer, we need to copy it back to RAM
+    ggml_backend_tensor_get(result, result_data, 0, ggml_nbytes(result));
+    const std::string bin_file = "mrope_2d_" + backend_name +".bin";
+    std::ofstream outFile(bin_file, std::ios::binary);
+
+    if (outFile.is_open()) {
+        outFile.write(reinterpret_cast<const char*>(result_data), ggml_nbytes(result));
+        outFile.close();
+        std::cout << "Data successfully written to " + bin_file << std::endl;
+    } else {
+        std::cerr << "Error opening file!" << std::endl;
+    }
+
+    free(result_data);
+    // 11. Free memory and exit
+    ggml_free(ctx_cgraph);
+    ggml_gallocr_free(allocr);
+    ggml_free(ctx);
+    ggml_backend_buffer_free(buffer);
+    ggml_backend_free(backend);
+}
+
+enum model_output_type {
+    conv3d,
+    patch_embed,
+    patch_win_attn_scatter,
+    first_attn_layer,
+    last_attn_layer,
+    attn_softmax,
+    final_layer,
+};
+
+static void debug_dump_img_embed(struct llava_context * ctx_llava, model_output_type output_type) {
+    constexpr int ih = 140;
+    constexpr int iw = 196;
+    // constexpr int ih = 56;
+    // constexpr int iw = 56;
+    // int n_embd  = llama_model_n_embd(llama_get_model(ctx_llava->ctx_llama));
+    int n_embd  = 1280;
+    int merge = 1;
+    if (output_type == model_output_type::final_layer) {
+        n_embd  = 2048;
+        merge = 2;
+    }
+    else if (output_type == model_output_type::attn_softmax) {
+        merge = 1;
+        n_embd = (ih/14/merge) * (iw/14/merge) * 16;
+    }
+
+    int ne = (ih/14/merge) * (iw/14/merge) * n_embd;
+    float vals[iw * ih * 3];
+    // float embd[ne];
+    std::vector<float> embd;
+    embd.resize(ne);
+
+    for (int i = 0; i < iw*ih; i++)
+    {
+        for (int c = 0; c < 3; c++)
+            vals[i * 3 + c] = (float)i / (iw*ih);
+    }
+
+    clip_encode_float_image(ctx_llava->ctx_clip, 8, vals, ih, iw, embd.data());
+
+    std::string file_postfix = "";
+    switch (output_type)
+    {
+    case model_output_type::conv3d:
+        file_postfix = "conv3d";
+        break;
+    case model_output_type::patch_embed:
+        file_postfix = "patch_embed";
+        break;
+    case model_output_type::patch_win_attn_scatter:
+        file_postfix = "scatter";
+        break;
+    case model_output_type::first_attn_layer:
+        file_postfix = "first_attn";
+        break;
+    case model_output_type::last_attn_layer:
+        file_postfix = "last_attn";
+        break;
+    case model_output_type::attn_softmax:
+        file_postfix = "attn_softmax";
+        break;
+    case model_output_type::final_layer:
+        file_postfix = "final";
+        break;
+    default:
+        break;
+    }
+    auto output_path = "img_embed_" + file_postfix + ".bin";
+
+    std::ofstream outFile(output_path, std::ios::binary);
+    if (outFile.is_open()) {
+        outFile.write(reinterpret_cast<const char*>(embd.data()), ne * sizeof(float));
+
+        outFile.close();
+        std::cout << "Data successfully written to ::[ " << output_path << std::endl;
+    } else {
+        std::cerr << "Error opening file!" << std::endl;
+    }
+}
+
+#endif
+
+
+int main(int argc, char ** argv) {
+    ggml_time_init();
+
+    common_params params;
+
+    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, print_usage)) {
+        return 1;
+    }
+
+    common_init();
+
+    if (params.mmproj.path.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) {
+        print_usage(argc, argv);
+        return 1;
+    }
+
+    auto * model = llava_init(&params);
+    if (model == NULL) {
+        fprintf(stderr, "%s: error: failed to init llava model\n", __func__);
+        return 1;
+    }
+
+    if (prompt_contains_image(params.prompt)) {
+        auto * ctx_llava = llava_init_context(&params, model);
+
+        auto * image_embed = load_image(ctx_llava, &params, "");
+
+        // process the prompt
+        process_prompt(ctx_llava, image_embed, &params, params.prompt);
+
+        llama_perf_context_print(ctx_llava->ctx_llama);
+        llava_image_embed_free(image_embed);
+        ctx_llava->model = NULL;
+        llava_free(ctx_llava);
+#ifndef NDEBUG
+    } else if (params.image[0].empty()) {
+        auto ctx_llava = llava_init_context(&params, model);
+
+        // debug_test_mrope_2d();
+        debug_dump_img_embed(ctx_llava, model_output_type::final_layer);
+        // debug_dump_img_embed(ctx_llava, model_output_type::last_attn_layer);
+
+        llama_perf_context_print(ctx_llava->ctx_llama);
+        ctx_llava->model = NULL;
+        llava_free(ctx_llava);
+#endif
+    } else {
+        for (auto & image : params.image) {
+            auto * ctx_llava = llava_init_context(&params, model);
+
+            auto * image_embed = load_image(ctx_llava, &params, image);
+            if (!image_embed) {
+                LOG_ERR("%s: failed to load image %s. Terminating\n\n", __func__, image.c_str());
+                return 1;
+            }
+
+            // process the prompt
+            process_prompt(ctx_llava, image_embed, &params, params.prompt);
+
+            llama_perf_context_print(ctx_llava->ctx_llama);
+            llava_image_embed_free(image_embed);
+            ctx_llava->model = NULL;
+            llava_free(ctx_llava);
+        }
+    }
+
+    llama_model_free(model);
+
+    return 0;
+}
index 4002f9d531bd257bded601c2c4bf9ba64363100e..75604315cfeba4e590c94c1ad34b9c61253c661f 100755 (executable)
@@ -54,8 +54,8 @@ add_test "llama-mtmd-cli"  "ibm-research/granite-vision-3.2-2b-GGUF:Q4_K_M"
 add_test "llama-mtmd-cli"  "second-state/MiniCPM-Llama3-V-2_5-GGUF:Q2_K"  # model from openbmb is corrupted
 add_test "llama-mtmd-cli"  "openbmb/MiniCPM-V-2_6-gguf:Q2_K"
 add_test "llama-mtmd-cli"  "openbmb/MiniCPM-o-2_6-gguf:Q4_0"
-add_test "llama-qwen2vl-cli"  "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
-add_test "llama-qwen2vl-cli"  "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
+add_test "llama-mtmd-cli"  "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
+add_test "llama-mtmd-cli"  "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
 
 # to test the big models, run: ./tests.sh big
 add_test_big "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"