]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : remove n_past (#16818)
authorGeorgi Gerganov <redacted>
Thu, 30 Oct 2025 16:42:57 +0000 (18:42 +0200)
committerGitHub <redacted>
Thu, 30 Oct 2025 16:42:57 +0000 (18:42 +0200)
* server : remove n_past

* server : replace slot.n_prompt_tokens() with slot.task->n_tokens()

* server : fixes + clean-up

* cont : fix context shift

* server : add server_tokens::pos_next()

Co-authored-by: Xuan-Son Nguyen <redacted>
* server : fix pos_next() usage

Co-authored-by: Xuan-Son Nguyen <redacted>
---------

Co-authored-by: Xuan-Son Nguyen <redacted>
tools/server/README.md
tools/server/server.cpp
tools/server/utils.hpp

index f5ab9236d52167e34cebf869f78b1ba67eb17ec8..c16d0bd6dcd7f616d5bb3f7a8e645170a1b46933 100644 (file)
@@ -587,7 +587,7 @@ These words will not be included in the completion, so make sure to add them to
   - `word`: Stopped due to encountering a stopping word from `stop` JSON array provided
 - `stopping_word`: The stopping word encountered which stopped the generation (or "" if not stopped due to a stopping word)
 - `timings`: Hash of timing information about the completion such as the number of tokens `predicted_per_second`
-- `tokens_cached`: Number of tokens from the prompt which could be re-used from previous completion (`n_past`)
+- `tokens_cached`: Number of tokens from the prompt which could be re-used from previous completion
 - `tokens_evaluated`: Number of tokens evaluated in total from the prompt
 - `truncated`: Boolean indicating if the context size was exceeded during generation, i.e. the number of tokens provided in the prompt (`tokens_evaluated`) plus tokens generated (`tokens predicted`) exceeded the context size (`n_ctx`)
 
@@ -1045,7 +1045,7 @@ Available metrics:
 - `llamacpp:kv_cache_tokens`: KV-cache tokens.
 - `llamacpp:requests_processing`: Number of requests processing.
 - `llamacpp:requests_deferred`: Number of requests deferred.
-- `llamacpp:n_past_max`: High watermark of the context size observed.
+- `llamacpp:n_tokens_max`: High watermark of the context size observed.
 
 ### POST `/slots/{id_slot}?action=save`: Save the prompt cache of the specified slot to a file.
 
index cb794ab647eba1ea4997f989959b8ca9a10e180c..d56468595dd00b984e31e4ffd7aa007bfdcdfd7a 100644 (file)
@@ -292,6 +292,10 @@ struct server_task {
 
     server_task(server_task_type type) : type(type) {}
 
+    int32_t n_tokens() const {
+        return tokens.size();
+    }
+
     static slot_params params_from_json_cmpl(
             const llama_context * ctx,
             const common_params & params_base,
@@ -1308,7 +1312,7 @@ struct server_task_result_metrics : server_task_result {
     uint64_t n_tokens_predicted_total        = 0;
     uint64_t t_tokens_generation_total       = 0;
 
-    uint64_t n_past_max = 0;
+    uint64_t n_tokens_max = 0;
 
     uint64_t n_prompt_tokens_processed = 0;
     uint64_t t_prompt_processing       = 0;
@@ -1335,7 +1339,7 @@ struct server_task_result_metrics : server_task_result {
             { "n_tokens_predicted_total",        n_tokens_predicted_total },
             { "t_prompt_processing_total",       t_prompt_processing_total },
 
-            { "n_past_max",                      n_past_max },
+            { "n_tokens_max",                    n_tokens_max },
 
             { "n_prompt_tokens_processed",       n_prompt_tokens_processed },
             { "t_prompt_processing",             t_prompt_processing },
@@ -1636,7 +1640,6 @@ struct server_slot {
 
     // generation props
     int32_t n_ctx       = 0;  // context size per slot
-    int32_t n_past      = 0;
     int32_t n_keep      = 0;
     int32_t n_decoded   = 0;
     int32_t n_remaining = -1;
@@ -1645,10 +1648,6 @@ struct server_slot {
     int32_t n_prompt_tokens_cache     = 0;
     int32_t n_prompt_tokens_processed = 0;
 
-    int32_t n_prompt_tokens() const {
-        return task->tokens.size();
-    }
-
     size_t last_nl_pos = 0;
 
     std::string  generated_text;
@@ -1733,7 +1732,6 @@ struct server_slot {
         truncated      = false;
         stop           = STOP_TYPE_NONE;
         stopping_word  = "";
-        n_past         = 0;
         n_sent_text    = 0;
         chat_format    = COMMON_CHAT_FORMAT_CONTENT_ONLY;
 
@@ -1818,7 +1816,7 @@ struct server_slot {
         if (is_processing()) {
             GGML_ASSERT(task);
 
-            SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated);
+            SLT_INF(*this, "stop processing: n_tokens = %d, truncated = %d\n", prompt.n_tokens(), truncated);
 
             t_last_used = ggml_time_us();
             t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
@@ -1970,7 +1968,7 @@ struct server_metrics {
     uint64_t n_tokens_predicted_total        = 0;
     uint64_t t_tokens_generation_total       = 0;
 
-    uint64_t n_past_max = 0;
+    uint64_t n_tokens_max = 0;
 
     uint64_t n_prompt_tokens_processed = 0;
     uint64_t t_prompt_processing       = 0;
@@ -1991,9 +1989,7 @@ struct server_metrics {
         t_prompt_processing             += slot.t_prompt_processing;
         t_prompt_processing_total       += slot.t_prompt_processing;
 
-        if (slot.n_past > 0) {
-            n_past_max = std::max(n_past_max, (uint64_t) slot.n_past);
-        }
+        n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens());
     }
 
     void on_prediction(const server_slot & slot) {
@@ -2009,9 +2005,7 @@ struct server_metrics {
             if (slot.is_processing()) {
                 n_busy_slots_total++;
             }
-            if (slot.n_past > 0) {
-                n_past_max = std::max(n_past_max, (uint64_t) slot.n_past);
-            }
+            n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens());
         }
     }
 
@@ -2865,13 +2859,13 @@ struct server_context {
         }
 
         // if context shifting is disabled, make sure that we don't run out of context
-        if (!params_base.ctx_shift && slot.n_past + 1 >= slot.n_ctx) {
+        if (!params_base.ctx_shift && slot.prompt.n_tokens() + 1 >= slot.n_ctx) {
             slot.truncated      = true;
             slot.stop           = STOP_TYPE_LIMIT;
             slot.has_next_token = false;
 
-            SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
-                    slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx);
+            SLT_DBG(slot, "stopped due to running out of context capacity, prompt.n_tokens() = %d, task.n_tokens = %d, n_decoded = %d, n_ctx = %d\n",
+                    slot.prompt.n_tokens(), slot.task->n_tokens(), slot.n_decoded, slot.n_ctx);
         }
 
         // check the limits
@@ -2998,7 +2992,7 @@ struct server_context {
     }
 
     void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
-        send_error(slot.task->id, error, type, slot.n_prompt_tokens(), slot.n_ctx);
+        send_error(slot.task->id, error, type, slot.task->n_tokens(), slot.n_ctx);
     }
 
     void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) {
@@ -3035,7 +3029,7 @@ struct server_context {
 
         if (is_progress) {
             res->is_progress        = true;
-            res->progress.total     = slot.n_prompt_tokens();
+            res->progress.total     = slot.task->n_tokens();
             res->progress.cache     = slot.n_prompt_tokens_cache;
             res->progress.processed = slot.prompt.tokens.size();
             res->progress.time_ms   = (ggml_time_us() - slot.t_start_process_prompt / 1000);
@@ -3047,7 +3041,7 @@ struct server_context {
         }
 
         res->n_decoded           = slot.n_decoded;
-        res->n_prompt_tokens     = slot.n_prompt_tokens();
+        res->n_prompt_tokens     = slot.task->n_tokens();
         res->post_sampling_probs = slot.task->params.post_sampling_probs;
 
         res->verbose           = slot.task->params.verbose;
@@ -3083,8 +3077,8 @@ struct server_context {
 
         res->truncated           = slot.truncated;
         res->n_decoded           = slot.n_decoded;
-        res->n_prompt_tokens     = slot.n_prompt_tokens();
-        res->n_tokens_cached     = slot.n_past;
+        res->n_prompt_tokens     = slot.task->n_tokens();
+        res->n_tokens_cached     = slot.prompt.n_tokens();
         res->has_new_line        = slot.has_new_line;
         res->stopping_word       = slot.stopping_word;
         res->stop                = slot.stop;
@@ -3123,7 +3117,7 @@ struct server_context {
         auto res = std::make_unique<server_task_result_embd>();
         res->id        = slot.task->id;
         res->index     = slot.task->index;
-        res->n_tokens  = slot.n_prompt_tokens();
+        res->n_tokens  = slot.task->n_tokens();
         res->oaicompat = slot.task->params.oaicompat;
 
         const int n_embd = llama_model_n_embd(model);
@@ -3168,7 +3162,7 @@ struct server_context {
         auto res = std::make_unique<server_task_result_rerank>();
         res->id       = slot.task->id;
         res->index    = slot.task->index;
-        res->n_tokens = slot.n_prompt_tokens();
+        res->n_tokens = slot.task->n_tokens();
 
         for (int i = 0; i < batch.n_tokens; ++i) {
             if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
@@ -3375,7 +3369,7 @@ struct server_context {
                     res->n_tokens_predicted_total        = metrics.n_tokens_predicted_total;
                     res->t_tokens_generation_total       = metrics.t_tokens_generation_total;
 
-                    res->n_past_max = metrics.n_past_max;
+                    res->n_tokens_max = metrics.n_tokens_max;
 
                     res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed;
                     res->t_prompt_processing       = metrics.t_prompt_processing;
@@ -3551,7 +3545,7 @@ struct server_context {
         // apply context-shift if needed
         // TODO: simplify and improve
         for (server_slot & slot : slots) {
-            if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) {
+            if (slot.is_processing() && slot.prompt.n_tokens() + 1 >= slot.n_ctx) {
                 if (!params_base.ctx_shift) {
                     // this check is redundant (for good)
                     // we should never get here, because generation should already stopped in process_token()
@@ -3567,7 +3561,7 @@ struct server_context {
                 }
 
                 // Shift context
-                int n_keep = slot.task->params.n_keep < 0 ? slot.n_prompt_tokens() : slot.task->params.n_keep;
+                int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep;
 
                 if (add_bos_token) {
                     n_keep += 1;
@@ -3575,28 +3569,30 @@ struct server_context {
 
                 n_keep = std::min(slot.n_ctx - 4, n_keep);
 
-                const int n_left    = slot.n_past - n_keep;
+                const int n_left    = slot.prompt.n_tokens() - n_keep;
                 const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2);
 
                 SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
 
                 llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep            , n_keep + n_discard);
-                llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.n_past,        -n_discard);
+                llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard);
 
                 // add generated tokens to cache
+                // ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481
                 {
+                    GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
+
                     llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy
                     for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
                         new_tokens[i - n_discard] = new_tokens[i];
                     }
 
                     new_tokens.resize(slot.prompt.tokens.size() - n_discard);
+
                     slot.prompt.tokens.clear();
                     slot.prompt.tokens.insert(new_tokens);
                 }
 
-                slot.n_past -= n_discard;
-
                 slot.truncated = true;
             }
         }
@@ -3627,13 +3623,12 @@ struct server_context {
 
             slot.i_batch = batch.n_tokens;
 
-            common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
+            common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
 
-            slot.n_past += 1;
             slot.prompt.tokens.push_back(slot.sampled);
 
-            SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
-                    slot.n_ctx, slot.n_past, (int) slot.prompt.tokens.size(), slot.truncated);
+            SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
+                    slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
         }
 
         // process in chunks of params.n_batch
@@ -3663,11 +3658,10 @@ struct server_context {
                         slot.t_start_process_prompt = ggml_time_us();
                         slot.t_start_generation = 0;
 
-                        slot.n_past = 0;
                         slot.state = SLOT_STATE_PROCESSING_PROMPT;
 
-                        SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n",
-                                slot.n_ctx, slot.task->params.n_keep, slot.n_prompt_tokens());
+                        SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n",
+                                slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens());
 
                         // print prompt tokens (for debugging)
                         /*if (1) {
@@ -3682,6 +3676,9 @@ struct server_context {
                             }
                         }*/
 
+                        // keep track how many tokens we can reuse from the previous state
+                        int n_past = 0;
+
                         // empty prompt passed -> release the slot and send empty response
                         if (input_tokens.empty()) {
                             SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
@@ -3701,19 +3698,19 @@ struct server_context {
                         }
 
                         if (!slot.can_split()) {
-                            if (slot.n_prompt_tokens() > n_ubatch) {
+                            if (slot.task->n_tokens() > n_ubatch) {
                                 send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
                                 slot.release();
                                 continue;
                             }
 
-                            if (slot.n_prompt_tokens() > slot.n_ctx) {
+                            if (slot.task->n_tokens() > slot.n_ctx) {
                                 send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
                                 slot.release();
                                 continue;
                             }
                         } else {
-                            if (slot.n_prompt_tokens() >= slot.n_ctx) {
+                            if (slot.task->n_tokens() >= slot.n_ctx) {
                                 send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
                                 slot.release();
                                 continue;
@@ -3721,32 +3718,34 @@ struct server_context {
 
                             if (slot.task->params.cache_prompt) {
                                 // reuse any previously computed tokens that are common with the new prompt
-                                slot.n_past = slot.prompt.tokens.get_common_prefix(input_tokens);
+                                n_past = slot.prompt.tokens.get_common_prefix(input_tokens);
 
                                 // if there is an alora invoked, don't cache after the invocation start
-                                if (slot.alora_invocation_start >= 0) {
-                                    SLT_DBG(slot, "only caching to alora invocation start (n_past=%d, alora_invocation_start=%d)\n", slot.n_past, slot.alora_invocation_start);
-                                    slot.n_past = std::min(slot.n_past, slot.alora_invocation_start - 1);
+                                if (slot.alora_invocation_start > 0) {
+                                    SLT_DBG(slot, "only caching to alora invocation start (n_past = %d, alora_invocation_start = %d)\n", n_past, slot.alora_invocation_start);
+                                    n_past = std::min(n_past, slot.alora_invocation_start - 1);
                                 }
 
                                 // reuse chunks from the cached prompt by shifting their KV cache in the new position
                                 if (params_base.n_cache_reuse > 0) {
-                                    size_t head_c = slot.n_past; // cache
-                                    size_t head_p = slot.n_past; // current prompt
+                                    GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
+
+                                    size_t head_c = n_past; // cache
+                                    size_t head_p = n_past; // current prompt
 
                                     if (mctx) {
                                         // we should never reach this
                                         GGML_ABORT("not supported by multimodal");
                                     }
 
-                                    SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past);
+                                    SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", params_base.n_cache_reuse, n_past);
 
                                     while (head_c < slot.prompt.tokens.size() &&
                                            head_p < input_tokens.size()) {
 
                                         size_t n_match = 0;
                                         while (head_c + n_match < slot.prompt.tokens.size() &&
-                                               head_p + n_match < input_tokens.size()     &&
+                                               head_p + n_match < input_tokens.size()       &&
                                                slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) {
 
                                             n_match++;
@@ -3765,7 +3764,7 @@ struct server_context {
 
                                             for (size_t i = 0; i < n_match; i++) {
                                                 slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]);
-                                                slot.n_past++;
+                                                n_past++;
                                             }
 
                                             head_c += n_match;
@@ -3775,31 +3774,31 @@ struct server_context {
                                         }
                                     }
 
-                                    SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
+                                    SLT_DBG(slot, "after context reuse, new n_past = %d\n", n_past);
                                 }
                             } else {
-                                // if we don't cache the prompt, we have to remove the entire KV cache
-                                slot.n_past = 0;
+                                // if we don't cache the prompt, we have to remove all previous tokens
+                                n_past = 0;
                             }
 
                             // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
                             const auto n_swa = std::max(1, llama_model_n_swa(model));
 
                             // the largest pos_min required for a checkpoint to be useful
-                            const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
+                            const auto pos_min_thold = std::max(0, n_past - n_swa);
 
-                            if (slot.n_past > 0 && slot.n_past < (int) slot.prompt.tokens.size()) {
+                            if (n_past > 0 && n_past < slot.prompt.n_tokens()) {
                                 const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
                                 if (pos_min == -1) {
-                                    SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
+                                    SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
                                     GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
                                 }
 
                                 // when the prompt prefix does not match, print the tokens around the mismatch
                                 // this is useful for debugging prompt caching
                                 {
-                                    const int np0 = std::max<int>(slot.n_past - 4, 0);
-                                    const int np1 = std::min<int>(slot.n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size()));
+                                    const int np0 = std::max<int>(n_past - 4, 0);
+                                    const int np1 = std::min<int>(n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size()));
 
                                     std::stringstream ss0;
                                     std::stringstream ss1;
@@ -3811,7 +3810,7 @@ struct server_context {
                                     ss1 << "new: ... ";
 
                                     for (int i = np0; i < np1; i++) {
-                                        if (i == slot.n_past) {
+                                        if (i == n_past) {
                                             ss0 << " | ";
                                             ss1 << " | ";
                                         }
@@ -3839,7 +3838,10 @@ struct server_context {
                                 }
 
                                 if (pos_min > pos_min_thold) {
-                                    SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
+                                    // TODO: support can be added in the future when corresponding vision models get released
+                                    GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
+
+                                    SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
 
                                     // search for a context checkpoint
                                     const auto it = std::find_if(
@@ -3863,7 +3865,7 @@ struct server_context {
                                             do_reset = true;
                                             //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
                                         } else {
-                                            slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max));
+                                            n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max));
                                             SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
                                         }
                                     }
@@ -3871,7 +3873,7 @@ struct server_context {
                                     if (do_reset) {
                                         SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
                                                 "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
-                                        slot.n_past = 0;
+                                        n_past = 0;
                                     }
                                 }
                             }
@@ -3891,43 +3893,44 @@ struct server_context {
                         }
 
                         // [TAG_PROMPT_LOGITS]
-                        if (slot.n_past == slot.n_prompt_tokens() && slot.n_past > 0) {
-                            SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens());
-                            slot.n_past--;
-                            SLT_WRN(slot, "n_past was set to %d\n", slot.n_past);
+                        if (n_past == slot.task->n_tokens() && n_past > 0) {
+                            SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, task.n_tokens() = %d)\n", n_past, slot.task->n_tokens());
+                            n_past--;
+                            SLT_WRN(slot, "n_past was set to %d\n", n_past);
                         }
 
-                        slot.n_prompt_tokens_cache     = slot.n_past;
+                        slot.n_prompt_tokens_cache     = n_past;
                         slot.n_prompt_tokens_processed = 0;
+
+                        slot.prompt.tokens.keep_first(n_past);
                     }
 
                     if (!slot.can_split()) {
                         // cannot fit the prompt in the current batch - will try next iter
-                        if (batch.n_tokens + slot.n_prompt_tokens() > n_batch) {
+                        if (batch.n_tokens + slot.task->n_tokens() > n_batch) {
                             continue;
                         }
                     }
 
                     // truncate any tokens that are beyond n_past for this slot
-                    if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1)) {
-                        SLT_WRN(slot, "failed to truncate tokens beyond n_past = %d\n", slot.n_past);
+                    const llama_pos p0 = slot.prompt.tokens.pos_next();
+                    if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
+                        SLT_WRN(slot, "failed to truncate tokens with position >= %d\n", p0);
                         llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
 
                         // there is no common part left
-                        slot.n_past                = 0;
                         slot.n_prompt_tokens_cache = 0;
-                    }
 
-                    SLT_INF(slot, "n_past = %d, memory_seq_rm [%d, end)\n", slot.n_past, slot.n_past);
+                        slot.prompt.tokens.clear();
+                    }
 
-                    // remove the non-common part from the cache
-                    slot.prompt.tokens.keep_first(slot.n_past);
+                    SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
 
                     // check if we should process the image
-                    if (slot.n_past < slot.n_prompt_tokens() && input_tokens[slot.n_past] == LLAMA_TOKEN_NULL) {
+                    if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
                         // process the image
-                        int32_t new_n_past;
-                        int32_t res = input_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past);
+                        size_t n_tokens_out = 0;
+                        int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
                         if (res != 0) {
                             SLT_ERR(slot, "failed to process image, res = %d\n", res);
                             send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
@@ -3935,16 +3938,13 @@ struct server_context {
                             continue;
                         }
 
+                        slot.n_prompt_tokens_processed += n_tokens_out;
+
                         // add the image chunk to cache
                         {
-                            const auto & chunk = input_tokens.find_chunk(slot.n_past);
+                            const auto & chunk = input_tokens.find_chunk(slot.prompt.n_tokens());
                             slot.prompt.tokens.push_back(chunk.get()); // copy
                         }
-
-                        const int32_t n_pos = new_n_past - slot.n_past;
-
-                        slot.n_past                    += n_pos;
-                        slot.n_prompt_tokens_processed += n_pos;
                     }
 
                     // If using an alora, there may be uncached tokens that come
@@ -3952,8 +3952,8 @@ struct server_context {
                     // tokens before the invocation sequence need to be
                     // processed without the adpter in a separate batch, then
                     // the adapter needs to be enabled for the remaining tokens.
-                    if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.n_past) {
-                        SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_past = %d, alora_invocation_start = %d)\n", slot.n_past, slot.alora_invocation_start);
+                    if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) {
+                        SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start);
                         const auto & enabled_loras = lora_get_enabled_ids(slot.lora);
                         GGML_ASSERT(enabled_loras.size() == 1);
                         alora_scale = slot.lora[enabled_loras[0]].scale;
@@ -3979,9 +3979,9 @@ struct server_context {
                             );
 
                     // add prompt tokens for processing in the current batch
-                    while (slot.n_past < slot.n_prompt_tokens() && batch.n_tokens < n_batch) {
+                    while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) {
                         // get next token to process
-                        llama_token cur_tok = input_tokens[slot.n_past];
+                        llama_token cur_tok = input_tokens[slot.prompt.n_tokens()];
                         if (cur_tok == LLAMA_TOKEN_NULL) {
                             break; // end of text chunk
                         }
@@ -3989,30 +3989,33 @@ struct server_context {
                         // if this is an alora request with pre-invocation
                         // tokens that are not cached, we need to stop filling
                         // this batch at those pre-invocation tokens.
-                        if (alora_scale > 0 && slot.n_past == slot.alora_invocation_start - 1) {
-                            SLT_DBG(slot, "stop prompt batch filling at (n_past = %d, alora_invocation_start = %d)\n", slot.n_past, slot.alora_invocation_start);
+                        if (alora_scale > 0 && slot.prompt.n_tokens() == slot.alora_invocation_start - 1) {
+                            SLT_DBG(slot, "stop prompt batch filling at (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start);
                             break;
                         }
 
                         // embedding requires all tokens in the batch to be output
-                        common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, slot.need_embd());
+                        common_batch_add(batch,
+                            cur_tok,
+                            slot.prompt.tokens.pos_next(),
+                            { slot.id },
+                            slot.need_embd());
                         slot.prompt.tokens.push_back(cur_tok);
 
                         slot.n_prompt_tokens_processed++;
-                        slot.n_past++;
 
                         // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
-                        if (do_checkpoint && slot.n_prompt_tokens() - slot.n_past == 64) {
+                        if (do_checkpoint && slot.task->n_tokens() - slot.prompt.n_tokens() == 64) {
                             break;
                         }
                     }
 
                     // SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str());
 
-                    SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_past / slot.n_prompt_tokens());
+                    SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
 
                     // entire prompt has been processed
-                    if (slot.n_past == slot.n_prompt_tokens()) {
+                    if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
                         slot.state = SLOT_STATE_DONE_PROMPT;
 
                         GGML_ASSERT(batch.n_tokens > 0);
@@ -4020,7 +4023,7 @@ struct server_context {
                         common_sampler_reset(slot.smpl);
 
                         // Process all prompt tokens through sampler system
-                        for (int i = 0; i < slot.n_prompt_tokens(); ++i) {
+                        for (int i = 0; i < slot.task->n_tokens(); ++i) {
                             llama_token id = input_tokens[i];
                             if (id != LLAMA_TOKEN_NULL) {
                                 common_sampler_accept(slot.smpl, id, false);
@@ -4033,7 +4036,7 @@ struct server_context {
                         slot.n_decoded = 0;
                         slot.i_batch   = batch.n_tokens - 1;
 
-                        SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
+                        SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
 
                         const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
                         const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
@@ -4253,9 +4256,9 @@ struct server_context {
                 // determine the max draft that fits the current slot state
                 int n_draft_max = slot.task->params.speculative.n_max;
 
-                // note: n_past is not yet increased for the `id` token sampled above
+                // note: slot.prompt is not yet expanded with the `id` token sampled above
                 //       also, need to leave space for 1 extra token to allow context shifts
-                n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
+                n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.prompt.n_tokens() - 2);
 
                 if (slot.n_remaining > 0) {
                     n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
@@ -4291,10 +4294,10 @@ struct server_context {
 
                 // construct the speculation batch
                 common_batch_clear(slot.batch_spec);
-                common_batch_add  (slot.batch_spec, id, slot.n_past, { slot.id }, true);
+                common_batch_add  (slot.batch_spec, id, slot.prompt.tokens.pos_next(), { slot.id }, true);
 
                 for (size_t i = 0; i < draft.size(); ++i) {
-                    common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
+                    common_batch_add(slot.batch_spec, draft[i], slot.prompt.tokens.pos_next() + 1 + i, { slot.id }, true);
                 }
 
                 SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
@@ -4304,7 +4307,6 @@ struct server_context {
                 // the accepted tokens from the speculation
                 const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
 
-                slot.n_past    += ids.size();
                 slot.n_decoded += ids.size();
 
                 // update how many tokens out of those tested were accepted
@@ -4313,7 +4315,7 @@ struct server_context {
                 slot.prompt.tokens.push_back(id);
                 slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
 
-                llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1);
+                llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);
 
                 for (size_t i = 0; i < ids.size(); ++i) {
                     completion_token_output result;
@@ -4334,7 +4336,7 @@ struct server_context {
                     }
                 }
 
-                SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
+                SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.prompt.n_tokens());
             }
         }
 
@@ -4662,9 +4664,9 @@ int main(int argc, char ** argv) {
                     {"help",  "Total number of llama_decode() calls"},
                     {"value",  res_task->n_decode_total}
             }, {
-                    {"name",  "n_past_max"},
-                    {"help",  "Largest observed n_past."},
-                    {"value",  res_task->n_past_max}
+                    {"name",  "n_tokens_max"},
+                    {"help",  "Largest observed n_tokens."},
+                    {"value",  res_task->n_tokens_max}
             }, {
                     {"name",  "n_busy_slots_per_decode"},
                     {"help",  "Average number of busy slots per llama_decode() call"},
index cc48f5a9d0ac7f03a14a0afa8fa506566739c649..fb95361b282771f588ea4d853b4f625d9db0efe3 100644 (file)
@@ -1080,19 +1080,22 @@ struct server_tokens {
 
 private: // disallow accessing these members directly, risking out-of-sync
 
-    // map a **start** position in tokens to the image chunk
-    std::unordered_map<llama_pos, mtmd::input_chunk_ptr> map_pos_to_media;
+    // map a **start** index in tokens to the image chunk
+    // note: the order need to be in-sync with tokens
+    std::map<size_t, mtmd::input_chunk_ptr> map_idx_to_media;
 
     // list of tokens
-    // it can include LLAMA_TOKEN_NULL, which is used to indicate a token that is not a text token
-    // a mtmd_input_chunk can occupy multiple tokens, one llama_token per **position**
-    // important: for models using mrope, an image can contain multiple tokens but will use only one **position**
+    //   if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk
+    //   otherwise, it is a normal text token
+    // note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list
+    // note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos
     llama_tokens tokens;
 
-    // for ex. with input of 5 text tokens and 2 images:
-    //      [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1]
-    // pos  0   1   2   3   4   5      6      7      8      9
-    // map_pos_to_media will contain: {5, img0}, {8, img1}
+    // for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos):
+    //      [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1]
+    // idx  0   1   2   3   4   5      6      7      8      9      10
+    // pos  0   1   2   3   4   5      5      5      7      7      7
+    // map_idx_to_media will contain: {5, img0}, {8, img1}
 
 public:
     server_tokens() = default;
@@ -1117,13 +1120,31 @@ public:
         }
     }
 
-    server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {}
+    server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {
+    }
+
+    llama_pos pos_next() const {
+        if (!has_mtmd) {
+            return tokens.size();
+        }
+
+        llama_pos res = tokens.size();
+
+        for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
+            const auto & chunk = it->second;
+            res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get());
+        }
+
+        return res;
+    }
 
     // for debugging
     std::string str() const {
         std::ostringstream oss;
         oss << "tokens: ";
-        for (const auto & t : tokens) {
+        for (size_t idx = 0; idx < tokens.size(); ++idx) {
+            llama_token t = tokens[idx];
+            oss << "idx:" << idx << " ";
             if (t == LLAMA_TOKEN_NULL) {
                 oss << "<embd> ";
             } else {
@@ -1131,16 +1152,16 @@ public:
             }
         }
         oss << "\n";
-        oss << "image pos: ";
-        for (const auto & it : map_pos_to_media) {
+        oss << "image idx: ";
+        for (const auto & it : map_idx_to_media) {
             oss << it.first << ", ";
         }
         return oss.str();
     }
 
-    const mtmd::input_chunk_ptr & find_chunk(llama_pos pos) const {
-        auto it = map_pos_to_media.find(pos);
-        if (it != map_pos_to_media.end()) {
+    const mtmd::input_chunk_ptr & find_chunk(size_t idx) const {
+        auto it = map_idx_to_media.find(idx);
+        if (it != map_idx_to_media.end()) {
             return it->second;
         }
         throw std::runtime_error("Chunk not found");
@@ -1158,13 +1179,13 @@ public:
         auto type = mtmd_input_chunk_get_type(chunk);
         if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
             GGML_ASSERT(has_mtmd);
-            const int n_pos = mtmd_input_chunk_get_n_pos(chunk);
-            llama_pos start_pos = tokens.size();
-            for (int i = 0; i < n_pos; ++i) {
+            const size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk);
+            size_t start_idx = tokens.size();
+            for (size_t i = 0; i < n_tokens; ++i) {
                 tokens.emplace_back(LLAMA_TOKEN_NULL);
             }
             mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
-            map_pos_to_media[start_pos] = std::move(new_chunk);
+            map_idx_to_media[start_idx] = std::move(new_chunk);
         } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
             size_t n_tokens;
             const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
@@ -1178,7 +1199,7 @@ public:
 
     // appends server tokens, updates the media map. copies media chunks.
     void push_back(server_tokens & tokens) {
-        size_t start_pos = size();
+        size_t start_idx = size();
         for (size_t i = 0; i < tokens.size(); i++) {
             push_back(tokens[i]);
         }
@@ -1186,10 +1207,10 @@ public:
             // Assert if we are copying MTMD chunks to a server_tokens that does not have mtmd.
             // We could also just check, but this will prevent silently dropping MTMD data.
             GGML_ASSERT(has_mtmd);
-            for (auto it = tokens.map_pos_to_media.begin(); it != tokens.map_pos_to_media.end(); ) {
-                auto * chunk = tokens.map_pos_to_media[it->first].get();
+            for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) {
+                auto * chunk = tokens.map_idx_to_media[it->first].get();
                 mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
-                map_pos_to_media[start_pos+it->first] = std::move(new_chunk);
+                map_idx_to_media[start_idx+it->first] = std::move(new_chunk);
             }
         }
     }
@@ -1245,10 +1266,10 @@ public:
                 }
             }
             // remove all image chunks that are not used anymore
-            for (auto it = map_pos_to_media.begin(); it != map_pos_to_media.end(); ) {
-                llama_pos pos = it->first;
-                if (pos >= (llama_pos)n) {
-                    it = map_pos_to_media.erase(it);
+            for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ) {
+                size_t idx = it->first;
+                if (idx >= n) {
+                    it = map_idx_to_media.erase(it);
                 } else {
                     ++it;
                 }
@@ -1296,12 +1317,12 @@ public:
                 const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get());
                 const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get());
 
-                const size_t pos_a = mtmd_input_chunk_get_n_pos(a_chunk.get());
-                const size_t pos_b = mtmd_input_chunk_get_n_pos(b_chunk.get());
+                const size_t n_tok_a = mtmd_input_chunk_get_n_tokens(a_chunk.get());
+                const size_t n_tok_b = mtmd_input_chunk_get_n_tokens(b_chunk.get());
 
-                if (id_ai == id_bi && pos_a == pos_b) {
-                    GGML_ASSERT(pos_a > 0 && "Invalid media chunk"); // should never happen
-                    i += pos_a - 1; // will be +1 by the for loop
+                if (id_ai == id_bi && n_tok_a == n_tok_b) {
+                    GGML_ASSERT(n_tok_a > 0 && "Invalid media chunk"); // should never happen
+                    i += n_tok_a - 1; // will be +1 by the for loop
                     continue;
                 }
 
@@ -1329,8 +1350,8 @@ public:
             if (t == LLAMA_TOKEN_NULL) {
                 try {
                     const auto & chunk = find_chunk(i);
-                    size_t n_pos = mtmd_input_chunk_get_n_pos(chunk.get());
-                    i += n_pos - 1; // will be +1 by the for loop
+                    size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk.get());
+                    i += n_tokens - 1; // will be +1 by the for loop
                 } catch (const std::exception & e) {
                     return false;
                 }
@@ -1345,19 +1366,20 @@ public:
     int32_t process_chunk(
                 llama_context * ctx,
                 mtmd_context * mctx,
-                llama_pos n_past,
+                size_t idx,
+                llama_pos pos,
                 int32_t seq_id,
-                llama_pos & n_pos_out) const {
-        const auto & chunk = find_chunk(n_past);
+                size_t & n_tokens_out) const {
+        const auto & chunk = find_chunk(idx);
         const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
                             ? "image" : "audio";
         SRV_INF("processing %s...\n", name);
         int32_t n_batch = llama_n_batch(ctx);
         int64_t t0 = ggml_time_ms();
-        llama_pos new_n_past = n_past;
+        llama_pos new_n_past; // unused for now
         int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx,
             chunk.get(),
-            n_past,
+            pos,
             seq_id,
             n_batch,
             true, // logits last
@@ -1365,10 +1387,10 @@ public:
         SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0);
         if (result != 0) {
             LOG_ERR("mtmd_helper_eval failed with status %d", result);
-            n_pos_out = n_past;
+            n_tokens_out = 0;
             return result;
         }
-        n_pos_out = new_n_past;
+        n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get());
         return 0;
     }
 };