]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : speed up tests (#15836)
authorXuan-Son Nguyen <redacted>
Sat, 6 Sep 2025 12:45:24 +0000 (19:45 +0700)
committerGitHub <redacted>
Sat, 6 Sep 2025 12:45:24 +0000 (14:45 +0200)
* server : speed up tests

* clean up

* restore timeout_seconds in some places

* flake8

* explicit offline

scripts/tool_bench.py
tools/server/tests/unit/test_basic.py
tools/server/tests/unit/test_template.py
tools/server/tests/unit/test_tool_call.py
tools/server/tests/unit/test_vision_api.py
tools/server/tests/utils.py

index 05d6dfc30a36ddd190473e8b46d5fb053fb69abe..e1512a49fd244a76e755ed1eac52e978fea58a65 100755 (executable)
@@ -53,7 +53,7 @@ import typer
 sys.path.insert(0, Path(__file__).parent.parent.as_posix())
 if True:
     from tools.server.tests.utils import ServerProcess
-    from tools.server.tests.unit.test_tool_call import TIMEOUT_SERVER_START, do_test_calc_result, do_test_hello_world, do_test_weather
+    from tools.server.tests.unit.test_tool_call import do_test_calc_result, do_test_hello_world, do_test_weather
 
 
 @contextmanager
@@ -335,7 +335,7 @@ def run(
                     # server.debug = True
 
                     with scoped_server(server):
-                        server.start(timeout_seconds=TIMEOUT_SERVER_START)
+                        server.start(timeout_seconds=15 * 60)
                         for ignore_chat_grammar in [False]:
                             run(
                                 server,
index c7b3af048916498b8c3c8d33b45039d9c02f6ca9..58ade52be655dd387824b69dcbb65168977a8bb1 100644 (file)
@@ -5,6 +5,12 @@ from utils import *
 server = ServerPreset.tinyllama2()
 
 
+@pytest.fixture(scope="session", autouse=True)
+def do_something():
+    # this will be run once per test session, before any tests
+    ServerPreset.load_all()
+
+
 @pytest.fixture(autouse=True)
 def create_server():
     global server
index c53eda5b884456da1d75a3124cf78d8bd9b716ec..e5185fcbfab850f5905331fd923b15bee911158d 100644 (file)
@@ -14,14 +14,11 @@ from utils import *
 
 server: ServerProcess
 
-TIMEOUT_SERVER_START = 15*60
-
 @pytest.fixture(autouse=True)
 def create_server():
     global server
     server = ServerPreset.tinyllama2()
     server.model_alias = "tinyllama-2"
-    server.server_port = 8081
     server.n_slots = 1
 
 
@@ -45,7 +42,7 @@ def test_reasoning_budget(template_name: str, reasoning_budget: int | None, expe
     server.jinja = True
     server.reasoning_budget = reasoning_budget
     server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
-    server.start(timeout_seconds=TIMEOUT_SERVER_START)
+    server.start()
 
     res = server.make_request("POST", "/apply-template", data={
         "messages": [
@@ -68,7 +65,7 @@ def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
     global server
     server.jinja = True
     server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
-    server.start(timeout_seconds=TIMEOUT_SERVER_START)
+    server.start()
 
     res = server.make_request("POST", "/apply-template", data={
         "messages": [
@@ -91,7 +88,7 @@ def test_add_generation_prompt(template_name: str, expected_generation_prompt: s
     global server
     server.jinja = True
     server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
-    server.start(timeout_seconds=TIMEOUT_SERVER_START)
+    server.start()
 
     res = server.make_request("POST", "/apply-template", data={
         "messages": [
index a3c3ccdf586ab138e8f5f45a2d6b8a64db784146..b8f0f10863fb83962368c2443480b98b3e92d6fe 100755 (executable)
@@ -12,7 +12,7 @@ from enum import Enum
 
 server: ServerProcess
 
-TIMEOUT_SERVER_START = 15*60
+TIMEOUT_START_SLOW = 15 * 60 # this is needed for real model tests
 TIMEOUT_HTTP_REQUEST = 60
 
 @pytest.fixture(autouse=True)
@@ -124,7 +124,7 @@ def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict,
     server.jinja = True
     server.n_predict = n_predict
     server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
-    server.start(timeout_seconds=TIMEOUT_SERVER_START)
+    server.start()
     do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED, temperature=0.0, top_k=1, top_p=1.0)
 
 
@@ -168,7 +168,7 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
     server.jinja = True
     server.n_predict = n_predict
     server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
-    server.start(timeout_seconds=TIMEOUT_SERVER_START)
+    server.start(timeout_seconds=TIMEOUT_START_SLOW)
     do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED)
 
 
@@ -240,7 +240,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
         assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
     elif isinstance(template_override, str):
         server.chat_template = template_override
-    server.start(timeout_seconds=TIMEOUT_SERVER_START)
+    server.start(timeout_seconds=TIMEOUT_START_SLOW)
     body = server.make_any_request("POST", "/v1/chat/completions", data={
         "max_tokens": n_predict,
         "messages": [
@@ -295,7 +295,7 @@ def test_completion_without_tool_call_fast(template_name: str, n_predict: int, t
     server.n_predict = n_predict
     server.jinja = True
     server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
-    server.start(timeout_seconds=TIMEOUT_SERVER_START)
+    server.start()
     do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
 
 
@@ -317,7 +317,7 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
     server.n_predict = n_predict
     server.jinja = True
     server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
-    server.start(timeout_seconds=TIMEOUT_SERVER_START)
+    server.start(timeout_seconds=TIMEOUT_START_SLOW)
     do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
 
 
@@ -377,7 +377,7 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
         assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
     elif isinstance(template_override, str):
         server.chat_template = template_override
-    server.start(timeout_seconds=TIMEOUT_SERVER_START)
+    server.start()
     do_test_weather(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
 
 
@@ -436,7 +436,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
         assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
     elif isinstance(template_override, str):
         server.chat_template = template_override
-    server.start(timeout_seconds=TIMEOUT_SERVER_START)
+    server.start(timeout_seconds=TIMEOUT_START_SLOW)
     do_test_calc_result(server, result_override, n_predict, stream=stream == CompletionMode.STREAMED)
 
 
@@ -524,7 +524,7 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
         assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
     elif isinstance(template_override, str):
         server.chat_template = template_override
-    server.start(timeout_seconds=TIMEOUT_SERVER_START)
+    server.start()
     body = server.make_any_request("POST", "/v1/chat/completions", data={
         "max_tokens": n_predict,
         "messages": [
@@ -597,7 +597,7 @@ def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | Non
         assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
     elif isinstance(template_override, str):
         server.chat_template = template_override
-    server.start(timeout_seconds=TIMEOUT_SERVER_START)
+    server.start(timeout_seconds=TIMEOUT_START_SLOW)
 
     do_test_hello_world(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
 
index 36d14b3885175c636798d082c74814b4a404c681..9408116d1cff33e113429ea508445eca7f15694b 100644 (file)
@@ -5,18 +5,31 @@ import requests
 
 server: ServerProcess
 
-IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
-IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
-
-response = requests.get(IMG_URL_0)
-response.raise_for_status() # Raise an exception for bad status codes
-IMG_BASE64_URI_0 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
-IMG_BASE64_0 = base64.b64encode(response.content).decode("utf-8")
-
-response = requests.get(IMG_URL_1)
-response.raise_for_status() # Raise an exception for bad status codes
-IMG_BASE64_URI_1 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
-IMG_BASE64_1 = base64.b64encode(response.content).decode("utf-8")
+def get_img_url(id: str) -> str:
+    IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
+    IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
+    if id == "IMG_URL_0":
+        return IMG_URL_0
+    elif id == "IMG_URL_1":
+        return IMG_URL_1
+    elif id == "IMG_BASE64_URI_0":
+        response = requests.get(IMG_URL_0)
+        response.raise_for_status() # Raise an exception for bad status codes
+        return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
+    elif id == "IMG_BASE64_0":
+        response = requests.get(IMG_URL_0)
+        response.raise_for_status() # Raise an exception for bad status codes
+        return base64.b64encode(response.content).decode("utf-8")
+    elif id == "IMG_BASE64_URI_1":
+        response = requests.get(IMG_URL_1)
+        response.raise_for_status() # Raise an exception for bad status codes
+        return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
+    elif id == "IMG_BASE64_1":
+        response = requests.get(IMG_URL_1)
+        response.raise_for_status() # Raise an exception for bad status codes
+        return base64.b64encode(response.content).decode("utf-8")
+    else:
+        return id
 
 JSON_MULTIMODAL_KEY = "multimodal_data"
 JSON_PROMPT_STRING_KEY = "prompt_string"
@@ -28,7 +41,7 @@ def create_server():
 
 def test_models_supports_multimodal_capability():
     global server
-    server.start() # vision model may take longer to load due to download size
+    server.start()
     res = server.make_request("GET", "/models", data={})
     assert res.status_code == 200
     model_info = res.body["models"][0]
@@ -38,7 +51,7 @@ def test_models_supports_multimodal_capability():
 
 def test_v1_models_supports_multimodal_capability():
     global server
-    server.start() # vision model may take longer to load due to download size
+    server.start()
     res = server.make_request("GET", "/v1/models", data={})
     assert res.status_code == 200
     model_info = res.body["models"][0]
@@ -50,10 +63,10 @@ def test_v1_models_supports_multimodal_capability():
     "prompt, image_url, success, re_content",
     [
         # test model is trained on CIFAR-10, but it's quite dumb due to small size
-        ("What is this:\n", IMG_URL_0,                True, "(cat)+"),
-        ("What is this:\n", "IMG_BASE64_URI_0",       True, "(cat)+"), # exceptional, so that we don't cog up the log
-        ("What is this:\n", IMG_URL_1,                True, "(frog)+"),
-        ("Test test\n",     IMG_URL_1,                True, "(frog)+"), # test invalidate cache
+        ("What is this:\n", "IMG_URL_0",              True, "(cat)+"),
+        ("What is this:\n", "IMG_BASE64_URI_0",       True, "(cat)+"),
+        ("What is this:\n", "IMG_URL_1",              True, "(frog)+"),
+        ("Test test\n",     "IMG_URL_1",              True, "(frog)+"), # test invalidate cache
         ("What is this:\n", "malformed",              False, None),
         ("What is this:\n", "https://google.com/404", False, None), # non-existent image
         ("What is this:\n", "https://ggml.ai",        False, None), # non-image data
@@ -62,9 +75,7 @@ def test_v1_models_supports_multimodal_capability():
 )
 def test_vision_chat_completion(prompt, image_url, success, re_content):
     global server
-    server.start(timeout_seconds=60) # vision model may take longer to load due to download size
-    if image_url == "IMG_BASE64_URI_0":
-        image_url = IMG_BASE64_URI_0
+    server.start()
     res = server.make_request("POST", "/chat/completions", data={
         "temperature": 0.0,
         "top_k": 1,
@@ -72,7 +83,7 @@ def test_vision_chat_completion(prompt, image_url, success, re_content):
             {"role": "user", "content": [
                 {"type": "text", "text": prompt},
                 {"type": "image_url", "image_url": {
-                    "url": image_url,
+                    "url": get_img_url(image_url),
                 }},
             ]},
         ],
@@ -90,19 +101,22 @@ def test_vision_chat_completion(prompt, image_url, success, re_content):
     "prompt, image_data, success, re_content",
     [
         # test model is trained on CIFAR-10, but it's quite dumb due to small size
-        ("What is this: <__media__>\n", IMG_BASE64_0,           True, "(cat)+"),
-        ("What is this: <__media__>\n", IMG_BASE64_1,           True, "(frog)+"),
+        ("What is this: <__media__>\n", "IMG_BASE64_0",         True, "(cat)+"),
+        ("What is this: <__media__>\n", "IMG_BASE64_1",         True, "(frog)+"),
         ("What is this: <__media__>\n", "malformed",            False, None), # non-image data
         ("What is this:\n",             "",                     False, None), # empty string
     ]
 )
 def test_vision_completion(prompt, image_data, success, re_content):
     global server
-    server.start() # vision model may take longer to load due to download size
+    server.start()
     res = server.make_request("POST", "/completions", data={
         "temperature": 0.0,
         "top_k": 1,
-        "prompt": { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
+        "prompt": {
+            JSON_PROMPT_STRING_KEY: prompt,
+            JSON_MULTIMODAL_KEY: [ get_img_url(image_data) ],
+        },
     })
     if success:
         assert res.status_code == 200
@@ -116,17 +130,18 @@ def test_vision_completion(prompt, image_data, success, re_content):
     "prompt, image_data, success",
     [
         # test model is trained on CIFAR-10, but it's quite dumb due to small size
-        ("What is this: <__media__>\n", IMG_BASE64_0,           True), # exceptional, so that we don't cog up the log
-        ("What is this: <__media__>\n", IMG_BASE64_1,           True),
+        ("What is this: <__media__>\n", "IMG_BASE64_0",         True),
+        ("What is this: <__media__>\n", "IMG_BASE64_1",         True),
         ("What is this: <__media__>\n", "malformed",            False), # non-image data
         ("What is this:\n",             "base64",               False), # non-image data
     ]
 )
 def test_vision_embeddings(prompt, image_data, success):
     global server
-    server.server_embeddings=True
-    server.n_batch=512
-    server.start() # vision model may take longer to load due to download size
+    server.server_embeddings = True
+    server.n_batch = 512
+    server.start()
+    image_data = get_img_url(image_data)
     res = server.make_request("POST", "/embeddings", data={
         "content": [
             { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
index cda7434d7c2011e353e2a7d58f0b53a425e782c1..10997ef57cac7b63ec80a7e33ad99029131e79f6 100644 (file)
@@ -26,7 +26,7 @@ from re import RegexFlag
 import wget
 
 
-DEFAULT_HTTP_TIMEOUT = 30
+DEFAULT_HTTP_TIMEOUT = 60
 
 
 class ServerResponse:
@@ -45,6 +45,7 @@ class ServerProcess:
     model_alias: str = "tinyllama-2"
     temperature: float = 0.8
     seed: int = 42
+    offline: bool = False
 
     # custom options
     model_alias: str | None = None
@@ -118,6 +119,8 @@ class ServerProcess:
             "--seed",
             self.seed,
         ]
+        if self.offline:
+            server_args.append("--offline")
         if self.model_file:
             server_args.extend(["--model", self.model_file])
         if self.model_url:
@@ -392,6 +395,19 @@ server_instances: Set[ServerProcess] = set()
 
 
 class ServerPreset:
+    @staticmethod
+    def load_all() -> None:
+        """ Load all server presets to ensure model files are cached. """
+        servers: List[ServerProcess] = [
+            method()
+            for name, method in ServerPreset.__dict__.items()
+            if callable(method) and name != "load_all"
+        ]
+        for server in servers:
+            server.offline = False
+            server.start()
+            server.stop()
+
     @staticmethod
     def tinyllama2() -> ServerProcess:
         server = ServerProcess()
@@ -408,6 +424,7 @@ class ServerPreset:
     @staticmethod
     def bert_bge_small() -> ServerProcess:
         server = ServerProcess()
+        server.offline = True # will be downloaded by load_all()
         server.model_hf_repo = "ggml-org/models"
         server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
         server.model_alias = "bert-bge-small"
@@ -422,6 +439,7 @@ class ServerPreset:
     @staticmethod
     def bert_bge_small_with_fa() -> ServerProcess:
         server = ServerProcess()
+        server.offline = True # will be downloaded by load_all()
         server.model_hf_repo = "ggml-org/models"
         server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
         server.model_alias = "bert-bge-small"
@@ -437,6 +455,7 @@ class ServerPreset:
     @staticmethod
     def tinyllama_infill() -> ServerProcess:
         server = ServerProcess()
+        server.offline = True # will be downloaded by load_all()
         server.model_hf_repo = "ggml-org/models"
         server.model_hf_file = "tinyllamas/stories260K-infill.gguf"
         server.model_alias = "tinyllama-infill"
@@ -451,6 +470,7 @@ class ServerPreset:
     @staticmethod
     def stories15m_moe() -> ServerProcess:
         server = ServerProcess()
+        server.offline = True # will be downloaded by load_all()
         server.model_hf_repo = "ggml-org/stories15M_MOE"
         server.model_hf_file = "stories15M_MOE-F16.gguf"
         server.model_alias = "stories15m-moe"
@@ -465,6 +485,7 @@ class ServerPreset:
     @staticmethod
     def jina_reranker_tiny() -> ServerProcess:
         server = ServerProcess()
+        server.offline = True # will be downloaded by load_all()
         server.model_hf_repo = "ggml-org/models"
         server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf"
         server.model_alias = "jina-reranker"
@@ -478,6 +499,7 @@ class ServerPreset:
     @staticmethod
     def tinygemma3() -> ServerProcess:
         server = ServerProcess()
+        server.offline = True # will be downloaded by load_all()
         # mmproj is already provided by HF registry API
         server.model_hf_repo = "ggml-org/tinygemma3-GGUF"
         server.model_hf_file = "tinygemma3-Q8_0.gguf"