]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server: fix/test add_generation_prompt (#13770) upstream/latest
authorOlivier Chafik <redacted>
Sun, 25 May 2025 09:45:49 +0000 (10:45 +0100)
committerGitHub <redacted>
Sun, 25 May 2025 09:45:49 +0000 (10:45 +0100)
Co-authored-by: ochafik <redacted>
tools/server/tests/unit/test_template.py
tools/server/utils.hpp

index cf9f96a7fbc528eb6c2401a8d74997fde7e3d33a..7bb857b335bb6c91cb6943182f450714f2a84ba0 100644 (file)
@@ -47,3 +47,28 @@ def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
 
     today_str = datetime.date.today().strftime(format)
     assert today_str in prompt, f"Expected today's date ({today_str}) in content ({prompt})"
+
+
+@pytest.mark.parametrize("add_generation_prompt", [False, True])
+@pytest.mark.parametrize("template_name,expected_generation_prompt", [
+    ("meta-llama-Llama-3.3-70B-Instruct",    "<|start_header_id|>assistant<|end_header_id|>"),
+])
+def test_add_generation_prompt(template_name: str, expected_generation_prompt: str, add_generation_prompt: bool):
+    global server
+    server.jinja = True
+    server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
+    server.start(timeout_seconds=TIMEOUT_SERVER_START)
+
+    res = server.make_request("POST", "/apply-template", data={
+        "messages": [
+            {"role": "user", "content": "What is today?"},
+        ],
+        "add_generation_prompt": add_generation_prompt,
+    })
+    assert res.status_code == 200
+    prompt = res.body["prompt"]
+
+    if add_generation_prompt:
+        assert expected_generation_prompt in prompt, f"Expected generation prompt ({expected_generation_prompt}) in content ({prompt})"
+    else:
+        assert expected_generation_prompt not in prompt, f"Did not expect generation prompt ({expected_generation_prompt}) in content ({prompt})"
index 91efcfef067726a4dd95c831a686e97b86214924..ee33f76c2b06d3b9e224e426c094a92ab0c0f2e3 100644 (file)
@@ -731,6 +731,7 @@ static json oaicompat_chat_params_parse(
     inputs.grammar               = grammar;
     inputs.use_jinja             = opt.use_jinja;
     inputs.parallel_tool_calls   = json_value(body, "parallel_tool_calls", false);
+    inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
     inputs.reasoning_format      = opt.reasoning_format;
     if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
         throw std::runtime_error("Cannot use custom grammar constraints with tools.");