server_task(server_task_type type) : type(type) {}
+ int32_t n_tokens() const {
+ return tokens.size();
+ }
+
static slot_params params_from_json_cmpl(
const llama_context * ctx,
const common_params & params_base,
uint64_t n_tokens_predicted_total = 0;
uint64_t t_tokens_generation_total = 0;
- uint64_t n_past_max = 0;
+ uint64_t n_tokens_max = 0;
uint64_t n_prompt_tokens_processed = 0;
uint64_t t_prompt_processing = 0;
{ "n_tokens_predicted_total", n_tokens_predicted_total },
{ "t_prompt_processing_total", t_prompt_processing_total },
- { "n_past_max", n_past_max },
+ { "n_tokens_max", n_tokens_max },
{ "n_prompt_tokens_processed", n_prompt_tokens_processed },
{ "t_prompt_processing", t_prompt_processing },
// generation props
int32_t n_ctx = 0; // context size per slot
- int32_t n_past = 0;
int32_t n_keep = 0;
int32_t n_decoded = 0;
int32_t n_remaining = -1;
int32_t n_prompt_tokens_cache = 0;
int32_t n_prompt_tokens_processed = 0;
- int32_t n_prompt_tokens() const {
- return task->tokens.size();
- }
-
size_t last_nl_pos = 0;
std::string generated_text;
truncated = false;
stop = STOP_TYPE_NONE;
stopping_word = "";
- n_past = 0;
n_sent_text = 0;
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
if (is_processing()) {
GGML_ASSERT(task);
- SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated);
+ SLT_INF(*this, "stop processing: n_tokens = %d, truncated = %d\n", prompt.n_tokens(), truncated);
t_last_used = ggml_time_us();
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
uint64_t n_tokens_predicted_total = 0;
uint64_t t_tokens_generation_total = 0;
- uint64_t n_past_max = 0;
+ uint64_t n_tokens_max = 0;
uint64_t n_prompt_tokens_processed = 0;
uint64_t t_prompt_processing = 0;
t_prompt_processing += slot.t_prompt_processing;
t_prompt_processing_total += slot.t_prompt_processing;
- if (slot.n_past > 0) {
- n_past_max = std::max(n_past_max, (uint64_t) slot.n_past);
- }
+ n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens());
}
void on_prediction(const server_slot & slot) {
if (slot.is_processing()) {
n_busy_slots_total++;
}
- if (slot.n_past > 0) {
- n_past_max = std::max(n_past_max, (uint64_t) slot.n_past);
- }
+ n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens());
}
}
}
// if context shifting is disabled, make sure that we don't run out of context
- if (!params_base.ctx_shift && slot.n_past + 1 >= slot.n_ctx) {
+ if (!params_base.ctx_shift && slot.prompt.n_tokens() + 1 >= slot.n_ctx) {
slot.truncated = true;
slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false;
- SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
- slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx);
+ SLT_DBG(slot, "stopped due to running out of context capacity, prompt.n_tokens() = %d, task.n_tokens = %d, n_decoded = %d, n_ctx = %d\n",
+ slot.prompt.n_tokens(), slot.task->n_tokens(), slot.n_decoded, slot.n_ctx);
}
// check the limits
}
void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
- send_error(slot.task->id, error, type, slot.n_prompt_tokens(), slot.n_ctx);
+ send_error(slot.task->id, error, type, slot.task->n_tokens(), slot.n_ctx);
}
void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) {
if (is_progress) {
res->is_progress = true;
- res->progress.total = slot.n_prompt_tokens();
+ res->progress.total = slot.task->n_tokens();
res->progress.cache = slot.n_prompt_tokens_cache;
res->progress.processed = slot.prompt.tokens.size();
res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt / 1000);
}
res->n_decoded = slot.n_decoded;
- res->n_prompt_tokens = slot.n_prompt_tokens();
+ res->n_prompt_tokens = slot.task->n_tokens();
res->post_sampling_probs = slot.task->params.post_sampling_probs;
res->verbose = slot.task->params.verbose;
res->truncated = slot.truncated;
res->n_decoded = slot.n_decoded;
- res->n_prompt_tokens = slot.n_prompt_tokens();
- res->n_tokens_cached = slot.n_past;
+ res->n_prompt_tokens = slot.task->n_tokens();
+ res->n_tokens_cached = slot.prompt.n_tokens();
res->has_new_line = slot.has_new_line;
res->stopping_word = slot.stopping_word;
res->stop = slot.stop;
auto res = std::make_unique<server_task_result_embd>();
res->id = slot.task->id;
res->index = slot.task->index;
- res->n_tokens = slot.n_prompt_tokens();
+ res->n_tokens = slot.task->n_tokens();
res->oaicompat = slot.task->params.oaicompat;
const int n_embd = llama_model_n_embd(model);
auto res = std::make_unique<server_task_result_rerank>();
res->id = slot.task->id;
res->index = slot.task->index;
- res->n_tokens = slot.n_prompt_tokens();
+ res->n_tokens = slot.task->n_tokens();
for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
res->n_tokens_predicted_total = metrics.n_tokens_predicted_total;
res->t_tokens_generation_total = metrics.t_tokens_generation_total;
- res->n_past_max = metrics.n_past_max;
+ res->n_tokens_max = metrics.n_tokens_max;
res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed;
res->t_prompt_processing = metrics.t_prompt_processing;
// apply context-shift if needed
// TODO: simplify and improve
for (server_slot & slot : slots) {
- if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) {
+ if (slot.is_processing() && slot.prompt.n_tokens() + 1 >= slot.n_ctx) {
if (!params_base.ctx_shift) {
// this check is redundant (for good)
// we should never get here, because generation should already stopped in process_token()
}
// Shift context
- int n_keep = slot.task->params.n_keep < 0 ? slot.n_prompt_tokens() : slot.task->params.n_keep;
+ int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep;
if (add_bos_token) {
n_keep += 1;
n_keep = std::min(slot.n_ctx - 4, n_keep);
- const int n_left = slot.n_past - n_keep;
+ const int n_left = slot.prompt.n_tokens() - n_keep;
const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2);
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard);
- llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.n_past, -n_discard);
+ llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard);
// add generated tokens to cache
+ // ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481
{
+ GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
+
llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
new_tokens[i - n_discard] = new_tokens[i];
}
new_tokens.resize(slot.prompt.tokens.size() - n_discard);
+
slot.prompt.tokens.clear();
slot.prompt.tokens.insert(new_tokens);
}
- slot.n_past -= n_discard;
-
slot.truncated = true;
}
}
slot.i_batch = batch.n_tokens;
- common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
+ common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
- slot.n_past += 1;
slot.prompt.tokens.push_back(slot.sampled);
- SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
- slot.n_ctx, slot.n_past, (int) slot.prompt.tokens.size(), slot.truncated);
+ SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
+ slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
}
// process in chunks of params.n_batch
slot.t_start_process_prompt = ggml_time_us();
slot.t_start_generation = 0;
- slot.n_past = 0;
slot.state = SLOT_STATE_PROCESSING_PROMPT;
- SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n",
- slot.n_ctx, slot.task->params.n_keep, slot.n_prompt_tokens());
+ SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n",
+ slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens());
// print prompt tokens (for debugging)
/*if (1) {
}
}*/
+ // keep track how many tokens we can reuse from the previous state
+ int n_past = 0;
+
// empty prompt passed -> release the slot and send empty response
if (input_tokens.empty()) {
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
}
if (!slot.can_split()) {
- if (slot.n_prompt_tokens() > n_ubatch) {
+ if (slot.task->n_tokens() > n_ubatch) {
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
slot.release();
continue;
}
- if (slot.n_prompt_tokens() > slot.n_ctx) {
+ if (slot.task->n_tokens() > slot.n_ctx) {
send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
slot.release();
continue;
}
} else {
- if (slot.n_prompt_tokens() >= slot.n_ctx) {
+ if (slot.task->n_tokens() >= slot.n_ctx) {
send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
slot.release();
continue;
if (slot.task->params.cache_prompt) {
// reuse any previously computed tokens that are common with the new prompt
- slot.n_past = slot.prompt.tokens.get_common_prefix(input_tokens);
+ n_past = slot.prompt.tokens.get_common_prefix(input_tokens);
// if there is an alora invoked, don't cache after the invocation start
- if (slot.alora_invocation_start >= 0) {
- SLT_DBG(slot, "only caching to alora invocation start (n_past=%d, alora_invocation_start=%d)\n", slot.n_past, slot.alora_invocation_start);
- slot.n_past = std::min(slot.n_past, slot.alora_invocation_start - 1);
+ if (slot.alora_invocation_start > 0) {
+ SLT_DBG(slot, "only caching to alora invocation start (n_past = %d, alora_invocation_start = %d)\n", n_past, slot.alora_invocation_start);
+ n_past = std::min(n_past, slot.alora_invocation_start - 1);
}
// reuse chunks from the cached prompt by shifting their KV cache in the new position
if (params_base.n_cache_reuse > 0) {
- size_t head_c = slot.n_past; // cache
- size_t head_p = slot.n_past; // current prompt
+ GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
+
+ size_t head_c = n_past; // cache
+ size_t head_p = n_past; // current prompt
if (mctx) {
// we should never reach this
GGML_ABORT("not supported by multimodal");
}
- SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past);
+ SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", params_base.n_cache_reuse, n_past);
while (head_c < slot.prompt.tokens.size() &&
head_p < input_tokens.size()) {
size_t n_match = 0;
while (head_c + n_match < slot.prompt.tokens.size() &&
- head_p + n_match < input_tokens.size() &&
+ head_p + n_match < input_tokens.size() &&
slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) {
n_match++;
for (size_t i = 0; i < n_match; i++) {
slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]);
- slot.n_past++;
+ n_past++;
}
head_c += n_match;
}
}
- SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
+ SLT_DBG(slot, "after context reuse, new n_past = %d\n", n_past);
}
} else {
- // if we don't cache the prompt, we have to remove the entire KV cache
- slot.n_past = 0;
+ // if we don't cache the prompt, we have to remove all previous tokens
+ n_past = 0;
}
// note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
const auto n_swa = std::max(1, llama_model_n_swa(model));
// the largest pos_min required for a checkpoint to be useful
- const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
+ const auto pos_min_thold = std::max(0, n_past - n_swa);
- if (slot.n_past > 0 && slot.n_past < (int) slot.prompt.tokens.size()) {
+ if (n_past > 0 && n_past < slot.prompt.n_tokens()) {
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
if (pos_min == -1) {
- SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
+ SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
}
// when the prompt prefix does not match, print the tokens around the mismatch
// this is useful for debugging prompt caching
{
- const int np0 = std::max<int>(slot.n_past - 4, 0);
- const int np1 = std::min<int>(slot.n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size()));
+ const int np0 = std::max<int>(n_past - 4, 0);
+ const int np1 = std::min<int>(n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size()));
std::stringstream ss0;
std::stringstream ss1;
ss1 << "new: ... ";
for (int i = np0; i < np1; i++) {
- if (i == slot.n_past) {
+ if (i == n_past) {
ss0 << " | ";
ss1 << " | ";
}
}
if (pos_min > pos_min_thold) {
- SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
+ // TODO: support can be added in the future when corresponding vision models get released
+ GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
+
+ SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
// search for a context checkpoint
const auto it = std::find_if(
do_reset = true;
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
} else {
- slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max));
+ n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max));
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
}
}
if (do_reset) {
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
- slot.n_past = 0;
+ n_past = 0;
}
}
}
}
// [TAG_PROMPT_LOGITS]
- if (slot.n_past == slot.n_prompt_tokens() && slot.n_past > 0) {
- SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens());
- slot.n_past--;
- SLT_WRN(slot, "n_past was set to %d\n", slot.n_past);
+ if (n_past == slot.task->n_tokens() && n_past > 0) {
+ SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, task.n_tokens() = %d)\n", n_past, slot.task->n_tokens());
+ n_past--;
+ SLT_WRN(slot, "n_past was set to %d\n", n_past);
}
- slot.n_prompt_tokens_cache = slot.n_past;
+ slot.n_prompt_tokens_cache = n_past;
slot.n_prompt_tokens_processed = 0;
+
+ slot.prompt.tokens.keep_first(n_past);
}
if (!slot.can_split()) {
// cannot fit the prompt in the current batch - will try next iter
- if (batch.n_tokens + slot.n_prompt_tokens() > n_batch) {
+ if (batch.n_tokens + slot.task->n_tokens() > n_batch) {
continue;
}
}
// truncate any tokens that are beyond n_past for this slot
- if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1)) {
- SLT_WRN(slot, "failed to truncate tokens beyond n_past = %d\n", slot.n_past);
+ const llama_pos p0 = slot.prompt.tokens.pos_next();
+ if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
+ SLT_WRN(slot, "failed to truncate tokens with position >= %d\n", p0);
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
// there is no common part left
- slot.n_past = 0;
slot.n_prompt_tokens_cache = 0;
- }
- SLT_INF(slot, "n_past = %d, memory_seq_rm [%d, end)\n", slot.n_past, slot.n_past);
+ slot.prompt.tokens.clear();
+ }
- // remove the non-common part from the cache
- slot.prompt.tokens.keep_first(slot.n_past);
+ SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
// check if we should process the image
- if (slot.n_past < slot.n_prompt_tokens() && input_tokens[slot.n_past] == LLAMA_TOKEN_NULL) {
+ if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
// process the image
- int32_t new_n_past;
- int32_t res = input_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past);
+ size_t n_tokens_out = 0;
+ int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
if (res != 0) {
SLT_ERR(slot, "failed to process image, res = %d\n", res);
send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
continue;
}
+ slot.n_prompt_tokens_processed += n_tokens_out;
+
// add the image chunk to cache
{
- const auto & chunk = input_tokens.find_chunk(slot.n_past);
+ const auto & chunk = input_tokens.find_chunk(slot.prompt.n_tokens());
slot.prompt.tokens.push_back(chunk.get()); // copy
}
-
- const int32_t n_pos = new_n_past - slot.n_past;
-
- slot.n_past += n_pos;
- slot.n_prompt_tokens_processed += n_pos;
}
// If using an alora, there may be uncached tokens that come
// tokens before the invocation sequence need to be
// processed without the adpter in a separate batch, then
// the adapter needs to be enabled for the remaining tokens.
- if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.n_past) {
- SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_past = %d, alora_invocation_start = %d)\n", slot.n_past, slot.alora_invocation_start);
+ if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) {
+ SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start);
const auto & enabled_loras = lora_get_enabled_ids(slot.lora);
GGML_ASSERT(enabled_loras.size() == 1);
alora_scale = slot.lora[enabled_loras[0]].scale;
);
// add prompt tokens for processing in the current batch
- while (slot.n_past < slot.n_prompt_tokens() && batch.n_tokens < n_batch) {
+ while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) {
// get next token to process
- llama_token cur_tok = input_tokens[slot.n_past];
+ llama_token cur_tok = input_tokens[slot.prompt.n_tokens()];
if (cur_tok == LLAMA_TOKEN_NULL) {
break; // end of text chunk
}
// if this is an alora request with pre-invocation
// tokens that are not cached, we need to stop filling
// this batch at those pre-invocation tokens.
- if (alora_scale > 0 && slot.n_past == slot.alora_invocation_start - 1) {
- SLT_DBG(slot, "stop prompt batch filling at (n_past = %d, alora_invocation_start = %d)\n", slot.n_past, slot.alora_invocation_start);
+ if (alora_scale > 0 && slot.prompt.n_tokens() == slot.alora_invocation_start - 1) {
+ SLT_DBG(slot, "stop prompt batch filling at (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start);
break;
}
// embedding requires all tokens in the batch to be output
- common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, slot.need_embd());
+ common_batch_add(batch,
+ cur_tok,
+ slot.prompt.tokens.pos_next(),
+ { slot.id },
+ slot.need_embd());
slot.prompt.tokens.push_back(cur_tok);
slot.n_prompt_tokens_processed++;
- slot.n_past++;
// process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
- if (do_checkpoint && slot.n_prompt_tokens() - slot.n_past == 64) {
+ if (do_checkpoint && slot.task->n_tokens() - slot.prompt.n_tokens() == 64) {
break;
}
}
// SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str());
- SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_past / slot.n_prompt_tokens());
+ SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
// entire prompt has been processed
- if (slot.n_past == slot.n_prompt_tokens()) {
+ if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
slot.state = SLOT_STATE_DONE_PROMPT;
GGML_ASSERT(batch.n_tokens > 0);
common_sampler_reset(slot.smpl);
// Process all prompt tokens through sampler system
- for (int i = 0; i < slot.n_prompt_tokens(); ++i) {
+ for (int i = 0; i < slot.task->n_tokens(); ++i) {
llama_token id = input_tokens[i];
if (id != LLAMA_TOKEN_NULL) {
common_sampler_accept(slot.smpl, id, false);
slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
- SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
+ SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
// determine the max draft that fits the current slot state
int n_draft_max = slot.task->params.speculative.n_max;
- // note: n_past is not yet increased for the `id` token sampled above
+ // note: slot.prompt is not yet expanded with the `id` token sampled above
// also, need to leave space for 1 extra token to allow context shifts
- n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
+ n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.prompt.n_tokens() - 2);
if (slot.n_remaining > 0) {
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
// construct the speculation batch
common_batch_clear(slot.batch_spec);
- common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
+ common_batch_add (slot.batch_spec, id, slot.prompt.tokens.pos_next(), { slot.id }, true);
for (size_t i = 0; i < draft.size(); ++i) {
- common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
+ common_batch_add(slot.batch_spec, draft[i], slot.prompt.tokens.pos_next() + 1 + i, { slot.id }, true);
}
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
- slot.n_past += ids.size();
slot.n_decoded += ids.size();
// update how many tokens out of those tested were accepted
slot.prompt.tokens.push_back(id);
slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
- llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1);
+ llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;
}
}
- SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
+ SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.prompt.n_tokens());
}
}
{"help", "Total number of llama_decode() calls"},
{"value", res_task->n_decode_total}
}, {
- {"name", "n_past_max"},
- {"help", "Largest observed n_past."},
- {"value", res_task->n_past_max}
+ {"name", "n_tokens_max"},
+ {"help", "Largest observed n_tokens."},
+ {"value", res_task->n_tokens_max}
}, {
{"name", "n_busy_slots_per_decode"},
{"help", "Average number of busy slots per llama_decode() call"},
private: // disallow accessing these members directly, risking out-of-sync
- // map a **start** position in tokens to the image chunk
- std::unordered_map<llama_pos, mtmd::input_chunk_ptr> map_pos_to_media;
+ // map a **start** index in tokens to the image chunk
+ // note: the order need to be in-sync with tokens
+ std::map<size_t, mtmd::input_chunk_ptr> map_idx_to_media;
// list of tokens
- // it can include LLAMA_TOKEN_NULL, which is used to indicate a token that is not a text token
- // a mtmd_input_chunk can occupy multiple tokens, one llama_token per **position**
- // important: for models using mrope, an image can contain multiple tokens but will use only one **position**
+ // if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk
+ // otherwise, it is a normal text token
+ // note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list
+ // note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos
llama_tokens tokens;
- // for ex. with input of 5 text tokens and 2 images:
- // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1]
- // pos 0 1 2 3 4 5 6 7 8 9
- // map_pos_to_media will contain: {5, img0}, {8, img1}
+ // for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos):
+ // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1]
+ // idx 0 1 2 3 4 5 6 7 8 9 10
+ // pos 0 1 2 3 4 5 5 5 7 7 7
+ // map_idx_to_media will contain: {5, img0}, {8, img1}
public:
server_tokens() = default;
}
}
- server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {}
+ server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {
+ }
+
+ llama_pos pos_next() const {
+ if (!has_mtmd) {
+ return tokens.size();
+ }
+
+ llama_pos res = tokens.size();
+
+ for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
+ const auto & chunk = it->second;
+ res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get());
+ }
+
+ return res;
+ }
// for debugging
std::string str() const {
std::ostringstream oss;
oss << "tokens: ";
- for (const auto & t : tokens) {
+ for (size_t idx = 0; idx < tokens.size(); ++idx) {
+ llama_token t = tokens[idx];
+ oss << "idx:" << idx << " ";
if (t == LLAMA_TOKEN_NULL) {
oss << "<embd> ";
} else {
}
}
oss << "\n";
- oss << "image pos: ";
- for (const auto & it : map_pos_to_media) {
+ oss << "image idx: ";
+ for (const auto & it : map_idx_to_media) {
oss << it.first << ", ";
}
return oss.str();
}
- const mtmd::input_chunk_ptr & find_chunk(llama_pos pos) const {
- auto it = map_pos_to_media.find(pos);
- if (it != map_pos_to_media.end()) {
+ const mtmd::input_chunk_ptr & find_chunk(size_t idx) const {
+ auto it = map_idx_to_media.find(idx);
+ if (it != map_idx_to_media.end()) {
return it->second;
}
throw std::runtime_error("Chunk not found");
auto type = mtmd_input_chunk_get_type(chunk);
if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
GGML_ASSERT(has_mtmd);
- const int n_pos = mtmd_input_chunk_get_n_pos(chunk);
- llama_pos start_pos = tokens.size();
- for (int i = 0; i < n_pos; ++i) {
+ const size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk);
+ size_t start_idx = tokens.size();
+ for (size_t i = 0; i < n_tokens; ++i) {
tokens.emplace_back(LLAMA_TOKEN_NULL);
}
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
- map_pos_to_media[start_pos] = std::move(new_chunk);
+ map_idx_to_media[start_idx] = std::move(new_chunk);
} else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
size_t n_tokens;
const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
// appends server tokens, updates the media map. copies media chunks.
void push_back(server_tokens & tokens) {
- size_t start_pos = size();
+ size_t start_idx = size();
for (size_t i = 0; i < tokens.size(); i++) {
push_back(tokens[i]);
}
// Assert if we are copying MTMD chunks to a server_tokens that does not have mtmd.
// We could also just check, but this will prevent silently dropping MTMD data.
GGML_ASSERT(has_mtmd);
- for (auto it = tokens.map_pos_to_media.begin(); it != tokens.map_pos_to_media.end(); ) {
- auto * chunk = tokens.map_pos_to_media[it->first].get();
+ for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) {
+ auto * chunk = tokens.map_idx_to_media[it->first].get();
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
- map_pos_to_media[start_pos+it->first] = std::move(new_chunk);
+ map_idx_to_media[start_idx+it->first] = std::move(new_chunk);
}
}
}
}
}
// remove all image chunks that are not used anymore
- for (auto it = map_pos_to_media.begin(); it != map_pos_to_media.end(); ) {
- llama_pos pos = it->first;
- if (pos >= (llama_pos)n) {
- it = map_pos_to_media.erase(it);
+ for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ) {
+ size_t idx = it->first;
+ if (idx >= n) {
+ it = map_idx_to_media.erase(it);
} else {
++it;
}
const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get());
const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get());
- const size_t pos_a = mtmd_input_chunk_get_n_pos(a_chunk.get());
- const size_t pos_b = mtmd_input_chunk_get_n_pos(b_chunk.get());
+ const size_t n_tok_a = mtmd_input_chunk_get_n_tokens(a_chunk.get());
+ const size_t n_tok_b = mtmd_input_chunk_get_n_tokens(b_chunk.get());
- if (id_ai == id_bi && pos_a == pos_b) {
- GGML_ASSERT(pos_a > 0 && "Invalid media chunk"); // should never happen
- i += pos_a - 1; // will be +1 by the for loop
+ if (id_ai == id_bi && n_tok_a == n_tok_b) {
+ GGML_ASSERT(n_tok_a > 0 && "Invalid media chunk"); // should never happen
+ i += n_tok_a - 1; // will be +1 by the for loop
continue;
}
if (t == LLAMA_TOKEN_NULL) {
try {
const auto & chunk = find_chunk(i);
- size_t n_pos = mtmd_input_chunk_get_n_pos(chunk.get());
- i += n_pos - 1; // will be +1 by the for loop
+ size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk.get());
+ i += n_tokens - 1; // will be +1 by the for loop
} catch (const std::exception & e) {
return false;
}
int32_t process_chunk(
llama_context * ctx,
mtmd_context * mctx,
- llama_pos n_past,
+ size_t idx,
+ llama_pos pos,
int32_t seq_id,
- llama_pos & n_pos_out) const {
- const auto & chunk = find_chunk(n_past);
+ size_t & n_tokens_out) const {
+ const auto & chunk = find_chunk(idx);
const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
? "image" : "audio";
SRV_INF("processing %s...\n", name);
int32_t n_batch = llama_n_batch(ctx);
int64_t t0 = ggml_time_ms();
- llama_pos new_n_past = n_past;
+ llama_pos new_n_past; // unused for now
int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx,
chunk.get(),
- n_past,
+ pos,
seq_id,
n_batch,
true, // logits last
SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0);
if (result != 0) {
LOG_ERR("mtmd_helper_eval failed with status %d", result);
- n_pos_out = n_past;
+ n_tokens_out = 0;
return result;
}
- n_pos_out = new_n_past;
+ n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get());
return 0;
}
};