]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llava : introduce libmtmd (#12849)
authorXuan-Son Nguyen <redacted>
Thu, 10 Apr 2025 20:57:16 +0000 (22:57 +0200)
committerGitHub <redacted>
Thu, 10 Apr 2025 20:57:16 +0000 (22:57 +0200)
* wip llava2

* migrated gemma3 to llava2

* add timings

* correct pre/postfix

* fix missing include

* fix compilation unused var warn

* update llava2_tokenize

* change name llava2 --> mtmd

* improve api

* refine helpers

* Update examples/llava/mtmd.cpp

Co-authored-by: Georgi Gerganov <redacted>
---------

Co-authored-by: Georgi Gerganov <redacted>
examples/llava/CMakeLists.txt
examples/llava/clip-impl.h
examples/llava/clip.cpp
examples/llava/clip.h
examples/llava/gemma3-cli.cpp
examples/llava/mtmd.cpp [new file with mode: 0644]
examples/llava/mtmd.h [new file with mode: 0644]

index f275ce1ccd0037c84d7eeb729d3dbd48d2e62f95..2d5061de460c08986f7ed99763b78c7492edf4e5 100644 (file)
@@ -1,3 +1,5 @@
+# llava (legacy)
+
 add_library(llava OBJECT
             llava.cpp
             llava.h
@@ -22,12 +24,41 @@ if (BUILD_SHARED_LIBS)
     install(TARGETS llava_shared LIBRARY)
 endif()
 
+# mtmd
+
+add_library(mtmd OBJECT
+            mtmd.cpp
+            mtmd.h
+            clip.cpp
+            clip.h
+            clip-impl.h
+            )
+
+target_link_libraries(mtmd PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT})
+
+target_include_directories(mtmd PUBLIC .)
+target_include_directories(mtmd PRIVATE ../..)
+target_include_directories(mtmd PRIVATE ../../common) # for stb_image.h
+
+target_compile_features(mtmd PRIVATE cxx_std_17)
+
+add_library(mtmd_static STATIC $<TARGET_OBJECTS:mtmd>)
+if (BUILD_SHARED_LIBS)
+    set_target_properties(mtmd PROPERTIES POSITION_INDEPENDENT_CODE ON)
+    target_compile_definitions(mtmd PRIVATE LLAMA_SHARED LLAMA_BUILD)
+    add_library(mtmd_shared SHARED $<TARGET_OBJECTS:mtmd>)
+    target_link_libraries(mtmd_shared PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT})
+    install(TARGETS mtmd_shared LIBRARY)
+endif()
+
 if (NOT MSVC)
     target_compile_options(llava PRIVATE -Wno-cast-qual) # stb_image.h
+    target_compile_options(mtmd PRIVATE -Wno-cast-qual) # stb_image.h
 endif()
 
 if(TARGET BUILD_INFO)
     add_dependencies(llava BUILD_INFO)
+    add_dependencies(mtmd BUILD_INFO)
 endif()
 
 set(TARGET llama-llava-cli)
@@ -55,7 +86,7 @@ set(TARGET llama-gemma3-cli)
 add_executable(${TARGET} gemma3-cli.cpp)
 set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-gemma3-cli)
 install(TARGETS ${TARGET} RUNTIME)
-target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
+target_link_libraries(${TARGET} PRIVATE common mtmd ${CMAKE_THREAD_LIBS_INIT})
 target_compile_features(${TARGET} PRIVATE cxx_std_17)
 
 set(TARGET llama-llava-clip-quantize-cli)
index 685d6e7e09ad1f8dc15b9630562209bfdb42e067..4c035298749245e8a04aa87b31c06dfdb22a7eb6 100644 (file)
@@ -1,12 +1,15 @@
 #include "ggml.h"
 #include "gguf.h"
 
+#include "clip.h"
+
 #include <climits>
 #include <cstdarg>
 #include <string>
 #include <map>
 #include <sstream>
 #include <vector>
+#include <memory>
 
 // Internal header for clip.cpp
 
@@ -120,6 +123,23 @@ static projector_type clip_projector_type_from_string(const std::string & str) {
     return PROJECTOR_TYPE_UNKNOWN;
 }
 
+// RGB uint8 image
+struct clip_image_u8 {
+    int nx;
+    int ny;
+
+    std::vector<uint8_t> buf;
+};
+
+// RGB float32 image (NHWC)
+// Memory layout: RGBRGBRGB...
+struct clip_image_f32 {
+    int nx;
+    int ny;
+
+    std::vector<float> buf;
+};
+
 //
 // logging
 //
