]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : return HTTP 400 if prompt exceeds context length (#16486)
authorRadoslav Gerganov <redacted>
Fri, 10 Oct 2025 14:11:07 +0000 (17:11 +0300)
committerGitHub <redacted>
Fri, 10 Oct 2025 14:11:07 +0000 (16:11 +0200)
In streaming mode when prompt exceeds context length, the server returns
HTTP 200 status code with a JSON error in the body.  This is very
confusing and inconsistent with all other inference engines which return
HTTP 4xx error in this case.

This patch fixes this problem and makes the server return HTTP 400 in
such cases.

tools/server/server.cpp
tools/server/tests/unit/test_chat_completion.py
tools/server/tests/utils.py

index 39c950c15dfe0d40a1320eed978e2b7d459a804d..5293a98f034f1c5801e7b60d3e96efd1333a86c5 100644 (file)
@@ -3727,7 +3727,7 @@ struct server_context {
                             }
                         } else {
                             if (slot.n_prompt_tokens() >= slot.n_ctx) {
-                                send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
+                                send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
                                 slot.release();
                                 continue;
                             }
@@ -4955,9 +4955,17 @@ int main(int argc, char ** argv) {
                 // Everything else, including multimodal completions.
                 inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
             }
-
+            const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server.params_base.n_parallel;
             tasks.reserve(inputs.size());
             for (size_t i = 0; i < inputs.size(); i++) {
+                auto n_prompt_tokens = inputs[i].size();
+                if (n_prompt_tokens >= n_ctx_slot) {
+                    json error_data = format_error_response("the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
+                    error_data["n_prompt_tokens"] = n_prompt_tokens;
+                    error_data["n_ctx"] = n_ctx_slot;
+                    res_error(res, error_data);
+                    return;
+                }
                 server_task task = server_task(type);
 
                 task.id    = ctx_server.queue_tasks.get_new_id();
index 6e5a3488e789bac81c50c5684d84bce1fbaa73ef..d56d3d5f178b80d37ced9c95f4db65a82e05b364 100644 (file)
@@ -408,6 +408,28 @@ def test_context_size_exceeded():
     assert res.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
 
 
+def test_context_size_exceeded_stream():
+    global server
+    server.start()
+    try:
+        for _ in server.make_stream_request("POST", "/chat/completions", data={
+            "messages": [
+                {"role": "system", "content": "Book"},
+                {"role": "user", "content": "What is the best book"},
+            ] * 100, # make the prompt too long
+            "stream": True}):
+                pass
+        assert False, "Should have failed"
+    except ServerError as e:
+        assert e.code == 400
+        assert "error" in e.body
+        assert e.body["error"]["type"] == "exceed_context_size_error"
+        assert e.body["error"]["n_prompt_tokens"] > 0
+        assert server.n_ctx is not None
+        assert server.n_slots is not None
+        assert e.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
+
+
 @pytest.mark.parametrize(
     "n_batch,batch_count,reuse_cache",
     [
index abd6fff10d0d1bf03b7bb4d7547e4a91345a72a3..4ba3d43c330442a2cfe305d790f12144a5b411a5 100644 (file)
@@ -35,6 +35,12 @@ class ServerResponse:
     body: dict | Any
 
 
+class ServerError(Exception):
+    def __init__(self, code, body):
+        self.code = code
+        self.body = body
+
+
 class ServerProcess:
     # default options
     debug: bool = False
@@ -297,6 +303,8 @@ class ServerProcess:
             response = requests.post(url, headers=headers, json=data, stream=True)
         else:
             raise ValueError(f"Unimplemented method: {method}")
+        if response.status_code != 200:
+            raise ServerError(response.status_code, response.json())
         for line_bytes in response.iter_lines():
             line = line_bytes.decode("utf-8")
             if '[DONE]' in line: