]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
mtmd : add C public API (#13184)
authorXuan-Son Nguyen <redacted>
Sun, 4 May 2025 21:43:42 +0000 (23:43 +0200)
committerGitHub <redacted>
Sun, 4 May 2025 21:43:42 +0000 (23:43 +0200)
* init

* wip

* working version

* add mtmd::bitmaps

* add test target

* rm redundant define

* test: mtmd_input_chunks_free

* rm outdated comment

* fix merging issue

* explicitly create mtmd::input_chunks

* mtmd_input_chunk_copy

* add clone()

* add const to various places

* add warning about breaking changes

* helper: use mtmd_image_tokens_get_n_pos

tests/CMakeLists.txt
tests/test-mtmd-c-api.c [new file with mode: 0644]
tools/llava/clip-impl.h
tools/llava/clip.h
tools/llava/mtmd-cli.cpp
tools/llava/mtmd.cpp
tools/llava/mtmd.h

index 72b32df9fca404ba9ddb3aa8c24013d1b11bbd1d..709d5ad96afba7f7fc08bb980d671c7cb1ef4d2e 100644 (file)
@@ -165,6 +165,10 @@ if (NOT GGML_BACKEND_DL)
     llama_build_and_test(test-rope.cpp)
 endif()
 
+# libmtmd
+set(LLAMA_TEST_NAME test-mtmd-c-api)
+llama_build_and_test(test-mtmd-c-api.c)
+target_link_libraries(${LLAMA_TEST_NAME} PRIVATE mtmd)
 
 # dummy executable - not installed
 get_filename_component(TEST_TARGET test-c.c NAME_WE)
diff --git a/tests/test-mtmd-c-api.c b/tests/test-mtmd-c-api.c
new file mode 100644 (file)
index 0000000..02e762e
--- /dev/null
@@ -0,0 +1,63 @@
+#include <stdio.h>
+#include <assert.h>
+
+#include "mtmd.h"
+
+int main(void) {
+    printf("\n\nTesting libmtmd C API...\n");
+    printf("--------\n\n");
+
+    struct mtmd_context_params params = mtmd_context_params_default();
+    printf("Default image marker: %s\n", params.image_marker);
+
+    mtmd_input_chunks * chunks = mtmd_test_create_input_chunks();
+
+    if (!chunks) {
+        fprintf(stderr, "Failed to create input chunks\n");
+        return 1;
+    }
+
+    size_t n_chunks = mtmd_input_chunks_size(chunks);
+    printf("Number of chunks: %zu\n", n_chunks);
+    assert(n_chunks > 0);
+
+    for (size_t i = 0; i < n_chunks; i++) {
+        const mtmd_input_chunk * chunk = mtmd_input_chunks_get(chunks, i);
+        assert(chunk != NULL);
+        enum mtmd_input_chunk_type type = mtmd_input_chunk_get_type(chunk);
+        printf("Chunk %zu type: %d\n", i, type);
+
+        if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
+            size_t n_tokens;
+            const llama_token * tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
+            printf("    Text chunk with %zu tokens\n", n_tokens);
+            assert(tokens != NULL);
+            assert(n_tokens > 0);
+            for (size_t j = 0; j < n_tokens; j++) {
+                assert(tokens[j] >= 0);
+                printf("    > Token %zu: %d\n", j, tokens[j]);
+            }
+
+        } else if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
+            const mtmd_image_tokens * image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
+            size_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
+            size_t nx = mtmd_image_tokens_get_nx(image_tokens);
+            size_t ny = mtmd_image_tokens_get_ny(image_tokens);
+            const char * id = mtmd_image_tokens_get_id(image_tokens);
+            assert(n_tokens > 0);
+            assert(nx > 0);
+            assert(ny > 0);
+            assert(id != NULL);
+            printf("    Image chunk with %zu tokens\n", n_tokens);
+            printf("    Image size: %zu x %zu\n", nx, ny);
+            printf("    Image ID: %s\n", id);
+        }
+    }
+
+    // Free the chunks
+    mtmd_input_chunks_free(chunks);
+
+    printf("\n\nDONE: test libmtmd C API...\n");
+
+    return 0;
+}
index b78d930bce34cde8787043f03be8b326f22b3fea..fb780e9deac7ee679bf4069d0c8b6d821a45f464 100644 (file)
@@ -233,6 +233,15 @@ struct clip_image_u8_batch {
 
 struct clip_image_f32_batch {
     std::vector<clip_image_f32_ptr> entries;
+
+    clip_image_f32_batch clone() const {
+        clip_image_f32_batch new_batch;
+        new_batch.entries.reserve(entries.size());
+        for (const auto & entry : entries) {
+            new_batch.entries.emplace_back(new clip_image_f32(*entry));
+        }
+        return new_batch;
+    }
 };
 
 //
index 0a53bd8eb78e1b5755afd80e1635ad810464d11c..0b0eb02956a3237d3e043b90014cddb8cdd24077 100644 (file)
@@ -78,10 +78,10 @@ CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip);
 CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size);
 CLIP_API struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip);
 
-CLIP_API struct clip_image_size      * clip_image_size_init();
-CLIP_API struct clip_image_u8        * clip_image_u8_init ();
-CLIP_API struct clip_image_f32       * clip_image_f32_init();
-CLIP_API struct clip_image_f32_batch * clip_image_f32_batch_init(); // only used by libllava
+CLIP_API struct clip_image_size      * clip_image_size_init(void);
+CLIP_API struct clip_image_u8        * clip_image_u8_init (void);
+CLIP_API struct clip_image_f32       * clip_image_f32_init(void);
+CLIP_API struct clip_image_f32_batch * clip_image_f32_batch_init(void); // only used by libllava
 
 // nx, ny are the output image dimensions
 CLIP_API unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny);
index 474e7c4f8357e9753f22c8b7157284cf5fa48ae0..dd18e0fe6ed0dfceac705e321e53ef3f61f5e935 100644 (file)
@@ -63,7 +63,7 @@ static void sigint_handler(int signo) {
 #endif
 
 struct mtmd_cli_context {
-    mtmd_context_ptr ctx_vision;
+    mtmd::context_ptr ctx_vision;
     common_init_result llama_init;
 
     llama_model       * model;
@@ -72,7 +72,7 @@ struct mtmd_cli_context {
     llama_batch         batch;
     int                 n_batch;
 
-    std::vector<mtmd_bitmap> bitmaps;
+    mtmd::bitmaps bitmaps;
 
     // note: we know that gemma3 template is "linear", meaning each turn is completely separated to another
     // so here we don't need to keep track of chat history
@@ -115,12 +115,12 @@ struct mtmd_cli_context {
 
     void init_vision_context(common_params & params) {
         const char * clip_path = params.mmproj.path.c_str();
-        ctx_vision.reset(mtmd_init_from_file(clip_path, model, mtmd_context_params{
-            /* use_gpu */   params.mmproj_use_gpu,
-            /* timings */   true,
-            /* n_threads */ params.cpuparams.n_threads,
-            /* verbosity */ params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO,
-        }));
+        mtmd_context_params mparams = mtmd_context_params_default();
+        mparams.use_gpu = params.mmproj_use_gpu;
+        mparams.print_timings = true;
+        mparams.n_threads = params.cpuparams.n_threads;
+        mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
+        ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams));
         if (!ctx_vision.get()) {
             LOG_ERR("Failed to load vision model from %s\n", clip_path);
             exit(1);
@@ -139,11 +139,11 @@ struct mtmd_cli_context {
     }
 
     bool load_image(const std::string & fname) {
-        mtmd_bitmap bitmap;
-        if (mtmd_helper_bitmap_init_from_file(fname.c_str(), bitmap)) {
+        mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_file(fname.c_str()));
+        if (!bmp.ptr) {
             return false;
         }
-        bitmaps.push_back(std::move(bitmap));
+        bitmaps.entries.push_back(std::move(bmp));
         return true;
     }
 };
@@ -193,27 +193,40 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, bool add_
     LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str());
 
     mtmd_input_text text;
-    text.text          = formatted_chat.prompt;
+    text.text          = formatted_chat.prompt.c_str();
     text.add_special   = add_bos;
     text.parse_special = true;
-    mtmd_input_chunks chunks;
 
     if (g_is_interrupted) return 0;
 
-    int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, ctx.bitmaps);
+    mtmd::input_chunks chunks(mtmd_input_chunks_init());
+    auto bitmaps_c_ptr = ctx.bitmaps.c_ptr();
+    int32_t res = mtmd_tokenize(ctx.ctx_vision.get(),
+                        chunks.ptr.get(), // output
+                        &text, // text
+                        bitmaps_c_ptr.data(),
+                        bitmaps_c_ptr.size());
     if (res != 0) {
         LOG_ERR("Unable to tokenize prompt, res = %d\n", res);
         return 1;
     }
 
-    ctx.bitmaps.clear();
-
-    if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks, ctx.n_past, 0, ctx.n_batch)) {
+    ctx.bitmaps.entries.clear();
+
+    llama_pos new_n_past;
+    if (mtmd_helper_eval_chunks(ctx.ctx_vision.get(),
+                ctx.lctx, // lctx
+                chunks.ptr.get(), // chunks
+                ctx.n_past, // n_past
+                0, // seq_id
+                ctx.n_batch, // n_batch
+                true, // logits_last
+                &new_n_past)) {
         LOG_ERR("Unable to eval prompt\n");
         return 1;
     }
 
-    ctx.n_past += mtmd_helper_get_n_pos(chunks);
+    ctx.n_past = new_n_past;
 
     LOG("\n");
 
@@ -246,7 +259,7 @@ int main(int argc, char ** argv) {
     struct common_sampler * smpl = common_sampler_init(ctx.model, params.sampling);
     int n_predict = params.n_predict < 0 ? INT_MAX : params.n_predict;
 
-    // ctrl+C handling
+    // Ctrl+C handling
     {
 #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
         struct sigaction sigint_action;
index 73abf2ad18e556226a24c0c283ab73f8b32d0512..b600e4341375fa30368779ee55081cba7b0ad985 100644 (file)
 #include <limits>
 #include <vector>
 
+// represents raw image data, layout is RGBRGBRGB...
+// length of data must be nx * ny * 3
+struct mtmd_bitmap {
+    uint32_t nx;
+    uint32_t ny;
+    std::vector<unsigned char> data;
+    std::string id; // optional user-defined id, for ex: can be set to image hash, useful for KV cache tracking
+};
+
+struct mtmd_image_tokens_deleter {
+    void operator()(mtmd_image_tokens * val); // forward declaration
+};
+using mtmd_image_tokens_ptr = std::unique_ptr<mtmd_image_tokens, mtmd_image_tokens_deleter>;
+
+struct mtmd_input_chunk {
+    mtmd_input_chunk_type type;
+    std::vector<llama_token> tokens_text;
+    mtmd_image_tokens_ptr tokens_image;
+};
+
+struct mtmd_input_chunks {
+    std::vector<mtmd_input_chunk> entries;
+};
+
 // slice template, used by some llava-uhd models to correctly place the special tokens around image embeddings
 // models not having it (llava-1.6) will process embeddings without any special tokens in-between
 enum mtmd_slice_tmpl {
@@ -21,6 +45,16 @@ enum mtmd_slice_tmpl {
     // TODO @ngxson : add support for idefics (SmolVLM)
 };
 
+mtmd_context_params mtmd_context_params_default() {
+    mtmd_context_params params;
+    params.use_gpu = true;
+    params.print_timings = true;
+    params.n_threads = 4;
+    params.verbosity = GGML_LOG_LEVEL_INFO;
+    params.image_marker = MTMD_DEFAULT_IMAGE_MARKER;
+    return params;
+}
+
 struct mtmd_context {
     struct clip_ctx * ctx_clip;
     const struct llama_model * text_model;
@@ -132,6 +166,16 @@ struct mtmd_image_tokens {
     uint32_t n_tokens() const { return nx * ny; }
     clip_image_f32_batch batch_f32; // preprocessed image patches
     std::string id; // optional user-defined ID, useful for KV cache tracking
+
+    mtmd_image_tokens clone() {
+        return mtmd_image_tokens{
+            nx,
+            ny,
+            use_mrope_pos,
+            batch_f32.clone(),
+            id
+        };
+    }
 };
 
 mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
@@ -172,12 +216,13 @@ static std::vector<llama_token> mtmd_tokenize_text_internal(
 }
 
 int32_t mtmd_tokenize(mtmd_context * ctx,
-                        std::vector<mtmd_input_chunk> & output,
-                        const mtmd_input_text & text,
-                        const std::vector<mtmd_bitmap> & bitmaps) {
+            mtmd_input_chunks * output,
+            const mtmd_input_text * text,
+            const mtmd_bitmap ** bitmaps,
+            size_t n_bitmaps) {
     auto vocab = llama_model_get_vocab(ctx->text_model);
 
-    std::string prompt_modified(text.text);
+    std::string prompt_modified(text->text);
     std::string marker_modified(ctx->image_marker);
     projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
 
@@ -211,8 +256,8 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
     // for glm-edge, BOI and EOI token's embeddings are not present in the text model
 
     std::vector<std::string> parts = string_split_str(prompt_modified, ctx->image_marker);
-    output.clear();
-    output.reserve(parts.size());
+    output->entries.clear();
+    output->entries.reserve(parts.size());
 
     size_t i_img = 0;
 
@@ -223,7 +268,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
             std::move(tokens),
             {},
         };
-        output.emplace_back(std::move(chunk));
+        output->entries.emplace_back(std::move(chunk));
     };
 
     // utility for splitting batch of multiple images into chunks of batch having single images
@@ -251,7 +296,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
     for (const auto & part : parts) {
         // printf("tokenizing part: %s\n", part.c_str());
         bool add_bos = &parts.front() == &part;
-        auto tokens = mtmd_tokenize_text_internal(vocab, part, text.add_special && add_bos, text.parse_special);
+        auto tokens = mtmd_tokenize_text_internal(vocab, part, text->add_special && add_bos, text->parse_special);
         if (tokens.empty()) {
             continue;
         }
@@ -260,22 +305,22 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
             std::move(tokens),
             {},
         };
-        output.emplace_back(std::move(chunk));
+        output->entries.emplace_back(std::move(chunk));
 
         if (&parts.back() != &part) {
             // add image token to middle of 2 parts
 
-            if (i_img >= bitmaps.size()) {
+            if (i_img >= n_bitmaps) {
                 LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
                 return 1;
             }
 
             // convert mtmd_bitmap to clip_image_u8
             clip_image_u8_ptr img_u8(clip_image_u8_init());
-            img_u8->nx = bitmaps[i_img].nx;
-            img_u8->ny = bitmaps[i_img].ny;
-            img_u8->buf.resize(bitmaps[i_img].data.size());
-            std::memcpy(img_u8->buf.data(), bitmaps[i_img].data.data(), img_u8->nx * img_u8->ny * 3);
+            img_u8->nx = bitmaps[i_img]->nx;
+            img_u8->ny = bitmaps[i_img]->ny;
+            img_u8->buf.resize(bitmaps[i_img]->data.size());
+            std::memcpy(img_u8->buf.data(), bitmaps[i_img]->data.data(), img_u8->nx * img_u8->ny * 3);
             clip_image_size img_u8_size{img_u8->nx, img_u8->ny};
 
             // preprocess image
@@ -288,12 +333,12 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
 
             if (ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_5 || ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_6) {
                 // split batch into chunks of single images
-                auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[i_img].id);
+                auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[i_img]->id);
                 GGML_ASSERT(chunks.size() > 0);
 
                 // add overview image
                 add_text_chunk({ctx->tok_ov_img_start});
-                output.emplace_back(std::move(chunks.front()));
+                output->entries.emplace_back(std::move(chunks.front()));
                 chunks.erase(chunks.begin());
                 add_text_chunk({ctx->tok_ov_img_end});
 
@@ -311,7 +356,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
                             if (ctx->tok_sli_img_start != LLAMA_TOKEN_NULL) {
                                 add_text_chunk({ctx->tok_sli_img_start});
                             }
-                            output.emplace_back(std::move(chunks[y * n_col + x]));
+                            output->entries.emplace_back(std::move(chunks[y * n_col + x]));
                             if (ctx->tok_sli_img_end != LLAMA_TOKEN_NULL) {
                                 add_text_chunk({ctx->tok_sli_img_end});
                             }
@@ -343,7 +388,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
                     image_tokens->ny = 1;
                 }
                 image_tokens->batch_f32 = std::move(batch_f32);
-                image_tokens->id = bitmaps[i_img].id; // optional
+                image_tokens->id = bitmaps[i_img]->id; // optional
 
                 LOG_DBG("image_tokens->nx = %d\n", image_tokens->nx);
                 LOG_DBG("image_tokens->ny = %d\n", image_tokens->ny);
@@ -354,7 +399,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
                     {},
                     std::move(image_tokens),
                 };
-                output.emplace_back(std::move(chunk));
+                output->entries.emplace_back(std::move(chunk));
             }
 
             i_img++; // move to next image
@@ -364,35 +409,12 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
     return 0;
 }
 
-void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) {
+static void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) {
     if (image_tokens) {
         delete image_tokens;
     }
 }
 
-size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens) {
-    return image_tokens->n_tokens();
-}
-
-size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens) {
-    return image_tokens->nx;
-}
-
-size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) {
-    return image_tokens->ny;
-}
-
-std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
-    return image_tokens->id;
-}
-
-llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
-    if (image_tokens->use_mrope_pos) {
-        return 1; // for M-RoPE, the whole image is 1 in temporal dimension
-    }
-    return image_tokens->n_tokens();
-}
-
 int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
     int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
     ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd);
@@ -432,13 +454,18 @@ float * mtmd_get_output_embd(mtmd_context * ctx) {
     return ctx->image_embd_v.data();
 }
 
-size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks) {
+size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks) {
     size_t n_tokens = 0;
-    for (auto & chunk : chunks) {
-        if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
-            n_tokens += chunk.tokens_text.size();
-        } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
-            n_tokens += mtmd_image_tokens_get_n_tokens(chunk.tokens_image.get());
+    for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) {
+        auto chunk = mtmd_input_chunks_get(chunks, i);
+        auto chunk_type = mtmd_input_chunk_get_type(chunk);
+        if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
+            size_t n_tokens_text;
+            mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text);
+            n_tokens += n_tokens_text;
+        } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
+            auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk);
+            n_tokens += mtmd_image_tokens_get_n_tokens(tokens_image);
         } else {
             GGML_ASSERT(false && "chunk type not supported");
         }
@@ -446,13 +473,18 @@ size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks) {
     return n_tokens;
 }
 
-llama_pos mtmd_helper_get_n_pos(mtmd_input_chunks & chunks) {
+llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) {
     llama_pos n_pos = 0;
-    for (auto & chunk : chunks) {
-        if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
-            n_pos += chunk.tokens_text.size();
-        } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
-            n_pos += mtmd_image_tokens_get_n_pos(chunk.tokens_image.get());
+    for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) {
+        auto chunk = mtmd_input_chunks_get(chunks, i);
+        auto chunk_type = mtmd_input_chunk_get_type(chunk);
+        if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
+            size_t n_tokens_text;
+            mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text);
+            n_pos += n_tokens_text;
+        } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
+            auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk);
+            n_pos += mtmd_image_tokens_get_n_pos(tokens_image);
         } else {
             GGML_ASSERT(false && "chunk type not supported");
         }
@@ -548,143 +580,172 @@ struct decode_embd_batch {
     }
 };
 
-int32_t mtmd_helper_eval(mtmd_context * ctx,
-        llama_context * lctx,
-        mtmd_input_chunks & chunks,
-        llama_pos pos0,
+int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
+        struct llama_context * lctx,
+        const mtmd_input_chunk * chunk,
+        llama_pos n_past,
         llama_seq_id seq_id,
-        int32_t n_batch) {
+        int32_t n_batch,
+        bool logits_last,
+        llama_pos * new_n_past) {
     int32_t ret;
-    llama_pos n_past = pos0;
     llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
+    auto chunk_type = mtmd_input_chunk_get_type(chunk);
     int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
     int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1;
 
-    for (auto & chunk : chunks) {
-        bool is_last = &chunk == &chunks.back();
-        if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
-            text_batch.n_tokens = chunk.tokens_text.size();
-            size_t i = 0;
-            while (i < chunk.tokens_text.size()) { // split into batches
-                for (; i < chunk.tokens_text.size() && text_batch.n_tokens < n_batch; i++) {
-                    text_batch.token   [i]    = chunk.tokens_text[i];
-                    text_batch.pos     [i]    = n_past++;
-                    text_batch.n_seq_id[i]    = 1;
-                    text_batch.seq_id  [i][0] = seq_id;
-                    text_batch.logits  [i]    = false;
-                }
-                if (is_last) {
-                    // always get logits for last input chunk
-                    text_batch.logits[text_batch.n_tokens - 1] = true;
-                }
-                ret = llama_decode(lctx, text_batch);
-                if (ret != 0) {
-                    LOG_ERR("failed to decode text\n");
-                    llama_batch_free(text_batch);
-                    return ret;
-                }
+    if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
+        size_t n_tokens;
+        const auto tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
+        LOG_DBG("decoding text chunk, n_tokens = %zu\n", n_tokens);
+        size_t i = 0;
+        while (i < n_tokens) { // split into batches
+            text_batch.n_tokens = 0; // clear the batch
+            for (; i < n_tokens && text_batch.n_tokens < n_batch; i++) {
+                text_batch.n_tokens++;
+                text_batch.token   [i]    = tokens[i];
+                text_batch.pos     [i]    = n_past++;
+                text_batch.n_seq_id[i]    = 1;
+                text_batch.seq_id  [i][0] = seq_id;
+                text_batch.logits  [i]    = false;
             }
-
-        } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
-            GGML_ASSERT(!is_last && "logits for last image chunk is not yet supported");
-            GGML_ASSERT(chunk.tokens_image != nullptr);
-            int64_t t0 = ggml_time_ms();
-            if (ctx->print_timings) {
-                LOG_INF("encoding image or slice...\n");
+            bool is_last_token = (i == n_tokens);
+            if (logits_last && is_last_token) {
+                text_batch.logits[text_batch.n_tokens - 1] = true;
             }
-            ret = mtmd_encode(ctx, chunk.tokens_image.get());
+            ret = llama_decode(lctx, text_batch);
             if (ret != 0) {
-                LOG_ERR("failed to encode image\n");
+                LOG_ERR("failed to decode text\n");
                 llama_batch_free(text_batch);
                 return ret;
             }
-            if (ctx->print_timings) {
-                LOG_INF("image/slice encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
-            }
+            *new_n_past += text_batch.n_tokens;
+        }
+
+    } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
+        const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
+        int64_t t0 = ggml_time_ms();
+        if (ctx->print_timings) {
+            LOG_INF("encoding image or slice...\n");
+        }
+        ret = mtmd_encode(ctx, image_tokens);
+        if (ret != 0) {
+            LOG_ERR("failed to encode image\n");
+            llama_batch_free(text_batch);
+            return ret;
+        }
+        if (ctx->print_timings) {
+            LOG_INF("image/slice encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
+        }
 
-            int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.tokens_image.get());
-            int32_t i_batch = 0;
-            int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
-            float * embd = mtmd_get_output_embd(ctx);
-            decode_embd_batch batch_embd(embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
+        int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
+        int32_t i_batch = 0;
+        int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
+        float * embd = mtmd_get_output_embd(ctx);
+        decode_embd_batch batch_embd(embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
 
-            const int nx = mtmd_image_tokens_get_nx(chunk.tokens_image.get());
-            const int ny = mtmd_image_tokens_get_ny(chunk.tokens_image.get());
+        const int nx = mtmd_image_tokens_get_nx(image_tokens);
+        const int ny = mtmd_image_tokens_get_ny(image_tokens);
 
-            if (mtmd_decode_use_mrope(ctx)) {
-                batch_embd.set_position_mrope(n_past, nx, ny, seq_id);
-            } else {
-                batch_embd.set_position_normal(n_past, seq_id);
+        if (mtmd_decode_use_mrope(ctx)) {
+            batch_embd.set_position_mrope(n_past, nx, ny, seq_id);
+        } else {
+            batch_embd.set_position_normal(n_past, seq_id);
+        }
+
+        if (mtmd_decode_use_non_causal(ctx)) {
+            llama_set_causal_attn(lctx, false);
+            // TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
+        }
+
+        while (i_batch < n_img_batches) { // split into batches
+            int pos_offset = i_batch*n_batch;
+            int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
+            llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);
+
+            LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
+
+            int64_t t1 = ggml_time_ms();
+            ret = llama_decode(lctx, batch_embd_view);
+            if (ret != 0) {
+                LOG_ERR("failed to decode image\n");
+                llama_set_causal_attn(lctx, true); // restore causal attn
+                llama_batch_free(text_batch);
+                return ret;
             }
 
-            if (mtmd_decode_use_non_causal(ctx)) {
-                llama_set_causal_attn(lctx, false);
-                // TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
+            if (ctx->print_timings) {
+                LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1);
             }
 
-            while (i_batch < n_img_batches) { // split into batches
-                int pos_offset = i_batch*n_batch;
-                int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
-                llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);
+            i_batch++;
+        }
 
-                LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
+        n_past += mtmd_image_tokens_get_n_pos(image_tokens);
+        *new_n_past = n_past;
 
-                int64_t t1 = ggml_time_ms();
-                ret = llama_decode(lctx, batch_embd_view);
-                if (ret != 0) {
-                    LOG_ERR("failed to decode image\n");
-                    llama_set_causal_attn(lctx, true); // restore causal attn
-                    llama_batch_free(text_batch);
-                    return ret;
-                }
+        if (mtmd_decode_use_non_causal(ctx)) {
+            llama_set_causal_attn(lctx, true);
+        }
 
-                if (ctx->print_timings) {
-                    LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1);
-                }
+    } else {
+        GGML_ABORT("chunk type not supported");
+    }
 
-                i_batch++;
-            }
+    return 0;
+}
 
-            // for mrope, one image is one single **temporal** position
-            n_past += mtmd_decode_use_mrope(ctx) ? 1 : n_tokens;
+int32_t mtmd_helper_eval_chunks(mtmd_context * ctx,
+                                struct llama_context * lctx,
+                                const mtmd_input_chunks * chunks,
+                                llama_pos n_past,
+                                llama_seq_id seq_id,
+                                int32_t n_batch,
+                                bool logits_last,
+                                llama_pos * new_n_past) {
+    size_t n_chunks = mtmd_input_chunks_size(chunks);
+    if (n_chunks == 0) {
+        LOG_WRN("no chunks to eval\n");
+        return 0;
+    }
 
-            if (mtmd_decode_use_non_causal(ctx)) {
-                llama_set_causal_attn(lctx, true);
-            }
+    for (size_t i = 0; i < n_chunks; i++) {
+        bool chunk_logits_last = (i == n_chunks - 1) && logits_last;
+        auto chunk = mtmd_input_chunks_get(chunks, i);
 
-        } else {
-            GGML_ASSERT(false && "chunk type not supported");
+        int32_t res = mtmd_helper_eval_chunk_single(ctx, lctx, chunk, n_past, seq_id, n_batch, chunk_logits_last, &n_past);
+        if (res != 0) {
+            LOG_ERR("failed to eval chunk %zu\n", i);
+            return res;
         }
+        *new_n_past = n_past;
     }
 
-    llama_batch_free(text_batch);
     return 0;
 }
 
-int32_t mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len, mtmd_bitmap & output) {
+mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len) {
     clip_image_u8_ptr img_u8(clip_image_u8_init());
     bool ok = clip_image_load_from_bytes(buf, len, img_u8.get());
     if (!ok) {
         LOG_ERR("Unable to load image from buffer\n");
-        return 1;
+        return nullptr;
     }
-    unsigned char * data = clip_image_u8_get_data(img_u8.get(), &output.nx, &output.ny);
-    output.data.resize(output.nx * output.ny * 3);
-    std::memcpy(output.data.data(), data, output.nx * output.ny * 3);
-    return 0;
+    uint32_t nx, ny;
+    unsigned char * data = clip_image_u8_get_data(img_u8.get(), &nx, &ny);
+    return mtmd_bitmap_init(nx, ny, data);
 }
 
-int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & output) {
+mtmd_bitmap * mtmd_helper_bitmap_init_from_file(const char * fname) {
     clip_image_u8_ptr img_u8(clip_image_u8_init());
     bool ok = clip_image_load_from_file(fname, img_u8.get());
     if (!ok) {
         LOG_ERR("Unable to load image %s\n", fname);
-        return 1;
+        return nullptr;
     }
-    unsigned char * data = clip_image_u8_get_data(img_u8.get(), &output.nx, &output.ny);
-    output.data.resize(output.nx * output.ny * 3);
-    std::memcpy(output.data.data(), data, output.nx * output.ny * 3);
-    return 0;
+    uint32_t nx, ny;
+    unsigned char * data = clip_image_u8_get_data(img_u8.get(), &nx, &ny);
+    return mtmd_bitmap_init(nx, ny, data);
 }
 
 bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
@@ -702,3 +763,175 @@ bool mtmd_decode_use_mrope(mtmd_context * ctx) {
 void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) {
     mtmd_image_tokens_free(val);
 }
+
+
+//
+// public API functions
+//
+
+// mtmd_bitmap
+
+mtmd_bitmap * mtmd_bitmap_init(uint32_t nx,
+                               uint32_t ny,
+                               const unsigned char * data) {
+    mtmd_bitmap * bitmap = new mtmd_bitmap;
+    bitmap->nx = nx;
+    bitmap->ny = ny;
+    size_t data_size = (size_t)nx * ny * 3;
+    bitmap->data.resize(data_size);
+    std::memcpy(bitmap->data.data(), data, data_size);
+    return bitmap;
+}
+
+uint32_t mtmd_bitmap_get_nx(const mtmd_bitmap * bitmap) {
+    return bitmap->nx;
+}
+
+uint32_t mtmd_bitmap_get_ny(const mtmd_bitmap * bitmap) {
+    return bitmap->ny;
+}
+
+const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap) {
+    return bitmap->data.data();
+}
+
+const char * mtmd_bitmap_get_id(const mtmd_bitmap * bitmap) {
+    return bitmap->id.c_str();
+}
+
+void mtmd_bitmap_set_id(mtmd_bitmap * bitmap, const char * id) {
+    if (id) {
+        bitmap->id = std::string(id);
+    } else {
+        bitmap->id.clear();
+    }
+}
+
+void mtmd_bitmap_free(mtmd_bitmap * bitmap) {
+    if (bitmap) {
+        delete bitmap;
+    }
+}
+
+// mtmd_input_chunks
+
+mtmd_input_chunks * mtmd_input_chunks_init() {
+    return new mtmd_input_chunks;
+}
+
+size_t mtmd_input_chunks_size(const mtmd_input_chunks * chunks) {
+    return chunks->entries.size();
+}
+
+const mtmd_input_chunk * mtmd_input_chunks_get(const mtmd_input_chunks * chunks, size_t idx) {
+    if (idx >= chunks->entries.size()) {
+        return nullptr;
+    }
+    return &chunks->entries[idx];
+}
+
+void mtmd_input_chunks_free(mtmd_input_chunks * chunks) {
+    if (chunks) {
+        delete chunks;
+    }
+}
+
+// mtmd_input_chunk
+
+enum mtmd_input_chunk_type mtmd_input_chunk_get_type(const mtmd_input_chunk * chunk) {
+    return chunk->type;
+}
+
+const llama_token * mtmd_input_chunk_get_tokens_text(const mtmd_input_chunk * chunk, size_t * n_tokens_output) {
+    if (chunk->type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
+        *n_tokens_output = chunk->tokens_text.size();
+        return chunk->tokens_text.data();
+    }
+    *n_tokens_output = 0;
+    return nullptr;
+}
+
+const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd_input_chunk * chunk) {
+    if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
+        return chunk->tokens_image.get();
+    }
+    return nullptr;
+}
+
+mtmd_input_chunk * mtmd_input_chunk_copy(const mtmd_input_chunk * chunk) {
+    mtmd_input_chunk * copy = new mtmd_input_chunk{
+        chunk->type,
+        chunk->tokens_text,
+        mtmd_image_tokens_ptr(),
+    };
+    if (chunk->tokens_image) {
+        // copy the image tokens
+        copy->tokens_image = mtmd_image_tokens_ptr(new mtmd_image_tokens());
+        *copy->tokens_image = chunk->tokens_image->clone();
+    }
+    return copy;
+}
+
+void mtmd_input_chunk_free(mtmd_input_chunk * chunk) {
+    if (chunk) {
+        delete chunk;
+    }
+}
+
+// mtmd_image_tokens
+
+size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens) {
+    return image_tokens->n_tokens();
+}
+
+size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens) {
+    return image_tokens->nx;
+}
+
+size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) {
+    return image_tokens->ny;
+}
+
+const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
+    return image_tokens->id.c_str();
+}
+
+llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
+    if (image_tokens->use_mrope_pos) {
+        return 1; // for M-RoPE, the whole image is 1 in temporal dimension
+    }
+    return image_tokens->n_tokens();
+}
+
+// test function
+
+mtmd_input_chunks * mtmd_test_create_input_chunks() {
+    mtmd_input_chunks * chunks = mtmd_input_chunks_init();
+    if (!chunks) {
+        return nullptr;
+    }
+
+    // create a text chunk
+    std::vector<llama_token> tokens_text = { 1, 2, 3, 4, 5 };
+    mtmd_input_chunk chunk_text{
+        MTMD_INPUT_CHUNK_TYPE_TEXT,
+        std::move(tokens_text),
+        {},
+    };
+    chunks->entries.emplace_back(std::move(chunk_text));
+
+    // create an image chunk
+    mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
+    image_tokens->nx = 4;
+    image_tokens->ny = 4;
+    image_tokens->batch_f32.entries.resize(16);
+    image_tokens->id = "image_1";
+    mtmd_input_chunk chunk_image{
+        MTMD_INPUT_CHUNK_TYPE_IMAGE,
+        {},
+        std::move(image_tokens),
+    };
+    chunks->entries.emplace_back(std::move(chunk_image));
+
+    return chunks;
+}
index 6805e5e4816c325f0566df3eb360c95020e6c5f9..e2f76e2e8d346cd7b0970798d48d5b9bdbb34028 100644 (file)
@@ -5,9 +5,24 @@
 #include "llama.h"
 #include "clip.h"
 
+#include <stddef.h>
+#include <stdint.h>
+#include <stdbool.h>
+
+#ifdef __cplusplus
 #include <vector>
 #include <cinttypes>
 #include <memory>
+#endif
+
+/**
+ * libmtmd: A library for multimodal support in llama.cpp.
+ *
+ * WARNING: This API is experimental and subject to many BREAKING CHANGES.
+ *          Issues related to API usage may receive lower priority support.
+ *
+ * For the usage, see an example in mtmd-cli.cpp
+ */
 
 #ifdef LLAMA_SHARED
 #    if defined(_WIN32) && !defined(__MINGW32__)
 #    define MTMD_API
 #endif
 
+#define MTMD_DEFAULT_IMAGE_MARKER "<__image__>"
+
 #ifdef __cplusplus
+extern "C" {
+#endif
 
 enum mtmd_input_chunk_type {
     MTMD_INPUT_CHUNK_TYPE_TEXT,
     MTMD_INPUT_CHUNK_TYPE_IMAGE,
 };
 
+// opaque types
 struct mtmd_context;
+struct mtmd_bitmap;
 struct mtmd_image_tokens;
+struct mtmd_input_chunk;
+struct mtmd_input_chunks;
 
-// represents raw image data, layout is RGBRGBRGB...
-// length of data must be nx * ny * 3
-struct mtmd_bitmap {
-    uint32_t nx;
-    uint32_t ny;
-    std::vector<unsigned char> data;
-    std::string id; // optional user-defined id, for ex: can be set to image hash, useful for KV cache tracking
-};
-
-struct mtmd_image_tokens_deleter {
-    void operator()(mtmd_image_tokens * val); // forward declaration
+struct mtmd_input_text {
+    const char * text;
+    bool add_special;
+    bool parse_special;
 };
-using mtmd_image_tokens_ptr = std::unique_ptr<mtmd_image_tokens, mtmd_image_tokens_deleter>;
 
-struct mtmd_input_chunk {
-    mtmd_input_chunk_type type;
-    std::vector<llama_token> tokens_text;
-    mtmd_image_tokens_ptr tokens_image;
-};
+//
+// C API
+//
 
-using mtmd_input_chunks = std::vector<mtmd_input_chunk>;
+typedef struct mtmd_context      mtmd_context;
+typedef struct mtmd_bitmap       mtmd_bitmap;
+typedef struct mtmd_image_tokens mtmd_image_tokens;
+typedef struct mtmd_input_chunk  mtmd_input_chunk;
+typedef struct mtmd_input_chunks mtmd_input_chunks;
+typedef struct mtmd_input_text   mtmd_input_text;
 
 struct mtmd_context_params {
-    bool use_gpu = true;
-    bool print_timings = true;
-    int n_threads = 4;
-    enum ggml_log_level verbosity = GGML_LOG_LEVEL_INFO;
-    const char * image_marker = "<__image__>";
+    bool use_gpu;
+    bool print_timings;
+    int n_threads;
+    enum ggml_log_level verbosity;
+    const char * image_marker;
 };
 
-struct mtmd_input_text {
-    std::string text;
-    bool add_special;
-    bool parse_special;
-};
+MTMD_API struct mtmd_context_params mtmd_context_params_default(void);
 
 // initialize the mtmd context
 // return nullptr on failure
 MTMD_API mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
-                                                const llama_model * text_model,
-                                                const mtmd_context_params ctx_params);
+                                            const struct llama_model * text_model,
+                                            const struct mtmd_context_params ctx_params);
 
 MTMD_API void mtmd_free(mtmd_context * ctx);
 
