}
} else {
if (slot.n_prompt_tokens() >= slot.n_ctx) {
- send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
+ send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
slot.release();
continue;
}
// Everything else, including multimodal completions.
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
}
-
+ const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server.params_base.n_parallel;
tasks.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
+ auto n_prompt_tokens = inputs[i].size();
+ if (n_prompt_tokens >= n_ctx_slot) {
+ json error_data = format_error_response("the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
+ error_data["n_prompt_tokens"] = n_prompt_tokens;
+ error_data["n_ctx"] = n_ctx_slot;
+ res_error(res, error_data);
+ return;
+ }
server_task task = server_task(type);
task.id = ctx_server.queue_tasks.get_new_id();
assert res.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
+def test_context_size_exceeded_stream():
+ global server
+ server.start()
+ try:
+ for _ in server.make_stream_request("POST", "/chat/completions", data={
+ "messages": [
+ {"role": "system", "content": "Book"},
+ {"role": "user", "content": "What is the best book"},
+ ] * 100, # make the prompt too long
+ "stream": True}):
+ pass
+ assert False, "Should have failed"
+ except ServerError as e:
+ assert e.code == 400
+ assert "error" in e.body
+ assert e.body["error"]["type"] == "exceed_context_size_error"
+ assert e.body["error"]["n_prompt_tokens"] > 0
+ assert server.n_ctx is not None
+ assert server.n_slots is not None
+ assert e.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
+
+
@pytest.mark.parametrize(
"n_batch,batch_count,reuse_cache",
[
body: dict | Any
+class ServerError(Exception):
+ def __init__(self, code, body):
+ self.code = code
+ self.body = body
+
+
class ServerProcess:
# default options
debug: bool = False
response = requests.post(url, headers=headers, json=data, stream=True)
else:
raise ValueError(f"Unimplemented method: {method}")
+ if response.status_code != 200:
+ raise ServerError(response.status_code, response.json())
for line_bytes in response.iter_lines():
line = line_bytes.decode("utf-8")
if '[DONE]' in line: