ERROR_TYPE_PERMISSION,
ERROR_TYPE_UNAVAILABLE, // custom error
ERROR_TYPE_NOT_SUPPORTED, // custom error
+ ERROR_TYPE_EXCEED_CONTEXT_SIZE, // custom error
};
static bool server_task_type_need_embd(server_task_type task_type) {
type_str = "unavailable_error";
code = 503;
break;
+ case ERROR_TYPE_EXCEED_CONTEXT_SIZE:
+ type_str = "exceed_context_size_error";
+ code = 400;
+ break;
}
return json {
{"code", code},
error_type err_type = ERROR_TYPE_SERVER;
std::string err_msg;
+ // for ERROR_TYPE_EXCEED_CONTEXT_SIZE
+ int32_t n_prompt_tokens = 0;
+ int32_t n_ctx = 0;
+
virtual bool is_error() override {
return true;
}
virtual json to_json() override {
- return format_error_response(err_msg, err_type);
+ json res = format_error_response(err_msg, err_type);
+ if (err_type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) {
+ res["n_prompt_tokens"] = n_prompt_tokens;
+ res["n_ctx"] = n_ctx;
+ }
+ return res;
}
};
}
void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
- send_error(slot.id_task, error, type);
+ send_error(slot.id_task, error, type, slot.n_prompt_tokens, slot.n_ctx);
}
- void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
+ void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) {
SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str());
+ if (type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) {
+ GGML_ASSERT(n_ctx > 0 && n_prompt_tokens > 0);
+ }
+
auto res = std::make_unique<server_task_result_error>();
- res->id = id_task;
- res->err_type = type;
- res->err_msg = error;
+ res->id = id_task;
+ res->err_type = type;
+ res->err_msg = error;
+ res->n_prompt_tokens = n_prompt_tokens;
+ res->n_ctx = n_ctx;
queue_results.send(std::move(res));
}
if (slot.n_prompt_tokens > slot.n_ctx) {
slot.release();
- send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_SERVER);
+ send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
continue;
}
} else {
// context shift should be applied only during the generation phase
if (slot.n_prompt_tokens >= slot.n_ctx) {
slot.release();
- send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
+ send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
continue;
}
}
output_text = res.choices[0].message.content
assert output_text
assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude)
+
+def test_context_size_exceeded():
+ global server
+ server.start()
+ res = server.make_request("POST", "/chat/completions", data={
+ "messages": [
+ {"role": "system", "content": "Book"},
+ {"role": "user", "content": "What is the best book"},
+ ] * 100, # make the prompt too long
+ })
+ assert res.status_code == 400
+ assert "error" in res.body
+ assert res.body["error"]["type"] == "exceed_context_size_error"
+ assert res.body["error"]["n_prompt_tokens"] > 0
+ assert server.n_ctx is not None
+ assert server.n_slots is not None
+ assert res.body["error"]["n_ctx"] == server.n_ctx // server.n_slots