]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : support multi-modal context checkpoints (#19849)
authorGeorgi Gerganov <redacted>
Wed, 25 Feb 2026 13:14:27 +0000 (15:14 +0200)
committerGitHub <redacted>
Wed, 25 Feb 2026 13:14:27 +0000 (15:14 +0200)
* Modify llama-memory-hybrid-iswa.cpp

* Modify llama-memory-recurrent.cpp

* Modify server-common.cpp

* Modify server-common.h

* Modify server-context.cpp

* Modify server-task.h

* Added comment to llama-memory-hybrid-iswa.cpp

* Remove comment from server-context.cpp

* Stylistic fix server-context.cpp

* Fix an issue when seqrm isn't called in server-context.cpp

* cont : alternative impl

* cont : cleanup

* cont : n_tokens -> int64_t

---------

Co-authored-by: timkhronos <redacted>
src/llama-memory-recurrent.cpp
tools/server/server-common.cpp
tools/server/server-common.h
tools/server/server-context.cpp
tools/server/server-task.h

index f0038036dcb7f591e9e97d296245e2f009d868cc..6e8413f493d072a0447a90904fe617b1b1d6e397 100644 (file)
@@ -163,7 +163,7 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
             const auto & cell = cells[tail_id];
             // partial intersection is invalid if it includes the final pos
             if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) {
-                //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n");
+                //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1);
                 return false;
             }
             // invalidate tails which will be cleared
index 88b6e77d82b55fbff715b35f589fa2b0ecdba7e1..ff3c6d3c2b054040efb9acb315b3791b84be5799 100644 (file)
@@ -231,19 +231,77 @@ server_tokens::server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) :
 server_tokens::server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {
 }
 
-llama_pos server_tokens::pos_next() const {
+llama_pos server_tokens::pos_next(int64_t n_tokens) const {
     if (!has_mtmd) {
-        return tokens.size();
+        if (n_tokens < 0) {
+            return tokens.size();
+        }
+
+        return n_tokens;
     }
 
-    llama_pos res = tokens.size();
+    if (n_tokens < 0) {
+        llama_pos res = tokens.size();
 
-    for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
-        const auto & chunk = it->second;
-        res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get());
+        for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
+            const auto & chunk = it->second;
+            res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get());
+        }
+
+        return res;
     }
 
-    return res;
+    int64_t idx = 0;
+    llama_pos pos = 0;
+
+    GGML_ASSERT(n_tokens <= (int64_t)tokens.size());
+
+    while (idx < n_tokens) {
+        const auto media_it = map_idx_to_media.find(idx);
+        if (media_it != map_idx_to_media.end()) {
+            const auto & chunk = media_it->second;
+            const llama_pos n_pos = mtmd_input_chunk_get_n_pos(chunk.get());
+            const size_t n_tok = mtmd_input_chunk_get_n_tokens(chunk.get());
+
+            pos += n_pos;
+            idx += n_tok;
+        } else {
+            pos++;
+            idx++;
+        }
+    }
+
+    return pos;
+}
+
+size_t server_tokens::size_up_to_pos(llama_pos max_pos) const {
+    if (!has_mtmd) {
+        return std::min((size_t)(max_pos + 1), tokens.size());
+    }
+
+    size_t idx = 0;
+    llama_pos pos = 0;
+
+    while (idx < tokens.size()) {
+        const auto media_it = map_idx_to_media.find(idx);
+        if (media_it != map_idx_to_media.end()) {
+            const auto & chunk = media_it->second;
+            const llama_pos n_pos = mtmd_input_chunk_get_n_pos(chunk.get());
+            const size_t n_tok = mtmd_input_chunk_get_n_tokens(chunk.get());
+
+            pos += n_pos;
+            idx += n_tok;
+        } else {
+            pos++;
+            idx++;
+        }
+
+        if (pos > max_pos) {
+            break;
+        }
+    }
+
+    return idx;
 }
 
 std::string server_tokens::str() const {
index 2629a6bee9274fa785755b66a085673db1bdf56f..4fb9e488dfdec6c6b3ab9fc86f4a42dfb59906da 100644 (file)
@@ -167,7 +167,12 @@ public:
     // for debugging
     std::string str() const;
 
-    llama_pos pos_next() const;
+    // the next position after n_tokens. if n_tokens < 0, return the next position after all tokens.
+    llama_pos pos_next(int64_t n_tokens = -1) const;
+
+    // number of tokens with position <= max_pos
+    size_t size_up_to_pos(llama_pos max_pos) const;
+
     const mtmd::input_chunk_ptr & find_chunk(size_t idx) const;
 
     void push_back(llama_token tok);
index 0f2f3a45aaae6460cc2d21436b7e907b0e90e46d..67c3988bd05402117ce0b189ef603e8adb4c94bc 100644 (file)
@@ -1442,7 +1442,7 @@ private:
         res->id      = slot.task->id;
         res->id_slot = slot.id;
 
-        res->index           = slot.task->index;
+        res->index = slot.task->index;
 
         // keep copy of last generated text for debugging purposes
         if (slots_debug) {
@@ -2282,15 +2282,15 @@ private:
                                 n_past = 0;
                             }
 
+                            llama_pos pos_next = slot.prompt.tokens.pos_next(n_past);
+
                             // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
                             const auto n_swa = std::max(1, llama_model_n_swa(model));
 
                             // the largest pos_min required for a checkpoint to be useful
-                            const auto pos_min_thold = std::max(0, n_past - n_swa);
+                            const auto pos_min_thold = std::max(0, pos_next - n_swa);
 
-                            // note: disallow with mtmd contexts for now
-                            //       https://github.com/ggml-org/llama.cpp/issues/17043
-                            if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) {
+                            if (n_past > 0 && n_past < slot.prompt.n_tokens()) {
                                 const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
                                 if (pos_min == -1) {
                                     SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
@@ -2341,9 +2341,6 @@ private:
                                 }
 
                                 if (pos_min > pos_min_thold) {
-                                    // TODO: support can be added in the future when corresponding vision models get released
-                                    GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
-
                                     SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
 
                                     // search for a context checkpoint
@@ -2364,18 +2361,20 @@ private:
                                         const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
 
                                         if (n != checkpoint_size) {
-                                            SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
+                                            SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024);
                                             do_reset = true;
                                             //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
                                         } else {
-                                            n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max));
-                                            SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
+                                            pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max));
+                                            n_past = slot.prompt.tokens.size_up_to_pos(pos_next);
+                                            SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024);
                                         }
                                     }
 
                                     if (do_reset) {
                                         SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
                                                 "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
+                                        pos_next = 0;
                                         n_past = 0;
                                     }
                                 }
