]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
mtmd : Expose helper_decode_image_chunk (#13366)
authorMatt Clayton <redacted>
Thu, 8 May 2025 18:25:39 +0000 (14:25 -0400)
committerGitHub <redacted>
Thu, 8 May 2025 18:25:39 +0000 (20:25 +0200)
* mtmd: Expose helper_decode_image, output_embd_copy, image_tokens_copy/free

* Slim down

* Cleanups

tools/mtmd/mtmd.cpp
tools/mtmd/mtmd.h

index b600e4341375fa30368779ee55081cba7b0ad985..5d18e8929b31f9a76b9fa0053d226e40bfcbd88d 100644 (file)
@@ -580,6 +580,79 @@ struct decode_embd_batch {
     }
 };
 
+// Helper function for decoding an image whose embeddings have already been calculated
+int32_t mtmd_helper_decode_image_chunk(
+        mtmd_context * ctx,
+        struct llama_context * lctx,
+        const mtmd_input_chunk * chunk,
+        float * encoded_embd,
+        llama_pos n_past,
+        llama_seq_id seq_id,
+        int32_t n_batch,
+        llama_pos * new_n_past) {
+    if (mtmd_input_chunk_get_type(chunk) != MTMD_INPUT_CHUNK_TYPE_IMAGE) {
+        LOG_ERR("failed to decode image chunk: input chunk not of image type\n");
+        return -1;
+    }
+    const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
+    if (!image_tokens) {
+        LOG_ERR("failed to decode image chunk: image tokens are null\n");
+        return -1;
+    }
+
+    int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
+    int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1;
+
+    int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
+    int32_t i_batch = 0;
+    int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
+    decode_embd_batch batch_embd(encoded_embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
+
+    const int nx = mtmd_image_tokens_get_nx(image_tokens);
+    const int ny = mtmd_image_tokens_get_ny(image_tokens);
+
+    if (mtmd_decode_use_mrope(ctx)) {
+        batch_embd.set_position_mrope(n_past, nx, ny, seq_id);
+    } else {
+        batch_embd.set_position_normal(n_past, seq_id);
+    }
+
+    if (mtmd_decode_use_non_causal(ctx)) {
+        llama_set_causal_attn(lctx, false);
+        // TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
+    }
+
+    while (i_batch < n_img_batches) { // split into batches
+        int pos_offset = i_batch*n_batch;
+        int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
+        llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);
+
+        LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
+
+        int64_t t1 = ggml_time_ms();
+        int32_t ret = llama_decode(lctx, batch_embd_view);
+        if (ret != 0) {
+            LOG_ERR("failed to decode image\n");
+            llama_set_causal_attn(lctx, true); // restore causal attn
+            return ret;
+        }
+
+        if (ctx->print_timings) {
+            LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1);
+        }
+
+        i_batch++;
+    }
+
+    n_past += mtmd_image_tokens_get_n_pos(image_tokens);
+    *new_n_past = n_past;
+
+    if (mtmd_decode_use_non_causal(ctx)) {
+        llama_set_causal_attn(lctx, true);
+    }
+    return 0;
+}
+
 int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
         struct llama_context * lctx,
         const mtmd_input_chunk * chunk,
@@ -591,8 +664,6 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
     int32_t ret;
     llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
     auto chunk_type = mtmd_input_chunk_get_type(chunk);
-    int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
-    int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1;
 
     if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
         size_t n_tokens;
@@ -637,57 +708,13 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
         if (ctx->print_timings) {
             LOG_INF("image/slice encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
         }
-
-        int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
-        int32_t i_batch = 0;
-        int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
         float * embd = mtmd_get_output_embd(ctx);
-        decode_embd_batch batch_embd(embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
-
-        const int nx = mtmd_image_tokens_get_nx(image_tokens);
-        const int ny = mtmd_image_tokens_get_ny(image_tokens);
-
-        if (mtmd_decode_use_mrope(ctx)) {
-            batch_embd.set_position_mrope(n_past, nx, ny, seq_id);
-        } else {
-            batch_embd.set_position_normal(n_past, seq_id);
-        }
-
-        if (mtmd_decode_use_non_causal(ctx)) {
-            llama_set_causal_attn(lctx, false);
-            // TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
-        }
-
-        while (i_batch < n_img_batches) { // split into batches
-            int pos_offset = i_batch*n_batch;
-            int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
-            llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);
-
-            LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
-
-            int64_t t1 = ggml_time_ms();
-            ret = llama_decode(lctx, batch_embd_view);
-            if (ret != 0) {
-                LOG_ERR("failed to decode image\n");
-                llama_set_causal_attn(lctx, true); // restore causal attn
-                llama_batch_free(text_batch);
-                return ret;
-            }
-
-            if (ctx->print_timings) {
-                LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1);
-            }
-
-            i_batch++;
-        }
-
-        n_past += mtmd_image_tokens_get_n_pos(image_tokens);
-        *new_n_past = n_past;
-
-        if (mtmd_decode_use_non_causal(ctx)) {
-            llama_set_causal_attn(lctx, true);
+        ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past);
+        if (ret != 0) {
+            LOG_ERR("failed to decode image\n");
+            llama_batch_free(text_batch);
+            return ret;
         }
-
     } else {
         GGML_ABORT("chunk type not supported");
     }
index e2f76e2e8d346cd7b0970798d48d5b9bdbb34028..54cf481b6aa9483dfd64714a4a66da88badeea03 100644 (file)
@@ -231,6 +231,18 @@ MTMD_API int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
                                                bool logits_last,
                                                llama_pos * new_n_past);
 
+// helper function to decode an image whose embeddings have already been calculated
+// this helper will handle batching and pre/post decoding setup (for ex. gemma 3 requires non-causal attention)
+// ret 0 on success, -1 on chunk not being a valid image chunk, 1 on decode failure
+MTMD_API int32_t mtmd_helper_decode_image_chunk(mtmd_context * ctx,
+                                                struct llama_context * lctx,
+                                                const mtmd_input_chunk * chunk,
+                                                float * encoded_embd,
+                                                llama_pos n_past,
+                                                llama_seq_id seq_id,
+                                                int32_t n_batch,
+                                                llama_pos * new_n_past);
+
 /////////////////////////////////////////
 
 // test function, to be used in test-mtmd-c-api.c