]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
mtmd : move helpers to dedicated file (#13442)
authorXuan-Son Nguyen <redacted>
Sun, 11 May 2025 09:34:23 +0000 (11:34 +0200)
committerGitHub <redacted>
Sun, 11 May 2025 09:34:23 +0000 (11:34 +0200)
* mtmd : move helpers to dedicated file

* fix windows build

* rm redundant include

tools/mtmd/CMakeLists.txt
tools/mtmd/mtmd-helper.cpp [new file with mode: 0644]
tools/mtmd/mtmd.cpp
tools/mtmd/mtmd.h

index 27b6d27e5cac3fe27a6a0ff34f7b0a87802c3921..dfafa9cf8116e9762c1bc353418de283656c06b3 100644 (file)
@@ -28,6 +28,7 @@ endif()
 
 add_library(mtmd OBJECT
             mtmd.cpp
+            mtmd-helper.cpp
             mtmd.h
             clip.cpp
             clip.h
diff --git a/tools/mtmd/mtmd-helper.cpp b/tools/mtmd/mtmd-helper.cpp
new file mode 100644 (file)
index 0000000..7a32886
--- /dev/null
@@ -0,0 +1,310 @@
+#include "mtmd.h"
+#include "llama.h"
+
+#include <algorithm>
+#include <cinttypes>
+#include <vector>
+
+#define LOG_INF(...) fprintf(stdout, __VA_ARGS__)
+#define LOG_ERR(...) fprintf(stderr, __VA_ARGS__)
+
+size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks) {
+    size_t n_tokens = 0;
+    for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) {
+        auto chunk = mtmd_input_chunks_get(chunks, i);
+        auto chunk_type = mtmd_input_chunk_get_type(chunk);
+        if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
+            size_t n_tokens_text;
+            mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text);
+            n_tokens += n_tokens_text;
+        } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
+            auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk);
+            n_tokens += mtmd_image_tokens_get_n_tokens(tokens_image);
+        } else {
+            GGML_ASSERT(false && "chunk type not supported");
+        }
+    }
+    return n_tokens;
+}
+
+llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) {
+    llama_pos n_pos = 0;
+    for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) {
+        auto chunk = mtmd_input_chunks_get(chunks, i);
+        auto chunk_type = mtmd_input_chunk_get_type(chunk);
+        if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
+            size_t n_tokens_text;
+            mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text);
+            n_pos += n_tokens_text;
+        } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
+            auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk);
+            n_pos += mtmd_image_tokens_get_n_pos(tokens_image);
+        } else {
+            GGML_ASSERT(false && "chunk type not supported");
+        }
+    }
+    return n_pos;
+}
+
+// helper struct to make working with embd batch easier
+// note: this will be removed after llama_batch_ext refactoring
+struct decode_embd_batch {
+    int n_pos_per_embd;
+    int n_mmproj_embd;
+    std::vector<llama_pos>      pos;
+    std::vector<llama_pos>      pos_view; // used by mrope
+    std::vector<int32_t>        n_seq_id;
+    std::vector<llama_seq_id>   seq_id_0;
+    std::vector<llama_seq_id *> seq_ids;
+    std::vector<int8_t>         logits;
+    llama_batch batch;
+    decode_embd_batch(float * embd, int32_t n_tokens, int n_pos_per_embd, int n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) {
+        pos     .resize(n_tokens * n_pos_per_embd);
+        n_seq_id.resize(n_tokens);
+        seq_ids .resize(n_tokens + 1);
+        logits  .resize(n_tokens);
+        seq_id_0.resize(1);
+        seq_ids [n_tokens] = nullptr;
+        batch = {
+            /*n_tokens       =*/ n_tokens,
+            /*tokens         =*/ nullptr,
+            /*embd           =*/ embd,
+            /*pos            =*/ pos.data(),
+            /*n_seq_id       =*/ n_seq_id.data(),
+            /*seq_id         =*/ seq_ids.data(),
+            /*logits         =*/ logits.data(),
+        };
+    }
+
+    void set_position_normal(llama_pos pos_0, llama_seq_id seq_id) {
+        seq_id_0[0] = seq_id;
+        for (int i = 0; i < batch.n_tokens; i++) {
+            batch.pos     [i] = pos_0 + i;
+            batch.n_seq_id[i] = 1;
+            batch.seq_id  [i] = seq_id_0.data();
+            batch.logits  [i] = false;
+        }
+    }
+
+    void set_position_mrope(llama_pos pos_0, int nx, int ny, llama_seq_id seq_id) {
+        GGML_ASSERT(n_pos_per_embd == 4);
+        seq_id_0[0] = seq_id;
+        for (int y = 0; y < ny; y++) {
+            for (int x = 0; x < nx; x++) {
+                int i = y * nx + x;
+                pos[i                     ] = pos_0;
+                pos[i + batch.n_tokens    ] = pos_0 + y;
+                pos[i + batch.n_tokens * 2] = pos_0 + x;
+                pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused
+            }
+        }
+        for (int i = 0; i < batch.n_tokens; i++) {
+            batch.n_seq_id[i] = 1;
+            batch.seq_id  [i] = seq_id_0.data();
+            batch.logits  [i] = false;
+        }
+    }
+
+    llama_batch get_view(int offset, int n_tokens) {
+        llama_pos * pos_ptr;
+        pos_view.clear();
+        pos_view.reserve(n_tokens * n_pos_per_embd);
+        if (n_pos_per_embd > 1) {
+            // mrope
+            // for example, with layout of src: 1234...1234...1234...1234...
+            //       offset 2 will give us dst: 34...34...34...34...
+            for (int i = 0; i < n_pos_per_embd; i++) {
+                // assume n_tokens is less than or equal to batch.n_tokens
+                // batch.n_tokens is number of **total** tokens
+                // n_tokens is number of viewed token
+                size_t src_idx = i * batch.n_tokens + offset;
+                pos_view.insert(pos_view.end(),
+                    pos.data() + src_idx,
+                    pos.data() + src_idx + n_tokens);
+            }
+            pos_ptr = pos_view.data();
+        } else {
+            // normal
+            pos_ptr = pos.data() + offset;
+        }
+        return {
+            /*n_tokens       =*/ n_tokens,
+            /*tokens         =*/ nullptr,
+            /*embd           =*/ batch.embd     + offset * n_mmproj_embd,
+            /*pos            =*/ pos_ptr,
+            /*n_seq_id       =*/ batch.n_seq_id + offset,
+            /*seq_id         =*/ batch.seq_id   + offset,
+            /*logits         =*/ batch.logits   + offset,
+        };
+    }
+};
+
+// Helper function for decoding an image whose embeddings have already been calculated
+int32_t mtmd_helper_decode_image_chunk(
+        mtmd_context * ctx,
+        struct llama_context * lctx,
+        const mtmd_input_chunk * chunk,
+        float * encoded_embd,
+        llama_pos n_past,
+        llama_seq_id seq_id,
+        int32_t n_batch,
+        llama_pos * new_n_past) {
+    if (mtmd_input_chunk_get_type(chunk) != MTMD_INPUT_CHUNK_TYPE_IMAGE) {
+        LOG_ERR("failed to decode image chunk: input chunk not of image type\n");
+        return -1;
+    }
+    const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
+    if (!image_tokens) {
+        LOG_ERR("failed to decode image chunk: image tokens are null\n");
+        return -1;
+    }
+
+    const llama_model * model = llama_get_model(lctx);
+    int n_mmproj_embd = llama_model_n_embd(model);
+    int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1;
+
+    int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
+    int32_t i_batch = 0;
+    int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
+    decode_embd_batch batch_embd(encoded_embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
+
+    const int nx = mtmd_image_tokens_get_nx(image_tokens);
+    const int ny = mtmd_image_tokens_get_ny(image_tokens);
+
+    if (mtmd_decode_use_mrope(ctx)) {
+        batch_embd.set_position_mrope(n_past, nx, ny, seq_id);
+    } else {
+        batch_embd.set_position_normal(n_past, seq_id);
+    }
+
+    if (mtmd_decode_use_non_causal(ctx)) {
+        llama_set_causal_attn(lctx, false);
+        // TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
+    }
+
+    while (i_batch < n_img_batches) { // split into batches
+        int pos_offset = i_batch*n_batch;
+        int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
+        llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);
+
+        LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
+
+        int64_t t1 = ggml_time_ms();
+        int32_t ret = llama_decode(lctx, batch_embd_view);
+        if (ret != 0) {
+            LOG_ERR("failed to decode image\n");
+            llama_set_causal_attn(lctx, true); // restore causal attn
+            return ret;
+        }
+
+        LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1);
+
+        i_batch++;
+    }
+
+    n_past += mtmd_image_tokens_get_n_pos(image_tokens);
+    *new_n_past = n_past;
+
+    if (mtmd_decode_use_non_causal(ctx)) {
+        llama_set_causal_attn(lctx, true);
+    }
+    return 0;
+}
+
+int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
+        struct llama_context * lctx,
+        const mtmd_input_chunk * chunk,
+        llama_pos n_past,
+        llama_seq_id seq_id,
+        int32_t n_batch,
+        bool logits_last,
+        llama_pos * new_n_past) {
+    int32_t ret;
+    llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
+    auto chunk_type = mtmd_input_chunk_get_type(chunk);
+
+    if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
+        size_t n_tokens;
+        const auto tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
+        // LOG_INF("decoding text chunk, n_tokens = %zu\n", n_tokens);
+        size_t i = 0;
+        while (i < n_tokens) { // split into batches
+            text_batch.n_tokens = 0; // clear the batch
+            for (; i < n_tokens && text_batch.n_tokens < n_batch; i++) {
+                text_batch.n_tokens++;
+                text_batch.token   [i]    = tokens[i];
+                text_batch.pos     [i]    = n_past++;
+                text_batch.n_seq_id[i]    = 1;
+                text_batch.seq_id  [i][0] = seq_id;
+                text_batch.logits  [i]    = false;
+            }
+            bool is_last_token = (i == n_tokens);
+            if (logits_last && is_last_token) {
+                text_batch.logits[text_batch.n_tokens - 1] = true;
+            }
+            ret = llama_decode(lctx, text_batch);
+            if (ret != 0) {
+                LOG_ERR("failed to decode text\n");
+                llama_batch_free(text_batch);
+                return ret;
+            }
+            *new_n_past += text_batch.n_tokens;
+        }
+
+    } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
+        const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
+        int64_t t0 = ggml_time_ms();
+
+        LOG_INF("encoding image or slice...\n");
+
+        ret = mtmd_encode(ctx, image_tokens);
+        if (ret != 0) {
+            LOG_ERR("failed to encode image\n");
+            llama_batch_free(text_batch);
+            return ret;
+        }
+
+        LOG_INF("image/slice encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
+
+        float * embd = mtmd_get_output_embd(ctx);
+        ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past);
+        if (ret != 0) {
+            LOG_ERR("failed to decode image\n");
+            llama_batch_free(text_batch);
+            return ret;
+        }
+    } else {
+        GGML_ABORT("chunk type not supported");
+    }
+
+    return 0;
+}
+
+int32_t mtmd_helper_eval_chunks(mtmd_context * ctx,
+                                struct llama_context * lctx,
+                                const mtmd_input_chunks * chunks,
+                                llama_pos n_past,
+                                llama_seq_id seq_id,
+                                int32_t n_batch,
+                                bool logits_last,
+                                llama_pos * new_n_past) {
+    size_t n_chunks = mtmd_input_chunks_size(chunks);
+    if (n_chunks == 0) {
+        LOG_ERR("no chunks to eval\n");
+        return 0;
+    }
+
+    for (size_t i = 0; i < n_chunks; i++) {
+        bool chunk_logits_last = (i == n_chunks - 1) && logits_last;
+        auto chunk = mtmd_input_chunks_get(chunks, i);
+
+        int32_t res = mtmd_helper_eval_chunk_single(ctx, lctx, chunk, n_past, seq_id, n_batch, chunk_logits_last, &n_past);
+        if (res != 0) {
+            LOG_ERR("failed to eval chunk %zu\n", i);
+            return res;
+        }
+        *new_n_past = n_past;
+    }
+
+    return 0;
+}
index f1b957394767148e8a959314e50473b0029459a8..2a852d9c19bd29fb469db74e73527d40a9c93dad 100644 (file)
@@ -461,308 +461,27 @@ float * mtmd_get_output_embd(mtmd_context * ctx) {
     return ctx->image_embd_v.data();
 }
 
-size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks) {
-    size_t n_tokens = 0;
-    for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) {
-        auto chunk = mtmd_input_chunks_get(chunks, i);
-        auto chunk_type = mtmd_input_chunk_get_type(chunk);
-        if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
-            size_t n_tokens_text;
-            mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text);
-            n_tokens += n_tokens_text;
-        } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
-            auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk);
-            n_tokens += mtmd_image_tokens_get_n_tokens(tokens_image);
-        } else {
-            GGML_ASSERT(false && "chunk type not supported");
-        }
-    }
-    return n_tokens;
-}
-
-llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) {
-    llama_pos n_pos = 0;
-    for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) {
-        auto chunk = mtmd_input_chunks_get(chunks, i);
-        auto chunk_type = mtmd_input_chunk_get_type(chunk);
-        if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
-            size_t n_tokens_text;
-            mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text);
-            n_pos += n_tokens_text;
-        } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
-            auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk);
-            n_pos += mtmd_image_tokens_get_n_pos(tokens_image);
-        } else {
-            GGML_ASSERT(false && "chunk type not supported");
-        }
-    }
-    return n_pos;
-}
-
-// helper struct to make working with embd batch easier
-// note: this will be removed after llama_batch_ext refactoring
-struct decode_embd_batch {
-    int n_pos_per_embd;
-    int n_mmproj_embd;
-    std::vector<llama_pos>      pos;
-    std::vector<llama_pos>      pos_view; // used by mrope
-    std::vector<int32_t>        n_seq_id;
-    std::vector<llama_seq_id>   seq_id_0;
-    std::vector<llama_seq_id *> seq_ids;
-    std::vector<int8_t>         logits;
-    llama_batch batch;
-    decode_embd_batch(float * embd, int32_t n_tokens, int n_pos_per_embd, int n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) {
-        pos     .resize(n_tokens * n_pos_per_embd);
-        n_seq_id.resize(n_tokens);
-        seq_ids .resize(n_tokens + 1);
-        logits  .resize(n_tokens);
-        seq_id_0.resize(1);
-        seq_ids [n_tokens] = nullptr;
-        batch = {
-            /*n_tokens       =*/ n_tokens,
-            /*tokens         =*/ nullptr,
-            /*embd           =*/ embd,
-            /*pos            =*/ pos.data(),
-            /*n_seq_id       =*/ n_seq_id.data(),
-            /*seq_id         =*/ seq_ids.data(),
-            /*logits         =*/ logits.data(),
-        };
-    }
-
-    void set_position_normal(llama_pos pos_0, llama_seq_id seq_id) {
-        seq_id_0[0] = seq_id;
-        for (int i = 0; i < batch.n_tokens; i++) {
-            batch.pos     [i] = pos_0 + i;
-            batch.n_seq_id[i] = 1;
-            batch.seq_id  [i] = seq_id_0.data();
-            batch.logits  [i] = false;
-        }
-    }
-
-    void set_position_mrope(llama_pos pos_0, int nx, int ny, llama_seq_id seq_id) {
-        GGML_ASSERT(n_pos_per_embd == 4);
-        seq_id_0[0] = seq_id;
-        for (int y = 0; y < ny; y++) {
-            for (int x = 0; x < nx; x++) {
-                int i = y * nx + x;
-                pos[i                     ] = pos_0;
-                pos[i + batch.n_tokens    ] = pos_0 + y;
-                pos[i + batch.n_tokens * 2] = pos_0 + x;
-                pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused
-            }
-        }
-        for (int i = 0; i < batch.n_tokens; i++) {
-            batch.n_seq_id[i] = 1;
-            batch.seq_id  [i] = seq_id_0.data();
-            batch.logits  [i] = false;
-        }
-    }
-
-    llama_batch get_view(int offset, int n_tokens) {
-        llama_pos * pos_ptr;
-        pos_view.clear();
-        pos_view.reserve(n_tokens * n_pos_per_embd);
-        if (n_pos_per_embd > 1) {
-            // mrope
-            // for example, with layout of src: 1234...1234...1234...1234...
-            //       offset 2 will give us dst: 34...34...34...34...
-            for (int i = 0; i < n_pos_per_embd; i++) {
-                // assume n_tokens is less than or equal to batch.n_tokens
-                // batch.n_tokens is number of **total** tokens
-                // n_tokens is number of viewed token
-                size_t src_idx = i * batch.n_tokens + offset;
-                pos_view.insert(pos_view.end(),
-                    pos.data() + src_idx,
-                    pos.data() + src_idx + n_tokens);
-            }
-            pos_ptr = pos_view.data();
-        } else {
-            // normal
-            pos_ptr = pos.data() + offset;
-        }
-        return {
-            /*n_tokens       =*/ n_tokens,
-            /*tokens         =*/ nullptr,
-            /*embd           =*/ batch.embd     + offset * n_mmproj_embd,
-            /*pos            =*/ pos_ptr,
-            /*n_seq_id       =*/ batch.n_seq_id + offset,
-            /*seq_id         =*/ batch.seq_id   + offset,
-            /*logits         =*/ batch.logits   + offset,
-        };
-    }
-};
-
-// Helper function for decoding an image whose embeddings have already been calculated
-int32_t mtmd_helper_decode_image_chunk(
-        mtmd_context * ctx,
-        struct llama_context * lctx,
-        const mtmd_input_chunk * chunk,
-        float * encoded_embd,
-        llama_pos n_past,
-        llama_seq_id seq_id,
-        int32_t n_batch,
-        llama_pos * new_n_past) {
-    if (mtmd_input_chunk_get_type(chunk) != MTMD_INPUT_CHUNK_TYPE_IMAGE) {
-        LOG_ERR("failed to decode image chunk: input chunk not of image type\n");
-        return -1;
-    }
-    const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
-    if (!image_tokens) {
-        LOG_ERR("failed to decode image chunk: image tokens are null\n");
-        return -1;
-    }
-
-    int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
-    int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1;
-
-    int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
-    int32_t i_batch = 0;
-    int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
-    decode_embd_batch batch_embd(encoded_embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
-
-    const int nx = mtmd_image_tokens_get_nx(image_tokens);
-    const int ny = mtmd_image_tokens_get_ny(image_tokens);
-
-    if (mtmd_decode_use_mrope(ctx)) {
-        batch_embd.set_position_mrope(n_past, nx, ny, seq_id);
-    } else {
-        batch_embd.set_position_normal(n_past, seq_id);
-    }
-
-    if (mtmd_decode_use_non_causal(ctx)) {
-        llama_set_causal_attn(lctx, false);
-        // TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
-    }
-
-    while (i_batch < n_img_batches) { // split into batches
-        int pos_offset = i_batch*n_batch;
-        int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
-        llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);
-
-        LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
-
-        int64_t t1 = ggml_time_ms();
-        int32_t ret = llama_decode(lctx, batch_embd_view);
-        if (ret != 0) {
-            LOG_ERR("failed to decode image\n");
-            llama_set_causal_attn(lctx, true); // restore causal attn
-            return ret;
-        }
-
-        if (ctx->print_timings) {
-            LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1);
-        }
-
-        i_batch++;
-    }
-
-    n_past += mtmd_image_tokens_get_n_pos(image_tokens);
-    *new_n_past = n_past;
-
-    if (mtmd_decode_use_non_causal(ctx)) {
-        llama_set_causal_attn(lctx, true);
+bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
+    projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
+    if (proj_type == PROJECTOR_TYPE_GEMMA3) {
+        return true;
     }
-    return 0;
+    return false;
 }
 
-int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
-        struct llama_context * lctx,
-        const mtmd_input_chunk * chunk,
-        llama_pos n_past,
-        llama_seq_id seq_id,
-        int32_t n_batch,
-        bool logits_last,
-        llama_pos * new_n_past) {
-    int32_t ret;
-    llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
-    auto chunk_type = mtmd_input_chunk_get_type(chunk);
-
-    if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
-        size_t n_tokens;
-        const auto tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
-        LOG_DBG("decoding text chunk, n_tokens = %zu\n", n_tokens);
-        size_t i = 0;
-        while (i < n_tokens) { // split into batches
-            text_batch.n_tokens = 0; // clear the batch
-            for (; i < n_tokens && text_batch.n_tokens < n_batch; i++) {
-                text_batch.n_tokens++;
-                text_batch.token   [i]    = tokens[i];
-                text_batch.pos     [i]    = n_past++;
-                text_batch.n_seq_id[i]    = 1;
-                text_batch.seq_id  [i][0] = seq_id;
-                text_batch.logits  [i]    = false;
-            }
-            bool is_last_token = (i == n_tokens);
-            if (logits_last && is_last_token) {
-                text_batch.logits[text_batch.n_tokens - 1] = true;
-            }
-            ret = llama_decode(lctx, text_batch);
-            if (ret != 0) {
-                LOG_ERR("failed to decode text\n");
-                llama_batch_free(text_batch);
-                return ret;
-            }
-            *new_n_past += text_batch.n_tokens;
-        }
-
-    } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
-        const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
-        int64_t t0 = ggml_time_ms();
-        if (ctx->print_timings) {
-            LOG_INF("encoding image or slice...\n");
-        }
-        ret = mtmd_encode(ctx, image_tokens);
-        if (ret != 0) {
-            LOG_ERR("failed to encode image\n");
-            llama_batch_free(text_batch);
-            return ret;
-        }
-        if (ctx->print_timings) {
-            LOG_INF("image/slice encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
-        }
-        float * embd = mtmd_get_output_embd(ctx);
-        ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past);
-        if (ret != 0) {
-            LOG_ERR("failed to decode image\n");
-            llama_batch_free(text_batch);
-            return ret;
-        }
-    } else {
-        GGML_ABORT("chunk type not supported");
-    }
-
-    return 0;
+bool mtmd_decode_use_mrope(mtmd_context * ctx) {
+    return ctx->use_mrope;
 }
 
-int32_t mtmd_helper_eval_chunks(mtmd_context * ctx,
-                                struct llama_context * lctx,
-                                const mtmd_input_chunks * chunks,
-                                llama_pos n_past,
-                                llama_seq_id seq_id,
-                                int32_t n_batch,
-                                bool logits_last,
-                                llama_pos * new_n_past) {
-    size_t n_chunks = mtmd_input_chunks_size(chunks);
-    if (n_chunks == 0) {
-        LOG_WRN("no chunks to eval\n");
-        return 0;
-    }
-
-    for (size_t i = 0; i < n_chunks; i++) {
-        bool chunk_logits_last = (i == n_chunks - 1) && logits_last;
-        auto chunk = mtmd_input_chunks_get(chunks, i);
-
-        int32_t res = mtmd_helper_eval_chunk_single(ctx, lctx, chunk, n_past, seq_id, n_batch, chunk_logits_last, &n_past);
-        if (res != 0) {
-            LOG_ERR("failed to eval chunk %zu\n", i);
-            return res;
-        }
-        *new_n_past = n_past;
-    }
-
-    return 0;
+void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) {
+    mtmd_image_tokens_free(val);
 }
 
+// these 2 helpers below use internal clip_image_u8_ptr,
+// so unfortunately they cannot moved to mtmd-helper.h
+// however, in theory, user can decode image file to bitmap using
+// whichever library they want, and then use mtmd_bitmap_init() to create bitmap
+
 mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len) {
     clip_image_u8_ptr img_u8(clip_image_u8_init());
     bool ok = clip_image_load_from_bytes(buf, len, img_u8.get());
@@ -787,23 +506,6 @@ mtmd_bitmap * mtmd_helper_bitmap_init_from_file(const char * fname) {
     return mtmd_bitmap_init(nx, ny, data);
 }
 
-bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
-    projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
-    if (proj_type == PROJECTOR_TYPE_GEMMA3) {
-        return true;
-    }
-    return false;
-}
-
-bool mtmd_decode_use_mrope(mtmd_context * ctx) {
-    return ctx->use_mrope;
-}
-
-void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) {
-    mtmd_image_tokens_free(val);
-}
-
-
 //
 // public API functions
 //
index 54cf481b6aa9483dfd64714a4a66da88badeea03..0ada78c90f6788d8f9206021a2cebbd0852da13c 100644 (file)
@@ -10,6 +10,7 @@
 #include <stdbool.h>
 
 #ifdef __cplusplus
+#include <string>
 #include <vector>
 #include <cinttypes>
 #include <memory>