@@ -178,6 +198,28 @@ static void clip_log_internal(enum ggml_log_level level, const char * format, ..
 #define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
 #define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT,  __VA_ARGS__)
 
+//
+// cpp wrappers
+//
+
+struct clip_image_u8_deleter {
+    void operator()(clip_image_u8 * val) { clip_image_u8_free(val); }
+};
+
+struct clip_image_f32_deleter {
+    void operator()(clip_image_f32 * val) { clip_image_f32_free(val); }
+};
+
+struct clip_image_f32_batch_deleter {
+    void operator()(clip_image_f32_batch * val) { clip_image_f32_batch_free(val); }
+};
+
+typedef std::unique_ptr<clip_image_u8, clip_image_u8_deleter> clip_image_u8_ptr;
+typedef std::unique_ptr<clip_image_f32, clip_image_f32_deleter> clip_image_f32_ptr;
+typedef std::unique_ptr<clip_image_f32_batch, clip_image_f32_batch_deleter> clip_image_f32_batch_ptr;
+
+// TODO @ngxson : we're currently having a naming clash between struct clip_image_size and function clip_image_size()
+
 //
 // common utils
 //
@@ -214,6 +256,20 @@ static void string_replace_all(std::string & s, const std::string & search, cons
     s = std::move(builder);
 }
 
+// split string by a `std::string delim` instead of `char delim`
+static std::vector<std::string> string_split_str(std::string s, const std::string & delimiter) {
+    std::vector<std::string> tokens;
+    size_t pos = 0;
+    std::string token;
+    while ((pos = s.find(delimiter)) != std::string::npos) {
+        token = s.substr(0, pos);
+        tokens.push_back(token);
+        s.erase(0, pos + delimiter.length());
+    }
+    tokens.push_back(s);
+    return tokens;
+}
+
 //
 // gguf utils
 //
@@ -271,3 +327,9 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
             return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0);
     }
 }
+
+//
+// API used internally with mtmd
+//
+
+projector_type clip_get_projector_type(const struct clip_ctx * ctx);
index 4f21e836a324d6df79dcf04c563eb5bcd5a78f59..710309edaecd6a31cb9608444a13939b49ea6aae 100644 (file)
@@ -32,23 +32,6 @@ struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callbac
 
 //#define CLIP_DEBUG_FUNCTIONS
 
-// RGB uint8 image
-struct clip_image_u8 {
-    int nx;
-    int ny;
-
-    std::vector<uint8_t> buf;
-};
-
-// RGB float32 image (NHWC)
-// Memory layout: RGBRGBRGB...
-struct clip_image_f32 {
-    int nx;
-    int ny;
-
-    std::vector<float> buf;
-};
-
 #ifdef CLIP_DEBUG_FUNCTIONS
 static void clip_image_write_image_to_ppm(const clip_image_u8& img, const std::string& filename) {
     std::ofstream file(filename, std::ios::binary);
@@ -1614,6 +1597,12 @@ struct clip_image_f32 * clip_image_f32_init() {
     return new clip_image_f32();
 }
 
+unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny) {
+    if (nx) *nx = img->nx;
+    if (ny) *ny = img->ny;
+    return img->buf.data();
+}
+
 void clip_image_size_free(struct clip_image_size * load_image_size) {
     if (load_image_size == nullptr) {
         return;
@@ -2346,6 +2335,8 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
         int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
         int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
         n_patches = x_patch * y_patch;
+    } else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
+        n_patches = 256;
     }
 
     return n_patches;
@@ -2893,3 +2884,11 @@ bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img,
     clip_image_encode(ctx, n_threads, &clip_img, vec);
     return true;
 }
+
+//
+// API used internally with mtmd
+//
+
+projector_type clip_get_projector_type(const struct clip_ctx * ctx) {
+    return ctx->proj_type;
+}
index 87aa61574b1eb86c1241346d3f0bea78aecc3d6b..f61e0c0b2b3a73d26b3e5d66249c95567000bdc2 100644 (file)
@@ -77,6 +77,9 @@ CLIP_API struct clip_image_size * clip_image_size_init();
 CLIP_API struct clip_image_u8  * clip_image_u8_init ();
 CLIP_API struct clip_image_f32 * clip_image_f32_init();
 
+// nx, ny are the output image dimensions
+CLIP_API unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny);
+
 CLIP_API void clip_image_size_free (struct clip_image_size * img_size);
 CLIP_API void clip_image_u8_free (struct clip_image_u8  * img);
 CLIP_API void clip_image_f32_free(struct clip_image_f32 * img);
