res->tokens = { tkn.tok };
}
- res->n_decoded = slot.n_decoded;
- res->n_prompt_tokens = slot.task->n_tokens();
- res->post_sampling_probs = slot.task->params.post_sampling_probs;
+ res->n_decoded = slot.n_decoded;
+ res->n_prompt_tokens = slot.task->n_tokens();
+ res->n_prompt_tokens_cache = slot.n_prompt_tokens_cache;
+ res->post_sampling_probs = slot.task->params.post_sampling_probs;
res->verbose = slot.task->params.verbose;
res->res_type = slot.task->params.res_type;
res->prompt = slot.task->tokens.detokenize(ctx, true);
res->response_fields = std::move(slot.task->params.response_fields);
- res->truncated = slot.truncated;
- res->n_decoded = slot.n_decoded;
- res->n_prompt_tokens = slot.task->n_tokens();
- res->n_tokens_cached = slot.prompt.n_tokens();
- res->has_new_line = slot.has_new_line;
- res->stopping_word = slot.stopping_word;
- res->stop = slot.stop;
- res->post_sampling_probs = slot.task->params.post_sampling_probs;
+ res->truncated = slot.truncated;
+ res->n_decoded = slot.n_decoded;
+ res->n_prompt_tokens = slot.task->n_tokens();
+ res->n_prompt_tokens_cache = slot.n_prompt_tokens_cache;
+ res->n_tokens_cached = slot.prompt.n_tokens();
+ res->has_new_line = slot.has_new_line;
+ res->stopping_word = slot.stopping_word;
+ res->stop = slot.stop;
+ res->post_sampling_probs = slot.task->params.post_sampling_probs;
res->verbose = slot.task->params.verbose;
res->stream = slot.task->params.stream;
return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
}
+json server_task_result_cmpl_final::usage_json_oaicompat() {
+ return json {
+ {"completion_tokens", n_decoded},
+ {"prompt_tokens", n_prompt_tokens},
+ {"total_tokens", n_decoded + n_prompt_tokens},
+ {"prompt_tokens_details", json { {"cached_tokens", n_prompt_tokens_cache} }},
+ };
+}
+
json server_task_result_cmpl_final::to_json_oaicompat() {
std::time_t t = std::time(0);
json logprobs = json(nullptr); // OAI default to null
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "text_completion"},
- {"usage", json {
- {"completion_tokens", n_decoded},
- {"prompt_tokens", n_prompt_tokens},
- {"total_tokens", n_decoded + n_prompt_tokens}
- }},
+ {"usage", usage_json_oaicompat()},
{"id", oaicompat_cmpl_id}
};
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "chat.completion"},
- {"usage", json {
- {"completion_tokens", n_decoded},
- {"prompt_tokens", n_prompt_tokens},
- {"total_tokens", n_decoded + n_prompt_tokens}
- }},
+ {"usage", usage_json_oaicompat()},
{"id", oaicompat_cmpl_id}
};
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "chat.completion.chunk"},
- {"usage", json {
- {"completion_tokens", n_decoded},
- {"prompt_tokens", n_prompt_tokens},
- {"total_tokens", n_decoded + n_prompt_tokens},
- }},
+ {"usage", usage_json_oaicompat()},
});
}
{"input_tokens", n_prompt_tokens},
{"output_tokens", n_decoded},
{"total_tokens", n_decoded + n_prompt_tokens},
+ {"input_tokens_details", json { {"cached_tokens", n_prompt_tokens_cache} }},
}},
};
{"usage", json {
{"input_tokens", n_prompt_tokens},
{"output_tokens", n_decoded},
- {"total_tokens", n_decoded + n_prompt_tokens}
+ {"total_tokens", n_decoded + n_prompt_tokens},
+ {"input_tokens_details", json { {"cached_tokens", n_prompt_tokens_cache} }},
}}
}},
}}
{"stop_reason", stop_reason},
{"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)},
{"usage", {
- {"input_tokens", n_prompt_tokens},
+ {"cache_read_input_tokens", n_prompt_tokens_cache},
+ {"input_tokens", n_prompt_tokens - n_prompt_tokens_cache},
{"output_tokens", n_decoded}
}}
};
{"stop_reason", nullptr},
{"stop_sequence", nullptr},
{"usage", {
- {"input_tokens", n_prompt_tokens},
+ {"cache_read_input_tokens", n_prompt_tokens_cache},
+ {"input_tokens", n_prompt_tokens - n_prompt_tokens_cache},
{"output_tokens", 0}
}}
}}
assert choice["finish_reason"] == finish_reason
+def test_chat_completion_cached_tokens():
+ global server
+ server.n_slots = 1
+ server.start()
+ seq = [
+ ("1 2 3 4 5 6", 77, 0),
+ ("1 2 3 4 5 6", 77, 76),
+ ("1 2 3 4 5 9", 77, 51),
+ ("1 2 3 9 9 9", 77, 47),
+ ]
+ for user_prompt, n_prompt, n_cache in seq:
+ res = server.make_request("POST", "/chat/completions", data={
+ "max_tokens": 8,
+ "messages": [
+ {"role": "system", "content": "Test"},
+ {"role": "user", "content": user_prompt},
+ ],
+ })
+ assert res.body["usage"]["prompt_tokens"] == n_prompt
+ assert res.body["usage"]["prompt_tokens_details"]["cached_tokens"] == n_cache
+
@pytest.mark.parametrize(
"system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
[