+// whether we need to set non-causal mask before llama_decode
+MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);
+
+// whether the current model use M-RoPE for llama_decode
+MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx);
+
+
+// mtmd_bitmap
+//
+// length of data must be nx * ny * 3
+// the data is in RGBRGBRGB... format
+MTMD_API mtmd_bitmap *         mtmd_bitmap_init    (uint32_t nx,
+                                                    uint32_t ny,
+                                                    const unsigned char * data);
+MTMD_API uint32_t              mtmd_bitmap_get_nx  (const mtmd_bitmap * bitmap);
+MTMD_API uint32_t              mtmd_bitmap_get_ny  (const mtmd_bitmap * bitmap);
+MTMD_API const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap);
+MTMD_API void                  mtmd_bitmap_free    (mtmd_bitmap * bitmap);
+// bitmap ID is optional, but useful for KV cache tracking
+// these getters/setters are dedicated functions, so you can for example calculate the hash of the image based on mtmd_bitmap_get_data()
+MTMD_API const char * mtmd_bitmap_get_id(const mtmd_bitmap * bitmap);
+MTMD_API void         mtmd_bitmap_set_id(mtmd_bitmap * bitmap, const char * id);
+
+
+// mtmd_input_chunks
+//
+// this is simply a list of mtmd_input_chunk
+// the elements can only be populated via mtmd_tokenize()
+MTMD_API mtmd_input_chunks *      mtmd_input_chunks_init(void);
+MTMD_API size_t                   mtmd_input_chunks_size(const mtmd_input_chunks * chunks);
+MTMD_API const mtmd_input_chunk * mtmd_input_chunks_get (const mtmd_input_chunks * chunks, size_t idx);
+MTMD_API void                     mtmd_input_chunks_free(mtmd_input_chunks * chunks);
+
+// mtmd_input_chunk
+//
+// the instance will be constructed via mtmd_tokenize()
+// it will be freed along with mtmd_input_chunks
+MTMD_API enum mtmd_input_chunk_type mtmd_input_chunk_get_type        (const mtmd_input_chunk * chunk);
+MTMD_API const llama_token *        mtmd_input_chunk_get_tokens_text (const mtmd_input_chunk * chunk, size_t * n_tokens_output);
+MTMD_API const mtmd_image_tokens *  mtmd_input_chunk_get_tokens_image(const mtmd_input_chunk * chunk);
+
+// in case you want to use custom logic to handle the chunk (i.e. KV cache management)
+// you can move the chunk ownership to your own code by copying it
+// remember to free the chunk when you are done with it
+MTMD_API mtmd_input_chunk * mtmd_input_chunk_copy(const mtmd_input_chunk * chunk);
+MTMD_API void               mtmd_input_chunk_free(mtmd_input_chunk * chunk);
+
+
+// mtmd_image_tokens
+//
+// the instance will be constructed via mtmd_tokenize()
+// it will be freed along with mtmd_input_chunk
+MTMD_API size_t       mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens);
+MTMD_API size_t       mtmd_image_tokens_get_nx      (const mtmd_image_tokens * image_tokens);
+MTMD_API size_t       mtmd_image_tokens_get_ny      (const mtmd_image_tokens * image_tokens);
+MTMD_API const char * mtmd_image_tokens_get_id      (const mtmd_image_tokens * image_tokens);
+// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
+MTMD_API llama_pos    mtmd_image_tokens_get_n_pos   (const mtmd_image_tokens * image_tokens);
+
 // tokenize an input text prompt and an image
 // the prompt must have the input image marker (default: "<__image__>") in it
 // the marker will be replaced with the image tokens
@@ -93,75 +166,152 @@ MTMD_API void mtmd_free(mtmd_context * ctx);
 //   1 on number of images not matching the number of markers
 //   2 on image preprocessing error
 MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx,
-                                std::vector<mtmd_input_chunk> & output,
-                                const mtmd_input_text & text,
-                                const std::vector<mtmd_bitmap> & bitmaps);
-
-// access mtmd_image_tokens
-MTMD_API size_t      mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens);
-MTMD_API size_t      mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens);
-MTMD_API size_t      mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens);
-MTMD_API std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens);
-MTMD_API llama_pos   mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens); // number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
-MTMD_API void        mtmd_image_tokens_free(mtmd_image_tokens * image_tokens);
+                               mtmd_input_chunks * output,
+                               const mtmd_input_text * text,
+                               const mtmd_bitmap ** bitmaps,
+                               size_t n_bitmaps);
 
 // returns 0 on success
 MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
-                            const mtmd_image_tokens * image_tokens);
+                             const mtmd_image_tokens * image_tokens);
 
 // get output embeddings from the last encode pass
 MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
 
-// whether we need to set non-causal mask before llama_decode
-MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);
-
-// whether the current model use M-RoPE for llama_decode
-MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx);
-
-
+/////////////////////////////////////////
 
 //
-// helper functions (can be implemented based on other functions)
+// Helper functions (can be implemented based on other functions)
+//
+// Please note that these helpers are not guaranteed to be stable.
+// BREAKING CHANGES are expected.
 //
 
+// helper function to construct a mtmd_bitmap from a file
+// returns nullptr on failure
+// this function is thread-safe
+MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_file(const char * fname);
+
+// helper function to construct a mtmd_bitmap from a buffer containing a file
+// the file content must be an image in format supported by stb_image (jpg, png, bmp, gif, etc.)
+// returns nullptr on failure
+// this function is thread-safe
+MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len);
+
 // helper to count the total number of tokens from a list of chunks, useful to keep track of KV cache
-MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks);
+MTMD_API size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks);
 
 // helper to count the total position of tokens from a list of chunks, useful to keep track of n_past
-MTMD_API llama_pos mtmd_helper_get_n_pos(mtmd_input_chunks & chunks);
+// normally, n_pos is equal to n_tokens, but for M-RoPE it is different
+MTMD_API llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks);
 
 // helper function that automatically:
 // 1. run llama_decode() on text chunks
 // 2. run mtmd_encode() on image chunks, then mtmd_get_output_embd() and then llama_decode()
 // if any of the mtmd_encode() or llama_decode() calls return non-zero, stop and forward the error
 // otherwise, returns 0 on success
-MTMD_API int32_t mtmd_helper_eval(mtmd_context * ctx,
-                                llama_context * lctx,
-                                mtmd_input_chunks & chunks,
-                                llama_pos pos0,
-                                llama_seq_id seq_id,
-                                int32_t n_batch);
+// this function is NOT thread-safe
+MTMD_API int32_t mtmd_helper_eval_chunks(mtmd_context * ctx,
+                                         struct llama_context * lctx,
+                                         const mtmd_input_chunks * chunks,
+                                         llama_pos n_past,
+                                         llama_seq_id seq_id,
+                                         int32_t n_batch,
+                                         bool logits_last,
+                                         llama_pos * new_n_past);
+
+// works like mtmd_helper_eval_chunks(), but only for a single chunk
+// this function is NOT thread-safe
+MTMD_API int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
+                                               struct llama_context * lctx,
+                                               const mtmd_input_chunk * chunk,
+                                               llama_pos n_past,
+                                               llama_seq_id seq_id,
+                                               int32_t n_batch,
+                                               bool logits_last,
+                                               llama_pos * new_n_past);
+
+/////////////////////////////////////////
+
+// test function, to be used in test-mtmd-c-api.c
+MTMD_API mtmd_input_chunks * mtmd_test_create_input_chunks(void);
 
-// helper function to construct a mtmd_bitmap from a file
-// returns 0 on success
-// this function is thread-safe
-MTMD_API int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & output);
+#ifdef __cplusplus
+} // extern "C"
+#endif
 
-// helper function to construct a mtmd_bitmap from a buffer
-// the buffer must be an image in format supported by stb_image (jpg, png, bmp, gif, etc.)
-// returns 0 on success
-// this function is thread-safe
-MTMD_API int32_t mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len, mtmd_bitmap & output);
+//
+// C++ wrappers
+//
+
+#ifdef __cplusplus
+
+namespace mtmd {
 
-// convenient unique_ptr wrappers
 struct mtmd_context_deleter {
     void operator()(mtmd_context * val) { mtmd_free(val); }
 };
-using mtmd_context_ptr = std::unique_ptr<mtmd_context, mtmd_context_deleter>;
+using context_ptr = std::unique_ptr<mtmd_context, mtmd_context_deleter>;
 
-#else
+struct mtmd_bitmap_deleter {
+    void operator()(mtmd_bitmap * val) { mtmd_bitmap_free(val); }
+};
+using bitmap_ptr = std::unique_ptr<mtmd_bitmap, mtmd_bitmap_deleter>;
+
+struct mtmd_input_chunks_deleter {
+    void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val); }
+};
+using input_chunks_ptr = std::unique_ptr<mtmd_input_chunks, mtmd_input_chunks_deleter>;
+
+struct mtmd_input_chunk_deleter {
+    void operator()(mtmd_input_chunk * val) { mtmd_input_chunk_free(val); }
+};
+using input_chunk_ptr = std::unique_ptr<mtmd_input_chunk, mtmd_input_chunk_deleter>;
+
+struct bitmap {
+    bitmap_ptr ptr;
+    bitmap() : ptr(nullptr) {}
+    bitmap(mtmd_bitmap * bitmap) : ptr(bitmap) {}
+    bitmap(bitmap && other) noexcept : ptr(std::move(other.ptr)) {}
+    bitmap(uint32_t nx, uint32_t ny, const unsigned char * data) {
+        ptr.reset(mtmd_bitmap_init(nx, ny, data));
+    }
+    ~bitmap() = default;
+    uint32_t nx() { return mtmd_bitmap_get_nx(ptr.get()); }
+    uint32_t ny() { return mtmd_bitmap_get_ny(ptr.get()); }
+    const unsigned char * data() { return mtmd_bitmap_get_data(ptr.get()); }
+    std::string id() { return mtmd_bitmap_get_id(ptr.get()); }
+    void set_id(const char * id) { mtmd_bitmap_set_id(ptr.get(), id); }
+};
+
+struct bitmaps {
+    std::vector<bitmap> entries;
+    ~bitmaps() = default;
+    // return list of pointers to mtmd_bitmap
+    // example:
+    //   auto bitmaps_c_ptr = bitmaps.c_ptr();
+    //   int32_t res = mtmd_tokenize(... bitmaps_c_ptr.data(), bitmaps_c_ptr.size());
+    std::vector<const mtmd_bitmap *> c_ptr() {
+        std::vector<const mtmd_bitmap *> res(entries.size());
+        for (size_t i = 0; i < entries.size(); i++) {
+            res[i] = entries[i].ptr.get();
+        }
+        return res;
+    }
+};
+
+struct input_chunks {
+    input_chunks_ptr ptr;
+    input_chunks() = default;
+    input_chunks(mtmd_input_chunks * chunks) : ptr(chunks) {}
+    ~input_chunks() = default;
+    size_t size() { return mtmd_input_chunks_size(ptr.get()); }
+    const mtmd_input_chunk * operator[](size_t idx) {
+        return mtmd_input_chunks_get(ptr.get(), idx);
+    }
+};
 
-static_assert(false && "C header is not yet supported by this library");
+} // namespace mtmd
 
 #endif