json data = json::parse(req.body);
// validate input
+ if (data.contains("prompt") && !data.at("prompt").is_string()) {
+ // prompt is optional
+ res_error(res, format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST));
+ }
+
if (!data.contains("input_prefix")) {
res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
}
}
if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
+ // input_extra is optional
res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
return;
}
+
json input_extra = json_value(data, "input_extra", json::array());
for (const auto & chunk : input_extra) {
// { "text": string, "filename": string }
}
data["input_extra"] = input_extra; // default to empty array if it's not exist
+ std::string prompt = json_value(data, "prompt", std::string());
+ std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
+ SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
+ data["prompt"] = format_infill(
+ ctx_server.ctx,
+ data.at("input_prefix"),
+ data.at("input_suffix"),
+ data.at("input_extra"),
+ ctx_server.params_base.n_batch,
+ ctx_server.params_base.n_predict,
+ ctx_server.slots[0].n_ctx, // TODO: there should be a better way
+ ctx_server.params_base.spm_infill,
+ tokenized_prompts[0]
+ );
+
return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res);
};
global server
server.start()
res = server.make_request("POST", "/infill", data={
- "prompt": "Complete this",
- "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
+ "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
+ "prompt": " int n_threads = llama_",
"input_suffix": "}\n",
})
assert res.status_code == 200
- assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"])
+ assert match_regex("(Ann|small|shiny)+", res.body["content"])
def test_infill_with_input_extra():
global server
server.start()
res = server.make_request("POST", "/infill", data={
- "prompt": "Complete this",
"input_extra": [{
"filename": "llama.h",
"text": "LLAMA_API int32_t llama_n_threads();\n"
}],
- "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
+ "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
+ "prompt": " int n_threads = llama_",
"input_suffix": "}\n",
})
assert res.status_code == 200
- assert match_regex("(cuts|Jimmy|mom|came|into|the|room)+", res.body["content"])
+ assert match_regex("(Dad|excited|park)+", res.body["content"])
@pytest.mark.parametrize("input_extra", [
global server
server.start()
res = server.make_request("POST", "/infill", data={
- "prompt": "Complete this",
"input_extra": [input_extra],
- "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
+ "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
+ "prompt": " int n_threads = llama_",
"input_suffix": "}\n",
})
assert res.status_code == 400
assert "error" in res.body
+
+
+@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test")
+def test_with_qwen_model():
+ global server
+ server.model_file = None
+ server.model_hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-IQ3_XXS-GGUF"
+ server.model_hf_file = "qwen2.5-coder-1.5b-iq3_xxs-imat.gguf"
+ server.start(timeout_seconds=600)
+ res = server.make_request("POST", "/infill", data={
+ "input_extra": [{
+ "filename": "llama.h",
+ "text": "LLAMA_API int32_t llama_n_threads();\n"
+ }],
+ "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
+ "prompt": " int n_threads = llama_",
+ "input_suffix": "}\n",
+ })
+ assert res.status_code == 200
+ assert res.body["content"] == "n_threads();\n printf(\"Number of threads: %d\\n\", n_threads);\n return 0;\n"