]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : fix format_infill (#10724)
authorXuan Son Nguyen <redacted>
Sun, 8 Dec 2024 22:04:29 +0000 (23:04 +0100)
committerGitHub <redacted>
Sun, 8 Dec 2024 22:04:29 +0000 (23:04 +0100)
* server : fix format_infill

* fix

* rename

* update test

* use another model

* update test

* update test

* test_invalid_input_extra_req

examples/server/server.cpp
examples/server/tests/unit/test_infill.py
examples/server/tests/utils.py

index 1d9c0533d4c404c16427e2e093a0f244df7926bf..47bfd6c4aac9007b125033ac610b39b69fec52fa 100644 (file)
@@ -3484,6 +3484,11 @@ int main(int argc, char ** argv) {
         json data = json::parse(req.body);
 
         // validate input
+        if (data.contains("prompt") && !data.at("prompt").is_string()) {
+            // prompt is optional
+            res_error(res, format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST));
+        }
+
         if (!data.contains("input_prefix")) {
             res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
         }
@@ -3493,9 +3498,11 @@ int main(int argc, char ** argv) {
         }
 
         if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
+            // input_extra is optional
             res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
             return;
         }
+
         json input_extra = json_value(data, "input_extra", json::array());
         for (const auto & chunk : input_extra) {
             // { "text": string, "filename": string }
@@ -3511,6 +3518,21 @@ int main(int argc, char ** argv) {
         }
         data["input_extra"] = input_extra; // default to empty array if it's not exist
 
+        std::string prompt = json_value(data, "prompt", std::string());
+        std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
+        SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
+        data["prompt"] = format_infill(
+            ctx_server.ctx,
+            data.at("input_prefix"),
+            data.at("input_suffix"),
+            data.at("input_extra"),
+            ctx_server.params_base.n_batch,
+            ctx_server.params_base.n_predict,
+            ctx_server.slots[0].n_ctx, // TODO: there should be a better way
+            ctx_server.params_base.spm_infill,
+            tokenized_prompts[0]
+        );
+
         return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res);
     };
 
index 6a6d40a1cbc8b81b34787d9e5484623f219f3570..ad4b8192a78756c884f95adeb04fb7c8bf887e00 100644 (file)
@@ -13,28 +13,28 @@ def test_infill_without_input_extra():
     global server
     server.start()
     res = server.make_request("POST", "/infill", data={
-        "prompt": "Complete this",
-        "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n    int n_threads = llama_",
+        "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
+        "prompt": "    int n_threads = llama_",
         "input_suffix": "}\n",
     })
     assert res.status_code == 200
-    assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"])
+    assert match_regex("(Ann|small|shiny)+", res.body["content"])
 
 
 def test_infill_with_input_extra():
     global server
     server.start()
     res = server.make_request("POST", "/infill", data={
-        "prompt": "Complete this",
         "input_extra": [{
             "filename": "llama.h",
             "text": "LLAMA_API int32_t llama_n_threads();\n"
         }],
-        "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n    int n_threads = llama_",
+        "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
+        "prompt": "    int n_threads = llama_",
         "input_suffix": "}\n",
     })
     assert res.status_code == 200
-    assert match_regex("(cuts|Jimmy|mom|came|into|the|room)+", res.body["content"])
+    assert match_regex("(Dad|excited|park)+", res.body["content"])
 
 
 @pytest.mark.parametrize("input_extra", [
@@ -48,10 +48,30 @@ def test_invalid_input_extra_req(input_extra):
     global server
     server.start()
     res = server.make_request("POST", "/infill", data={
-        "prompt": "Complete this",
         "input_extra": [input_extra],
-        "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n    int n_threads = llama_",
+        "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
+        "prompt": "    int n_threads = llama_",
         "input_suffix": "}\n",
     })
     assert res.status_code == 400
     assert "error" in res.body
+
+
+@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test")
+def test_with_qwen_model():
+    global server
+    server.model_file = None
+    server.model_hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-IQ3_XXS-GGUF"
+    server.model_hf_file = "qwen2.5-coder-1.5b-iq3_xxs-imat.gguf"
+    server.start(timeout_seconds=600)
+    res = server.make_request("POST", "/infill", data={
+        "input_extra": [{
+            "filename": "llama.h",
+            "text": "LLAMA_API int32_t llama_n_threads();\n"
+        }],
+        "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
+        "prompt": "    int n_threads = llama_",
+        "input_suffix": "}\n",
+    })
+    assert res.status_code == 200
+    assert res.body["content"] == "n_threads();\n    printf(\"Number of threads: %d\\n\", n_threads);\n    return 0;\n"
index 69215eaa4ebb70eb087c04191ff97a0ee1ae0f35..7c89b9cd3750513db14cbb65c56cdb626750f452 100644 (file)
@@ -371,3 +371,6 @@ def match_regex(regex: str, text: str) -> bool:
         ).search(text)
         is not None
     )
+
+def is_slow_test_allowed():
+    return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON"