]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : improve mtmd ctx checkpoints (#20726)
authorGeorgi Gerganov <redacted>
Fri, 20 Mar 2026 09:13:12 +0000 (11:13 +0200)
committerGitHub <redacted>
Fri, 20 Mar 2026 09:13:12 +0000 (11:13 +0200)
* server : improve mtmd ctx checkpoints

* server : fix off-by-one in pos_min_thold

tools/server/server-context.cpp

index 2e43bebd6f959797da389ecbe5592c185302d514..9de554e9007bc13d6e99e7b945a9f45904bc99e4 100644 (file)
@@ -2307,8 +2307,8 @@ private:
 
                             llama_pos pos_next = slot.prompt.tokens.pos_next(n_past);
 
-                            // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
-                            const auto n_swa = std::max(1, llama_model_n_swa(model));
+                            // note: when n_swa == 0, the model does not use SWA
+                            const auto n_swa = std::max(0, llama_model_n_swa(model));
 
                             // the largest pos_min required for a checkpoint to be useful
                             const auto pos_min_thold = std::max(0, pos_next - n_swa);
@@ -2363,7 +2363,7 @@ private:
                                     SLT_WRN(slot, "%s\n", st1.str().c_str());
                                 }
 
-                                if (pos_min > pos_min_thold) {
+                                if (pos_min >= pos_min_thold) {
                                     SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
 
                                     // search for a context checkpoint
@@ -2459,31 +2459,6 @@ private:
                         slot.n_prompt_tokens_cache = 0;
                     }
 
-                    bool do_checkpoint = params_base.n_ctx_checkpoints > 0;
-
-                    // check if we should process the image
-                    if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
-                        // process the image
-                        size_t n_tokens_out = 0;
-                        int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
-                        if (res != 0) {
-                            SLT_ERR(slot, "failed to process image, res = %d\n", res);
-                            send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
-                            slot.release();
-                            continue;
-                        }
-
-                        slot.n_prompt_tokens_processed += n_tokens_out;
-
-                        // add the image chunk to cache
-                        {
-                            const auto & chunk = input_tokens.find_chunk(slot.prompt.n_tokens());
-                            slot.prompt.tokens.push_back(chunk.get()); // copy
-                        }
-
-                        do_checkpoint = false; // do not checkpoint right after an image chunk
-                    }
-
                     // If using an alora, there may be uncached tokens that come
                     // before the invocation sequence. When this happens, the
                     // tokens before the invocation sequence need to be
@@ -2498,6 +2473,8 @@ private:
                         alora_disabled_id = enabled_loras[0];
                     }
 
+                    bool do_checkpoint = params_base.n_ctx_checkpoints > 0;
+
                     // make checkpoints only for completion tasks
                     do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION;
 
@@ -2513,6 +2490,31 @@ private:
                             (llama_model_n_swa(model) > 0 && !params_base.swa_full)
                             );
 
+                    bool has_mtmd = false;
+
+                    // check if we should process the image
+                    if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
+                        // process the image
+                        size_t n_tokens_out = 0;
+                        int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
+                        if (res != 0) {
+                            SLT_ERR(slot, "failed to process image, res = %d\n", res);
+                            send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
+                            slot.release();
+                            continue;
+                        }
+
+                        slot.n_prompt_tokens_processed += n_tokens_out;
+
+                        // add the image chunk to cache
+                        {
+                            const auto & chunk = input_tokens.find_chunk(slot.prompt.n_tokens());
+                            slot.prompt.tokens.push_back(chunk.get()); // copy
+                        }
+
+                        has_mtmd = true;
+                    }
+
                     // add prompt tokens for processing in the current batch
                     while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) {
                         // get next token to process
@@ -2544,13 +2546,13 @@ private:
                         //  - 4 + n_ubatch
                         //  - 4
                         // ref: https://github.com/ggml-org/llama.cpp/pull/20288
-                        {
+                        if (do_checkpoint) {
                             static const int checkpoint_offsets[] = {4 + n_ubatch, 4};
 
                             bool should_break = false;
                             for (int offset : checkpoint_offsets) {
                                 const int n_last = std::min(n_batch, offset);
-                                if (do_checkpoint && slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) {
+                                if (slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) {
                                     should_break = true;
                                     break;
                                 }
@@ -2607,10 +2609,13 @@ private:
                     const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
 
                     // no need for empty or small checkpoints
-                    do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
+                    do_checkpoint = do_checkpoint && (pos_min >= 0 && slot.prompt.n_tokens() >= 64);
+
+                    // do not checkpoint after mtmd chunks
+                    do_checkpoint = do_checkpoint && !has_mtmd;
 
                     // no need to create checkpoints that are too close together
-                    do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
+                    do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || slot.prompt.n_tokens() - n_tokens_cur > slot.prompt.checkpoints.back().n_tokens + 64);
 
                     // note: we create the checkpoint before calling llama_decode(), so the current batch is not
                     //       yet processed and therefore it is not part of the checkpoint.