int slot_prompt_len = slot_prompt.size();
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
- int lcp_len = common_part(slot_prompt, prompt);
+ int lcp_len = longest_common_prefix(slot_prompt, prompt);
// fraction of the common substring length compared to the current slot's prompt length
similarity = static_cast<float>(lcp_len) / slot_prompt_len;
}
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
- // if input prompt is too big, truncate it (if group attention self-extend is disabled)
+ // if input prompt is too big, truncate it
if (slot.n_prompt_tokens >= slot.n_ctx) {
const int n_left = slot.n_ctx - slot.params.n_keep;
if (slot.params.cache_prompt) {
// reuse any previously computed tokens that are common with the new prompt
- slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
+ slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
// push the prompt into the sampling context (do not apply grammar)
for (int i = 0; i < slot.n_past; ++i) {
common_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
}
+
+ // reuse chunks from the cached prompt by shifting their KV cache in the new position
+ if (params.n_cache_reuse > 0) {
+ size_t head_c = slot.n_past; // cache
+ size_t head_p = slot.n_past; // current prompt
+
+ SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params.n_cache_reuse, slot.n_past);
+
+ while (head_c < slot.cache_tokens.size() &&
+ head_p < prompt_tokens.size()) {
+ if (llama_token_is_control(model, slot.cache_tokens[head_c])) {
+ break;
+ }
+
+ if (llama_token_is_control(model, prompt_tokens[head_p])) {
+ break;
+ }
+
+ size_t n_match = 0;
+
+ while (head_c + n_match < slot.cache_tokens.size() &&
+ head_p + n_match < prompt_tokens.size() &&
+ slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
+ if (llama_token_is_control(model, slot.cache_tokens[head_c + n_match])) {
+ break;
+ }
+
+ if (llama_token_is_control(model, prompt_tokens[head_p + n_match])) {
+ break;
+ }
+
+ n_match++;
+ }
+
+ if (n_match >= (size_t) params.n_cache_reuse) {
+ SLT_DBG(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
+ //for (size_t i = head_p; i < head_p + n_match; i++) {
+ // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
+ //}
+
+ const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
+
+ llama_kv_cache_seq_rm (ctx, slot.id + 1, head_p, head_c);
+ llama_kv_cache_seq_add(ctx, slot.id + 1, head_c, -1, kv_shift);
+
+ for (size_t i = 0; i < n_match; i++) {
+ slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
+
+ common_sampler_accept(slot.smpl, slot.cache_tokens[head_p + i], false);
+
+ slot.n_past++;
+ }
+
+ head_c += n_match;
+ head_p += n_match;
+ } else {
+ head_c += 1;
+ }
+ }
+
+ SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
+ }
}
}
ctx_server.queue_tasks.on_new_task(std::bind(
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
+
ctx_server.queue_tasks.on_update_slots(std::bind(
&server_context::update_slots, &ctx_server));