assert res.status_code == 200
assert len(res.body["data"]) == 1
assert res.body["data"][0]["id"] == server.model_alias
+
+def test_load_split_model():
+ global server
+ server.model_hf_repo = "ggml-org/models"
+ server.model_hf_file = "tinyllamas/split/stories15M-q8_0-00001-of-00003.gguf"
+ server.model_alias = "tinyllama-split"
+ server.start()
+ res = server.make_request("POST", "/completion", data={
+ "n_predict": 16,
+ "prompt": "Hello",
+ "temperature": 0.0,
+ })
+ assert res.status_code == 200
+ assert match_regex("(little|girl)+", res.body["content"])
assert res.status_code != 200
assert "error" in res.body
+
+@pytest.mark.parametrize("messages", [
+ None,
+ "string",
+ [123],
+ [{}],
+ [{"role": 123}],
+ [{"role": "system", "content": 123}],
+ # [{"content": "hello"}], # TODO: should not be a valid case
+ [{"role": "system", "content": "test"}, {}],
+])
+def test_invalid_chat_completion_req(messages):
+ global server
+ server.start()
+ res = server.make_request("POST", "/chat/completions", data={
+ "messages": messages,
+ })
+ assert res.status_code == 400 or res.status_code == 500
+ assert "error" in res.body
global server
server = ServerPreset.tinyllama_infill()
+
def test_infill_without_input_extra():
global server
server.start()
assert res.status_code == 200
assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"])
+
def test_infill_with_input_extra():
global server
server.start()
})
assert res.status_code == 200
assert match_regex("(cuts|Jimmy|mom|came|into|the|room)+", res.body["content"])
+
+
+@pytest.mark.parametrize("input_extra", [
+ {},
+ {"filename": "ok"},
+ {"filename": 123},
+ {"filename": 123, "text": "abc"},
+ {"filename": 123, "text": 456},
+])
+def test_invalid_input_extra_req(input_extra):
+ global server
+ server.start()
+ res = server.make_request("POST", "/infill", data={
+ "prompt": "Complete this",
+ "input_extra": [input_extra],
+ "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
+ "input_suffix": "}\n",
+ })
+ assert res.status_code == 400
+ assert "error" in res.body
assert most_relevant["relevance_score"] > least_relevant["relevance_score"]
assert most_relevant["index"] == 2
assert least_relevant["index"] == 3
+
+
+@pytest.mark.parametrize("documents", [
+ [],
+ None,
+ 123,
+ [1, 2, 3],
+])
+def test_invalid_rerank_req(documents):
+ global server
+ server.start()
+ res = server.make_request("POST", "/rerank", data={
+ "query": "Machine learning is",
+ "documents": documents,
+ })
+ assert res.status_code == 400
+ assert "error" in res.body
--- /dev/null
+import pytest
+from utils import *
+
+# We use a F16 MOE gguf as main model, and q4_0 as draft model
+
+server = ServerPreset.stories15m_moe()
+
+MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf"
+
+def create_server():
+ global server
+ server = ServerPreset.stories15m_moe()
+ # download draft model file if needed
+ file_name = MODEL_DRAFT_FILE_URL.split('/').pop()
+ model_draft_file = f'../../../{file_name}'
+ if not os.path.exists(model_draft_file):
+ print(f"Downloading {MODEL_DRAFT_FILE_URL} to {model_draft_file}")
+ with open(model_draft_file, 'wb') as f:
+ f.write(requests.get(MODEL_DRAFT_FILE_URL).content)
+ print(f"Done downloading draft model file")
+ # set default values
+ server.model_draft = model_draft_file
+ server.draft_min = 4
+ server.draft_max = 8
+
+
+@pytest.fixture(scope="module", autouse=True)
+def fixture_create_server():
+ return create_server()
+
+
+def test_with_and_without_draft():
+ global server
+ server.model_draft = None # disable draft model
+ server.start()
+ res = server.make_request("POST", "/completion", data={
+ "prompt": "I believe the meaning of life is",
+ "temperature": 0.0,
+ "top_k": 1,
+ })
+ assert res.status_code == 200
+ content_no_draft = res.body["content"]
+ server.stop()
+
+ # create new server with draft model
+ create_server()
+ server.start()
+ res = server.make_request("POST", "/completion", data={
+ "prompt": "I believe the meaning of life is",
+ "temperature": 0.0,
+ "top_k": 1,
+ })
+ assert res.status_code == 200
+ content_draft = res.body["content"]
+
+ assert content_no_draft == content_draft
+
+
+def test_different_draft_min_draft_max():
+ global server
+ test_values = [
+ (1, 2),
+ (1, 4),
+ (4, 8),
+ (4, 12),
+ (8, 16),
+ ]
+ last_content = None
+ for draft_min, draft_max in test_values:
+ server.stop()
+ server.draft_min = draft_min
+ server.draft_max = draft_max
+ server.start()
+ res = server.make_request("POST", "/completion", data={
+ "prompt": "I believe the meaning of life is",
+ "temperature": 0.0,
+ "top_k": 1,
+ })
+ assert res.status_code == 200
+ if last_content is not None:
+ assert last_content == res.body["content"]
+ last_content = res.body["content"]
+
+
+@pytest.mark.parametrize("n_slots,n_requests", [
+ (1, 2),
+ (2, 2),
+])
+def test_multi_requests_parallel(n_slots: int, n_requests: int):
+ global server
+ server.n_slots = n_slots
+ server.start()
+ tasks = []
+ for _ in range(n_requests):
+ tasks.append((server.make_request, ("POST", "/completion", {
+ "prompt": "I believe the meaning of life is",
+ "temperature": 0.0,
+ "top_k": 1,
+ })))
+ results = parallel_function_calls(tasks)
+ for res in results:
+ assert res.status_code == 200
+ assert match_regex("(wise|kind|owl|answer)+", res.body["content"])
model_alias: str | None = None
model_url: str | None = None
model_file: str | None = None
+ model_draft: str | None = None
n_threads: int | None = None
n_gpu_layer: int | None = None
n_batch: int | None = None
response_format: str | None = None
lora_files: List[str] | None = None
disable_ctx_shift: int | None = False
+ draft_min: int | None = None
+ draft_max: int | None = None
# session variables
process: subprocess.Popen | None = None
server_args.extend(["--model", self.model_file])
if self.model_url:
server_args.extend(["--model-url", self.model_url])
+ if self.model_draft:
+ server_args.extend(["--model-draft", self.model_draft])
if self.model_hf_repo:
server_args.extend(["--hf-repo", self.model_hf_repo])
if self.model_hf_file:
server_args.extend(["--no-context-shift"])
if self.api_key:
server_args.extend(["--api-key", self.api_key])
+ if self.draft_max:
+ server_args.extend(["--draft-max", self.draft_max])
+ if self.draft_min:
+ server_args.extend(["--draft-min", self.draft_min])
args = [str(arg) for arg in [server_path, *server_args]]
print(f"bench: starting server with: {' '.join(args)}")
raise TimeoutError(f"Server did not start within {timeout_seconds} seconds")
def stop(self) -> None:
- server_instances.remove(self)
+ if self in server_instances:
+ server_instances.remove(self)
if self.process:
print(f"Stopping server with pid={self.process.pid}")
self.process.kill()