]
)
def test_router_chat_completion_stream(model: str, success: bool):
- # TODO: make sure the model is in cache (ie. ServerProcess.load_all()) before starting the router server
global server
server.start()
content = ""
else:
assert ex is not None
assert content == ""
+
+
+def _get_model_status(model_id: str) -> str:
+ res = server.make_request("GET", "/models")
+ assert res.status_code == 200
+ for item in res.body.get("data", []):
+ if item.get("id") == model_id or item.get("model") == model_id:
+ return item["status"]["value"]
+ raise AssertionError(f"Model {model_id} not found in /models response")
+
+
+def _wait_for_model_status(model_id: str, desired: set[str], timeout: int = 60) -> str:
+ deadline = time.time() + timeout
+ last_status = None
+ while time.time() < deadline:
+ last_status = _get_model_status(model_id)
+ if last_status in desired:
+ return last_status
+ time.sleep(1)
+ raise AssertionError(
+ f"Timed out waiting for {model_id} to reach {desired}, last status: {last_status}"
+ )
+
+
+def _load_model_and_wait(
+ model_id: str, timeout: int = 60, headers: dict | None = None
+) -> None:
+ load_res = server.make_request(
+ "POST", "/models/load", data={"model": model_id}, headers=headers
+ )
+ assert load_res.status_code == 200
+ assert isinstance(load_res.body, dict)
+ assert load_res.body.get("success") is True
+ _wait_for_model_status(model_id, {"loaded"}, timeout=timeout)
+
+
+def test_router_unload_model():
+ global server
+ server.start()
+ model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
+
+ _load_model_and_wait(model_id)
+
+ unload_res = server.make_request("POST", "/models/unload", data={"model": model_id})
+ assert unload_res.status_code == 200
+ assert unload_res.body.get("success") is True
+ _wait_for_model_status(model_id, {"unloaded"})
+
+
+def test_router_models_max_evicts_lru():
+ global server
+ server.models_max = 2
+ server.start()
+
+ candidate_models = [
+ "ggml-org/tinygemma3-GGUF:Q8_0",
+ "ggml-org/test-model-stories260K",
+ "ggml-org/test-model-stories260K-infill",
+ ]
+
+ # Load only the first 2 models to fill the cache
+ first, second, third = candidate_models[:3]
+
+ _load_model_and_wait(first, timeout=120)
+ _load_model_and_wait(second, timeout=120)
+
+ # Verify both models are loaded
+ assert _get_model_status(first) == "loaded"
+ assert _get_model_status(second) == "loaded"
+
+ # Load the third model - this should trigger LRU eviction of the first model
+ _load_model_and_wait(third, timeout=120)
+
+ # Verify eviction: third is loaded, first was evicted
+ assert _get_model_status(third) == "loaded"
+ assert _get_model_status(first) == "unloaded"
+
+
+def test_router_no_models_autoload():
+ global server
+ server.no_models_autoload = True
+ server.start()
+ model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
+
+ res = server.make_request(
+ "POST",
+ "/v1/chat/completions",
+ data={
+ "model": model_id,
+ "messages": [{"role": "user", "content": "hello"}],
+ "max_tokens": 4,
+ },
+ )
+ assert res.status_code == 400
+ assert "error" in res.body
+
+ _load_model_and_wait(model_id)
+
+ success_res = server.make_request(
+ "POST",
+ "/v1/chat/completions",
+ data={
+ "model": model_id,
+ "messages": [{"role": "user", "content": "hello"}],
+ "max_tokens": 4,
+ },
+ )
+ assert success_res.status_code == 200
+ assert "error" not in success_res.body
+
+
+def test_router_api_key_required():
+ global server
+ server.api_key = "sk-router-secret"
+ server.start()
+
+ model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
+ auth_headers = {"Authorization": f"Bearer {server.api_key}"}
+
+ res = server.make_request(
+ "POST",
+ "/v1/chat/completions",
+ data={
+ "model": model_id,
+ "messages": [{"role": "user", "content": "hello"}],
+ "max_tokens": 4,
+ },
+ )
+ assert res.status_code == 401
+ assert res.body.get("error", {}).get("type") == "authentication_error"
+
+ _load_model_and_wait(model_id, headers=auth_headers)
+
+ authed = server.make_request(
+ "POST",
+ "/v1/chat/completions",
+ headers=auth_headers,
+ data={
+ "model": model_id,
+ "messages": [{"role": "user", "content": "hello"}],
+ "max_tokens": 4,
+ },
+ )
+ assert authed.status_code == 200
+ assert "error" not in authed.body
import os
import re
import json
+from json import JSONDecodeError
import sys
import requests
import time
pooling: str | None = None
draft: int | None = None
api_key: str | None = None
+ models_dir: str | None = None
+ models_max: int | None = None
+ no_models_autoload: bool | None = None
lora_files: List[str] | None = None
enable_ctx_shift: int | None = False
draft_min: int | None = None
server_args.extend(["--hf-repo", self.model_hf_repo])
if self.model_hf_file:
server_args.extend(["--hf-file", self.model_hf_file])
+ if self.models_dir:
+ server_args.extend(["--models-dir", self.models_dir])
+ if self.models_max is not None:
+ server_args.extend(["--models-max", self.models_max])
if self.n_batch:
server_args.extend(["--batch-size", self.n_batch])
if self.n_ubatch:
server_args.extend(["--draft-min", self.draft_min])
if self.no_webui:
server_args.append("--no-webui")
+ if self.no_models_autoload:
+ server_args.append("--no-models-autoload")
if self.jinja:
server_args.append("--jinja")
else:
result = ServerResponse()
result.headers = dict(response.headers)
result.status_code = response.status_code
- result.body = response.json() if parse_body else None
+ if parse_body:
+ try:
+ result.body = response.json()
+ except JSONDecodeError:
+ result.body = response.text
+ else:
+ result.body = None
print("Response from server", json.dumps(result.body, indent=2))
return result
@staticmethod
def tinyllama2() -> ServerProcess:
server = ServerProcess()
- server.model_hf_repo = "ggml-org/models"
- server.model_hf_file = "tinyllamas/stories260K.gguf"
+ server.offline = True # will be downloaded by load_all()
+ server.model_hf_repo = "ggml-org/test-model-stories260K"
+ server.model_hf_file = None
server.model_alias = "tinyllama-2"
server.n_ctx = 512
server.n_batch = 32
def tinyllama_infill() -> ServerProcess:
server = ServerProcess()
server.offline = True # will be downloaded by load_all()
- server.model_hf_repo = "ggml-org/models"
- server.model_hf_file = "tinyllamas/stories260K-infill.gguf"
+ server.model_hf_repo = "ggml-org/test-model-stories260K-infill"
+ server.model_hf_file = None
server.model_alias = "tinyllama-infill"
server.n_ctx = 2048
server.n_batch = 1024
@staticmethod
def router() -> ServerProcess:
server = ServerProcess()
+ server.offline = True # will be downloaded by load_all()
# router server has no models
server.model_file = None
server.model_alias = None