From: Georgi Gerganov Date: Fri, 20 Mar 2026 09:13:12 +0000 (+0200) Subject: server : improve mtmd ctx checkpoints (#20726) X-Git-Tag: upstream/0.0.8611~161 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=ab9d4c3678a6f8bb797610a27bc0af493fcf786c;p=pkg%2Fggml%2Fsources%2Fllama.cpp server : improve mtmd ctx checkpoints (#20726) * server : improve mtmd ctx checkpoints * server : fix off-by-one in pos_min_thold --- diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 2e43bebd6..9de554e90 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2307,8 +2307,8 @@ private: llama_pos pos_next = slot.prompt.tokens.pos_next(n_past); - // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1 - const auto n_swa = std::max(1, llama_model_n_swa(model)); + // note: when n_swa == 0, the model does not use SWA + const auto n_swa = std::max(0, llama_model_n_swa(model)); // the largest pos_min required for a checkpoint to be useful const auto pos_min_thold = std::max(0, pos_next - n_swa); @@ -2363,7 +2363,7 @@ private: SLT_WRN(slot, "%s\n", st1.str().c_str()); } - if (pos_min > pos_min_thold) { + if (pos_min >= pos_min_thold) { SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa); // search for a context checkpoint @@ -2459,31 +2459,6 @@ private: slot.n_prompt_tokens_cache = 0; } - bool do_checkpoint = params_base.n_ctx_checkpoints > 0; - - // check if we should process the image - if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) { - // process the image - size_t n_tokens_out = 0; - int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); - if (res != 0) { - SLT_ERR(slot, "failed to process image, res = %d\n", res); - send_error(slot, "failed to process image", ERROR_TYPE_SERVER); - slot.release(); - continue; - } - - slot.n_prompt_tokens_processed += n_tokens_out; - - // add the image chunk to cache - { - const auto & chunk = input_tokens.find_chunk(slot.prompt.n_tokens()); - slot.prompt.tokens.push_back(chunk.get()); // copy - } - - do_checkpoint = false; // do not checkpoint right after an image chunk - } - // If using an alora, there may be uncached tokens that come // before the invocation sequence. When this happens, the // tokens before the invocation sequence need to be @@ -2498,6 +2473,8 @@ private: alora_disabled_id = enabled_loras[0]; } + bool do_checkpoint = params_base.n_ctx_checkpoints > 0; + // make checkpoints only for completion tasks do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION; @@ -2513,6 +2490,31 @@ private: (llama_model_n_swa(model) > 0 && !params_base.swa_full) ); + bool has_mtmd = false; + + // check if we should process the image + if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) { + // process the image + size_t n_tokens_out = 0; + int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); + if (res != 0) { + SLT_ERR(slot, "failed to process image, res = %d\n", res); + send_error(slot, "failed to process image", ERROR_TYPE_SERVER); + slot.release(); + continue; + } + + slot.n_prompt_tokens_processed += n_tokens_out; + + // add the image chunk to cache + { + const auto & chunk = input_tokens.find_chunk(slot.prompt.n_tokens()); + slot.prompt.tokens.push_back(chunk.get()); // copy + } + + has_mtmd = true; + } + // add prompt tokens for processing in the current batch while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) { // get next token to process @@ -2544,13 +2546,13 @@ private: // - 4 + n_ubatch // - 4 // ref: https://github.com/ggml-org/llama.cpp/pull/20288 - { + if (do_checkpoint) { static const int checkpoint_offsets[] = {4 + n_ubatch, 4}; bool should_break = false; for (int offset : checkpoint_offsets) { const int n_last = std::min(n_batch, offset); - if (do_checkpoint && slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) { + if (slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) { should_break = true; break; } @@ -2607,10 +2609,13 @@ private: const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); // no need for empty or small checkpoints - do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64); + do_checkpoint = do_checkpoint && (pos_min >= 0 && slot.prompt.n_tokens() >= 64); + + // do not checkpoint after mtmd chunks + do_checkpoint = do_checkpoint && !has_mtmd; // no need to create checkpoints that are too close together - do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64); + do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || slot.prompt.n_tokens() - n_tokens_cur > slot.prompt.checkpoints.back().n_tokens + 64); // note: we create the checkpoint before calling llama_decode(), so the current batch is not // yet processed and therefore it is not part of the checkpoint.