@@ -2386,7 +2385,7 @@ private:
                                 for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
                                     const auto & cur = *it;
                                     if (cur.pos_min > pos_min_thold) {
-                                        SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
+                                        SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, (float) cur.data.size() / 1024 / 1024);
                                         it = slot.prompt.checkpoints.erase(it);
                                     } else {
                                         ++it;
@@ -2402,7 +2401,7 @@ private:
                             SLT_WRN(slot, "n_past was set to %d\n", n_past);
                         }
 
-                        slot.n_prompt_tokens_cache     = n_past;
+                        slot.n_prompt_tokens_cache = n_past;
                         slot.n_prompt_tokens_processed = 0;
 
                         slot.prompt.tokens.keep_first(n_past);
@@ -2520,10 +2519,6 @@ private:
                         }
                     }
 
-                    // SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str());
-
-                    SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
-
                     // entire prompt has been processed
                     if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
                         slot.state = SLOT_STATE_DONE_PROMPT;
@@ -2536,8 +2531,6 @@ private:
                         slot.n_decoded = 0;
                         slot.i_batch   = batch.n_tokens - 1;
 
-                        SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
-
                         slot.init_sampler();
 
                         const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
@@ -2549,13 +2542,15 @@ private:
                         // no need to create checkpoints that are too close together
                         do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
 
+                        // note: we create the checkpoint before calling llama_decode(), so the current batch is not
+                        //       yet processed and therefore it is not part of the checkpoint.
                         if (do_checkpoint) {
                             while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
                                 // make room for the new checkpoint, if needed
                                 const auto & cur = slot.prompt.checkpoints.front();
 
-                                SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
-                                        cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
+                                SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
+                                        cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
 
                                 slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
                             }
@@ -2563,16 +2558,21 @@ private:
                             const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
 
                             auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
-                                /*.pos_min = */ pos_min,
-                                /*.pos_max = */ pos_max,
-                                /*.data    = */ std::vector<uint8_t>(checkpoint_size),
+                                /*.pos_min  = */ pos_min,
+                                /*.pos_max  = */ pos_max,
+                                /*.n_tokens = */ slot.prompt.n_tokens() - batch.n_tokens,
+                                /*.data     = */ std::vector<uint8_t>(checkpoint_size),
                             });
 
                             llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
 
-                            SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
-                                    (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
+                            SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
+                                    (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
                         }
+
+                        SLT_INF(slot, "prompt processing done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
+                    } else {
+                        SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
                     }
                 }
 
index a69e8f1a3d2b4a30b5db563165bfe9072090a6f5..e2e3e5a582846045edcbe8e3b87f291718ef3360 100644 (file)
@@ -557,6 +557,8 @@ struct server_prompt_checkpoint {
     llama_pos pos_min;
     llama_pos pos_max;
 
+    int64_t n_tokens;
+
     std::vector<uint8_t> data;
 
     size_t size() const {