index 4f89c0e15b4e9db871c678bcb441effce8f5b79b..91a07e2a8f40d570e0be8131a7bc1690dc76cb30 100644 (file)
@@ -2,11 +2,11 @@
 #include "log.h"
 #include "common.h"
 #include "sampling.h"
-#include "clip.h"
-#include "stb_image.h"
 #include "llama.h"
 #include "ggml.h"
 #include "console.h"
+#include "chat.h"
+#include "mtmd.h"
 
 #include <vector>
 #include <limits.h>
@@ -57,13 +57,18 @@ static void sigint_handler(int signo) {
 #endif
 
 struct gemma3_context {
-    struct clip_ctx    * ctx_clip = NULL;
-    common_init_result   llama_init;
+    mtmd_context_ptr ctx_vision;
+    common_init_result llama_init;
 
     llama_model       * model;
     llama_context     * lctx;
     const llama_vocab * vocab;
     llama_batch         batch;
+    int                 n_batch;
+
+    // note: we know that gemma3 template is "linear", meaning each turn is completely separated to another
+    // so here we don't need to keep track of chat history
+    common_chat_templates_ptr tmpls;
 
     int n_threads    = 1;
     llama_pos n_past = 0;
@@ -74,21 +79,24 @@ struct gemma3_context {
         vocab = llama_model_get_vocab(model);
         n_threads = params.cpuparams.n_threads;
         batch = llama_batch_init(params.n_batch, 0, 1);
-        init_clip_model(params);
+        n_batch = params.n_batch;
+        tmpls = common_chat_templates_init(model, params.chat_template);
+        init_vision_context(params);
     }
 
-    void init_clip_model(common_params & params) {
+    void init_vision_context(common_params & params) {
         const char * clip_path = params.mmproj.path.c_str();
-        ctx_clip = clip_model_load(clip_path, GGML_LOG_LEVEL_INFO);
-        if (!ctx_clip) {
-            LOG_ERR("Failed to load CLIP model from %s\n", clip_path);
+        ctx_vision.reset(mtmd_init_from_file(clip_path, model, mtmd_context_params{
+            /* use_gpu */   true,
+            /* timings */   true,
+            /* n_threads */ params.cpuparams.n_threads,
+            /* verbosity */ GGML_LOG_LEVEL_INFO,
+        }));
+        if (!ctx_vision.get()) {
+            LOG_ERR("Failed to load vision model from %s\n", clip_path);
             exit(1);
         }
     }
-
-    ~gemma3_context() {
-        clip_free(ctx_clip);
-    }
 };
 
 struct decode_embd_batch {
@@ -124,77 +132,6 @@ struct decode_embd_batch {
     }
 };
 
-static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) {
-    llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true);
-    common_batch_clear(ctx.batch);
-    for (llama_token & t : tokens) {
-        common_batch_add(ctx.batch, t, ctx.n_past++, {0}, false);
-    }
-    if (logits_last) {
-        ctx.batch.logits[ctx.batch.n_tokens - 1] = true;
-    }
-    // LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str());
-    if (llama_decode(ctx.lctx, ctx.batch)) {
-        LOG_ERR("Failed to decode text\n");
-        return 1;
-    }
-    return 0;
-}
-
-static int eval_image(gemma3_context & ctx, std::string & fname) {
-    std::vector<float> image_embd_v;
-    int n_embd = llama_model_n_embd(ctx.model);
-    int n_tokens = 256;
-    image_embd_v.resize(n_tokens * n_embd);
-
-    bool ok;
-    struct clip_image_u8 * img_u8 = clip_image_u8_init();
-    ok = clip_image_load_from_file(fname.c_str(), img_u8);
-    if (!ok) {
-        LOG_ERR("Unable to load image %s\n", fname.c_str());
-        clip_image_u8_free(img_u8);
-        return 2; // non-fatal error
-    }
-
-    clip_image_f32_batch batch_f32;
-    ok = clip_image_preprocess(ctx.ctx_clip, img_u8, &batch_f32);
-    if (!ok) {
-        LOG_ERR("Unable to preprocess image\n");
-        clip_image_f32_batch_free(&batch_f32);
-        clip_image_u8_free(img_u8);
-        return 1;
-    }
-
-    int64_t t0 = ggml_time_ms();
-    LOG("Encoding image %s\n", fname.c_str());
-    ok = clip_image_batch_encode(ctx.ctx_clip, ctx.n_threads, &batch_f32, image_embd_v.data());
-    if (!ok) {
-        LOG_ERR("Unable to encode image\n");
-        clip_image_f32_batch_free(&batch_f32);
-        clip_image_u8_free(img_u8);
-        return 1;
-    }
-    LOG("Image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
-
-    clip_image_f32_batch_free(&batch_f32);
-    clip_image_u8_free(img_u8);
-
-    // decode image embeddings
-    int64_t t1 = ggml_time_ms();
-    eval_text(ctx, "<start_of_image>");
-    llama_set_causal_attn(ctx.lctx, false);
-    decode_embd_batch batch_img(image_embd_v.data(), n_tokens, ctx.n_past, 0);
-    if (llama_decode(ctx.lctx, batch_img.batch)) {
-        LOG_ERR("failed to decode image\n");
-        return 1;
-    }
-    ctx.n_past += n_tokens;
-    llama_set_causal_attn(ctx.lctx, true);
-    eval_text(ctx, "<end_of_image>");
-    LOG("Image decoded in %" PRId64 " ms\n", ggml_time_ms() - t1);
-    return 0;
-}
-
 static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_predict) {
     for (int i = 0; i < n_predict; i++) {
         if (i > n_predict || !g_is_generating) {
@@ -224,6 +161,45 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_
     return 0;
 }
 
+static int eval_message(gemma3_context & ctx, common_chat_msg & msg, std::vector<std::string> & images_fname, bool add_bos = false) {
+    std::vector<mtmd_bitmap> bitmaps;
+
+    common_chat_templates_inputs tmpl_inputs;
+    tmpl_inputs.messages = {msg};
+    tmpl_inputs.add_generation_prompt = true;
+    tmpl_inputs.use_jinja = false; // jinja is buggy here
+    auto formatted_chat = common_chat_templates_apply(ctx.tmpls.get(), tmpl_inputs);
+    LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str());
+
+    for (auto & fname : images_fname) {
+        mtmd_bitmap bitmap;
+        if (mtmd_helper_bitmap_init_from_file(fname.c_str(), bitmap)) {
+            LOG_ERR("Unable to load image %s\n", fname.c_str());
+            return 2; // image not found
+        }
+        bitmaps.push_back(std::move(bitmap));
+    }
+
+    mtmd_input_text text;
+    text.text          = formatted_chat.prompt;
+    text.add_special   = add_bos;
+    text.parse_special = true;
+    mtmd_input_chunks_ptr chunks(mtmd_tokenize(ctx.ctx_vision.get(), text, bitmaps));
+    if (chunks == nullptr) {
+        LOG_ERR("Unable to tokenize prompt\n");
+        return 1;
+    }
+
+    if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks.get(), ctx.n_past, 0, ctx.n_batch)) {
+        LOG_ERR("Unable to eval prompt\n");
+        return 1;
+    }
+
+    ctx.n_past += mtmd_helper_get_n_tokens(chunks.get());
+
+    return 0;
+}
+
 int main(int argc, char ** argv) {
     ggml_time_init();
 
@@ -265,21 +241,15 @@ int main(int argc, char ** argv) {
 #endif
     }
 
-    if (eval_text(ctx, "<bos>")) {
-        return 1;
-    }
-
     if (is_single_turn) {
         g_is_generating = true;
-        if (eval_text(ctx, "<start_of_turn>user\n")) {
-            return 1;
-        }
-        for (auto & fname : params.image) {
-            if (eval_image(ctx, fname)) {
-                return 1;
-            }
+        if (params.prompt.find("<__image__>") == std::string::npos) {
+            params.prompt += " <__image__>";
         }
-        if (eval_text(ctx, params.prompt + "<end_of_turn><start_of_turn>model\n", true)) {
+        common_chat_msg msg;
+        msg.role = "user";
+        msg.content = params.prompt;
+        if (eval_message(ctx, msg, params.image, true)) {
             return 1;
         }
         if (generate_response(ctx, smpl, n_predict)) {
@@ -293,9 +263,9 @@ int main(int argc, char ** argv) {
         LOG("\n   /quit or /exit   exit the program");
         LOG("\n");
 
-        if (eval_text(ctx, "<start_of_turn>user\n")) {
-            return 1;
-        }
+        bool is_first_msg = true;
+        std::vector<std::string> images_fname;
+        std::string content;
 
         while (true) {
             g_is_generating = false;
@@ -320,24 +290,31 @@ int main(int argc, char ** argv) {
             g_is_generating = true;
             if (line.find("/image") == 0) {
                 std::string image = line.substr(7);
-                int res = eval_image(ctx, image);
-                if (res == 2) {
-                    continue; // image not found
-                }
-                if (res) {
-                    return 1;
-                }
+                images_fname.push_back(string_strip(image));
+                content += "<__image__>";
                 continue;
+            } else {
+                content += line;
             }
-            if (eval_text(ctx, line + "<end_of_turn><start_of_turn>model\n", true)) {
-                return 1;
+            common_chat_msg msg;
+            msg.role = "user";
+            msg.content = content;
+            int ret = eval_message(ctx, msg, images_fname, is_first_msg);
+            if (ret == 2) {
+                // non-fatal error
+                images_fname.clear();
+                content.clear();
+                continue;
             }
-            if (generate_response(ctx, smpl, n_predict)) {
+            if (ret) {
                 return 1;
             }
-            if (eval_text(ctx, "<end_of_turn><start_of_turn>user\n")) {
+            if (generate_response(ctx, smpl, n_predict)) {
                 return 1;
             }
+            images_fname.clear();
+            content.clear();
+            is_first_msg = false;
         }
     }
 
diff --git a/examples/llava/mtmd.cpp b/examples/llava/mtmd.cpp
new file mode 100644 (file)
index 0000000..58503d0
--- /dev/null
@@ -0,0 +1,341 @@
+#include "clip.h"
+#include "clip-impl.h"
+#include "mtmd.h"
+
+#include "llama.h"
+
+#include <algorithm>
+#include <cerrno>
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include <limits>
+#include <vector>
+
+struct mtmd_context {
+    struct clip_ctx * ctx_clip;
+    const struct llama_model * text_model;
+    std::vector<float> image_embd_v; // image embedding vector
+    bool print_timings;
+    int n_threads;
+    std::string image_marker;
+
+    // TODO @ngxson : add timings
+
+    mtmd_context(const char * mmproj_fname,
+                   const llama_model * text_model,
+                   const mtmd_context_params & ctx_params) : print_timings(ctx_params.print_timings), n_threads(ctx_params.n_threads), image_marker(ctx_params.image_marker) {
+        clip_context_params ctx_clip_params;
+        ctx_clip_params.use_gpu   = ctx_params.use_gpu;
+        ctx_clip_params.verbosity = ctx_params.verbosity;
+        ctx_clip = clip_init(mmproj_fname, ctx_clip_params);
+        if (!ctx_clip) {
+            throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname));
+        }
+        this->text_model = text_model;
+    }
+
+    ~mtmd_context() {
+        clip_free(ctx_clip);
+    }
+};
+
+struct mtmd_image_tokens_data {
+    clip_image_f32_batch_ptr batch_f32; // preprocessed image patches
+};
+
+struct mtmd_image_tokens {
+    uint32_t nx; // number of tokens in x direction
+    uint32_t ny; // number of tokens in y direction
+    uint32_t n_tokens() const { return nx * ny; }
+    clip_image_f32_batch_ptr batch_f32; // preprocessed image patches
+};
+
+mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
+        const struct llama_model * text_model,
+        const struct mtmd_context_params ctx_params) {
+    try {
+        return new mtmd_context(mmproj_fname, text_model, ctx_params);
+    } catch (const std::exception & e) {
+        LOG_ERR("%s: error: %s\n", __func__, e.what());
+        return nullptr;
+    }
+}
+
+void mtmd_free(mtmd_context * ctx) {
+    if (ctx) {
+        delete ctx;
+    }
+}
+
+// copied from common_tokenize
+static std::vector<llama_token> mtmd_tokenize_text_internal(
+    const struct llama_vocab * vocab,
+           const std::string & text,
+                        bool   add_special,
+                        bool   parse_special) {
+    // upper limit for the number of tokens
+    int n_tokens = text.length() + 2 * add_special;
+    std::vector<llama_token> result(n_tokens);
+    n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
+    if (n_tokens < 0) {
+        result.resize(-n_tokens);
+        int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
+        GGML_ASSERT(check == -n_tokens);
+    } else {
+        result.resize(n_tokens);
+    }
+    return result;
+}
+
+mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
+                                const mtmd_input_text & text,
+                                const std::vector<mtmd_bitmap> & bitmaps) {
+    mtmd_input_chunks * output = new mtmd_input_chunks;
+    auto vocab = llama_model_get_vocab(ctx->text_model);
+
+    std::string prompt_modified(text.text);
+    std::string marker_modified(ctx->image_marker);
+    projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
+    // a bit hacky here, but works for now
+    // for some models, we need to add prefix and suffix to the image embeddings
+    if (proj_type == PROJECTOR_TYPE_GEMMA3) {
+        // <start_of_image> ... (image embeddings) ... <end_of_image>
+        marker_modified = "<start_of_image>" + ctx->image_marker + "<end_of_image>";
+        string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
+    }
+
+    std::vector<std::string> parts = string_split_str(text.text, ctx->image_marker);
+    output->clear();
+    output->reserve(parts.size());
+
+    size_t i_img = 0;
+
+    for (const auto & part : parts) {
+        //printf("tokenizing part: %s\n", part.c_str());
+        bool add_bos = &parts.front() == &part;
+        auto tokens = mtmd_tokenize_text_internal(vocab, part, text.add_special && add_bos, text.parse_special);
+        if (tokens.empty()) {
+            continue;
+        }
+        mtmd_input_chunk chunk{
+            MTMD_INPUT_CHUNK_TYPE_TEXT,
+            std::move(tokens),
+            {},
+        };
+        output->emplace_back(std::move(chunk));
+
+        if (&parts.back() != &part) {
+            // add image token to middle of 2 parts
+
+            if (i_img >= bitmaps.size()) {
+                LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
+                return nullptr;
+            }
+
+            // shim layer
+            clip_image_u8_ptr img_u8(clip_image_u8_init());
+            img_u8->nx = bitmaps[i_img].nx;
+            img_u8->ny = bitmaps[i_img].ny;
+            img_u8->buf.resize(bitmaps[i_img].data.size());
+            std::memcpy(img_u8->buf.data(), bitmaps[i_img].data.data(), img_u8->nx * img_u8->ny * 3);
+
+            // preprocess image
+            clip_image_f32_batch_ptr batch_f32(new clip_image_f32_batch);
+            bool ok = clip_image_preprocess(ctx->ctx_clip, img_u8.get(), batch_f32.get());
+            if (!ok) {
+                LOG_ERR("Unable to preprocess image\n");
+                return nullptr;
+            }
+
+            mtmd_image_tokens * image_tokens = new mtmd_image_tokens;
+            image_tokens->nx = clip_n_patches(ctx->ctx_clip); // TODO @ngxson : use clip_n_patches_by_image
+            image_tokens->ny = 1; // TODO
+            image_tokens->batch_f32 = std::move(batch_f32);
+
+            mtmd_input_chunk chunk{
+                MTMD_INPUT_CHUNK_TYPE_IMAGE,
+                {},
+                image_tokens,
+            };
+            output->emplace_back(std::move(chunk));
+            i_img++;
+        }
+    }
+
+    return output;
+}
+
+void mtmd_input_chunks_free(mtmd_input_chunks * chunks) {
+    for (auto & chunk : *chunks) {
+        if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE && chunk.tokens_image) {
+            delete chunk.tokens_image;
+        }
+    }
+    delete chunks;
+}
+
+int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
+    int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
+    ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd);
+    bool ok = clip_image_batch_encode(
+        ctx->ctx_clip,
+        ctx->n_threads,
+        image_tokens->batch_f32.get(),
+        ctx->image_embd_v.data());
+    return ok ? 0 : 1;
+}
+
+float * mtmd_get_output_embd(mtmd_context * ctx) {
+    return ctx->image_embd_v.data();
+}
+
+size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks) {
+    size_t n_tokens = 0;
+    for (auto & chunk : *chunks) {
+        if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
+            n_tokens += chunk.tokens_text.size();
+        } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
+            n_tokens += chunk.tokens_image->n_tokens();
+        } else {
+            GGML_ASSERT(false && "chunk type not supported");
+        }
+    }
+    return n_tokens;
+}
+
+// helper struct to make working with embd batch easier
+// note: this will be removed after llama_batch_ext refactoring
+struct decode_embd_batch {
+    std::vector<llama_pos>      pos;
+    std::vector<int32_t>        n_seq_id;
+    std::vector<llama_seq_id>   seq_id_0;
+    std::vector<llama_seq_id *> seq_ids;
+    std::vector<int8_t>         logits;
+    llama_batch batch;
+    decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
+        pos     .resize(n_tokens);
+        n_seq_id.resize(n_tokens);
+        seq_ids .resize(n_tokens + 1);
+        logits  .resize(n_tokens);
+        seq_id_0.resize(1);
+        seq_id_0[0] = seq_id;
+        seq_ids [n_tokens] = nullptr;
+        batch = {
+            /*n_tokens       =*/ n_tokens,
+            /*tokens         =*/ nullptr,
+            /*embd           =*/ embd,
+            /*pos            =*/ pos.data(),
+            /*n_seq_id       =*/ n_seq_id.data(),
+            /*seq_id         =*/ seq_ids.data(),
+            /*logits         =*/ logits.data(),
+        };
+        for (int i = 0; i < n_tokens; i++) {
+            batch.pos     [i] = pos_0 + i;
+            batch.n_seq_id[i] = 1;
+            batch.seq_id  [i] = seq_id_0.data();
+            batch.logits  [i] = false;
+        }
+    }
+};
+
+int32_t mtmd_helper_eval(mtmd_context * ctx,
+        llama_context * lctx,
+        mtmd_input_chunks * chunks,
+        llama_pos pos0,
+        llama_seq_id seq_id,
+        int32_t n_batch) {
+    int32_t ret;
+    llama_pos n_past = pos0;
+    llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
+
+    for (auto & chunk : *chunks) {
+        bool is_last = &chunk == &chunks->back();
+        if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
+            // TODO @ngxson : may need to split into smaller batches
+            text_batch.n_tokens = chunk.tokens_text.size();
+            for (size_t i = 0; i < chunk.tokens_text.size(); i++) {
+                text_batch.token   [i]    = chunk.tokens_text[i];
+                text_batch.pos     [i]    = n_past++;
+                text_batch.n_seq_id[i]    = 1;
+                text_batch.seq_id  [i][0] = seq_id;
+                text_batch.logits  [i]    = false;
+            }
+            if (is_last) {
+                // always get logits for last input chunk
+                text_batch.logits[text_batch.n_tokens - 1] = true;
+            }
+            ret = llama_decode(lctx, text_batch);
+            if (ret != 0) {
+                LOG_ERR("failed to decode text\n");
+                llama_batch_free(text_batch);
+                return ret;
+            }
+
+        } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
+            GGML_ASSERT(!is_last && "logits for last image chunk is not yet support");
+            GGML_ASSERT(chunk.tokens_image != nullptr);
+            int64_t t0 = ggml_time_ms();
+            if (ctx->print_timings) {
+                LOG_INF("encoding image...\n");
+            }
+            ret = mtmd_encode(ctx, chunk.tokens_image);
+            if (ret != 0) {
+                LOG_ERR("failed to encode image\n");
+                llama_batch_free(text_batch);
+                return ret;
+            }
+            if (ctx->print_timings) {
+                LOG_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
+            }
+
+            int32_t n_tokens = chunk.tokens_image->n_tokens();
+            float * embd = mtmd_get_output_embd(ctx);
+            decode_embd_batch batch_img(embd, n_tokens, n_past, 0);
+            int64_t t1 = ggml_time_ms();
+            ret = llama_decode(lctx, batch_img.batch);
+            if (ret != 0) {
+                LOG_ERR("failed to decode image\n");
+                llama_batch_free(text_batch);
+                return ret;
+            }
+            if (ctx->print_timings) {
+                LOG_INF("image decoded in %" PRId64 " ms\n", ggml_time_ms() - t1);
+            }
+
+            n_past += n_tokens;
+
+        } else {
+            GGML_ASSERT(false && "chunk type not supported");
+        }
+    }
+
+    llama_batch_free(text_batch);
+    return 0;
+}
+
+int32_t mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len, mtmd_bitmap & output) {
+    clip_image_u8_ptr img_u8(clip_image_u8_init());
+    bool ok = clip_image_load_from_bytes(buf, len, img_u8.get());
+    if (!ok) {
+        LOG_ERR("Unable to load image from buffer\n");
+        return 1;
+    }
+    unsigned char * data = clip_image_u8_get_data(img_u8.get(), &output.nx, &output.ny);
+    output.data.resize(output.nx * output.ny * 3);
+    std::memcpy(output.data.data(), data, output.nx * output.ny * 3);
+    return 0;
+}
+
+int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & output) {
+    clip_image_u8_ptr img_u8(clip_image_u8_init());
+    bool ok = clip_image_load_from_file(fname, img_u8.get());
+    if (!ok) {
+        LOG_ERR("Unable to load image %s\n", fname);
+        return 1;
+    }
+    unsigned char * data = clip_image_u8_get_data(img_u8.get(), &output.nx, &output.ny);
+    output.data.resize(output.nx * output.ny * 3);
+    std::memcpy(output.data.data(), data, output.nx * output.ny * 3);
+    return 0;
+}
diff --git a/examples/llava/mtmd.h b/examples/llava/mtmd.h
new file mode 100644 (file)
index 0000000..598f694
--- /dev/null
@@ -0,0 +1,146 @@
+#ifndef MTMD_H
+#define MTMD_H
+
+#include "ggml.h"
+#include "llama.h"
+#include "clip.h"
+
+#include <vector>
+#include <cinttypes>
+#include <memory>
+
+#ifdef LLAMA_SHARED
+#    if defined(_WIN32) && !defined(__MINGW32__)
+#        ifdef LLAMA_BUILD
+#            define MTMD_API __declspec(dllexport)
+#        else
+#            define MTMD_API __declspec(dllimport)
+#        endif
+#    else
+#        define MTMD_API __attribute__ ((visibility ("default")))
+#    endif
+#else
+#    define MTMD_API
+#endif
+
+#ifdef __cplusplus
+
+enum mtmd_input_chunk_type {
+    MTMD_INPUT_CHUNK_TYPE_TEXT,
+    MTMD_INPUT_CHUNK_TYPE_IMAGE,
+};
+
+struct mtmd_context;
+struct mtmd_image_tokens;
+
+// represents raw image data, layout is RGBRGBRGB...
+// length of data must be nx * ny * 3
+struct mtmd_bitmap {
+    uint32_t nx;
+    uint32_t ny;
+    std::vector<unsigned char> data;
+};
+
+struct mtmd_input_chunk {
+    mtmd_input_chunk_type type;
+    std::vector<llama_token> tokens_text;
+    mtmd_image_tokens * tokens_image = nullptr;
+};
+
+using mtmd_input_chunks = std::vector<mtmd_input_chunk>;
+
+struct mtmd_context_params {
+    bool use_gpu = true;
+    bool print_timings = true;
+    int n_threads = 4;
+    enum ggml_log_level verbosity = GGML_LOG_LEVEL_INFO;
+    const char * image_marker = "<__image__>";
+};
+
+struct mtmd_input_text {
+    std::string text;
+    bool add_special;
+    bool parse_special;
+};
+
+// initialize the mtmd context
+// return nullptr on failure
+MTMD_API mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
+                                                const llama_model * text_model,
+                                                const mtmd_context_params ctx_params);
+
+MTMD_API void mtmd_free(mtmd_context * ctx);
+
+// tokenize an input text prompt and an image
+// the prompt must have the input image marker (default: "<__image__>") in it
+// the marker will be replaced with the image tokens
+// for example:
+//   "here is an image: <__image__>\ndescribe it in detail."
+//   this will gives 3 chunks:
+//   1. "here is an image: <start_of_image>"
+//   2. (image tokens)
+//   3. "<end_of_image>\ndescribe it in detail."
+// number of bitmaps must be equal to the number of image markers in the prompt
+// this function is thread-safe (shared ctx)
+MTMD_API mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
+                                const mtmd_input_text & text,
+                                const std::vector<mtmd_bitmap> & bitmaps);
+
+// free image chunk data
+MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks);
+
+// returns 0 on success
+MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
+                            const mtmd_image_tokens * image_tokens);
+
+// get output embeddings from the last encode pass
+MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
+
+//
+// helper functions (can be implemented based on other functions)
+//
+
+// helper to count the total number of tokens from a list of chunks, useful to keep track of n_past
+MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks);
+
+// helper function that automatically:
+// 1. run llama_decode() on text chunks
+// 2. run mtmd_encode() on image chunks, then mtmd_get_output_embd() and then llama_decode()
+// if any of the mtmd_encode() or llama_decode() calls return non-zero, stop and forward the error
+// otherwise, returns 0 on success
+MTMD_API int32_t mtmd_helper_eval(mtmd_context * ctx,
+                                llama_context * lctx,
+                                mtmd_input_chunks * chunks,
+                                llama_pos pos0,
+                                llama_seq_id seq_id,
+                                int32_t n_batch);
+
+// helper function to construct a mtmd_bitmap from a file
+// returns 0 on success
+// this function is thread-safe
+MTMD_API int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & output);
+
+// helper function to construct a mtmd_bitmap from a buffer
+// the buffer must be an image in format supported by stb_image (jpg, png, bmp, gif, etc.)
+// returns 0 on success
+// this function is thread-safe
+MTMD_API int32_t mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len, mtmd_bitmap & output);
+
+// convenient unique_ptr wrappers
+struct mtmd_context_deleter {
+    void operator()(mtmd_context * val) { mtmd_free(val); }
+};
+using mtmd_context_ptr = std::unique_ptr<mtmd_context, mtmd_context_deleter>;
+
+struct mtmd_input_chunks_deleter {
+    void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val); }
+};
+using mtmd_input_chunks_ptr = std::unique_ptr<mtmd_input_chunks, mtmd_input_chunks_deleter>;
+
+#else
+
+static_assert(false && "C header is not yet supported by this library");
+
+#endif
+
+#endif