]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server-bench: external OAI servers, sqlite (#15179)
authorJohannes Gäßler <redacted>
Fri, 8 Aug 2025 21:04:36 +0000 (23:04 +0200)
committerGitHub <redacted>
Fri, 8 Aug 2025 21:04:36 +0000 (23:04 +0200)
* server-bench: external OAI servers, sqlite

* Update scripts/server-bench.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update scripts/server-bench.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update scripts/server-bench.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* raise_for_status

---------

Co-authored-by: Sigbjørn Skjæret <redacted>
scripts/server-bench.py

index 9326be8d57bbee359da90e43d7e06b2d1fd1c9ee..a71602017340afd65adb9541c3bf2ca3750b85f9 100755 (executable)
@@ -4,6 +4,7 @@ import argparse
 import json
 import os
 import random
+import sqlite3
 import subprocess
 from time import sleep, time
 from typing import Optional, Union
@@ -47,6 +48,8 @@ def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]:
 
 
 def get_server(path_server: str, path_log: Optional[str]) -> dict:
+    if path_server.startswith("http://") or path_server.startswith("https://"):
+        return {"process": None, "address": path_server, "fout": None}
     if os.environ.get("LLAMA_ARG_HOST") is None:
         logger.info("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1")
         os.environ["LLAMA_ARG_HOST"] = "127.0.0.1"
@@ -89,15 +92,13 @@ def get_prompt_length(data: dict) -> int:
         f"{server_address}/apply-template",
         json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
     )
-    if response.status_code != 200:
-        raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
+    response.raise_for_status()
     prompt: str = json.loads(response.text)["prompt"]
     response = session.post(
         f"{server_address}/tokenize",
         json={"content": prompt, "add_special": True}
     )
-    if response.status_code != 200:
-        raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
+    response.raise_for_status()
     tokens: list[str] = json.loads(response.text)["tokens"]
     return len(tokens)
 
@@ -107,7 +108,12 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
     server_address: str = data["server_address"]
 
     t_submit = time()
-    if data["synthetic_prompt"]:
+    if data["external_server"]:
+        json_data: dict = {
+            "prompt": data["prompt"], "ignore_eos": True,
+            "seed": data["seed"], "max_tokens": data["n_predict"], "stream": True}
+        response = session.post(f"{server_address}/v1/completions", json=json_data, stream=True)
+    elif data["synthetic_prompt"]:
         json_data: dict = {
             "prompt": data["prompt"], "ignore_eos": True, "cache_prompt": False,
             "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
@@ -117,34 +123,38 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
             f"{server_address}/apply-template",
             json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
         )
-        if response.status_code != 200:
-            raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
+        response.raise_for_status()
         prompt: str = json.loads(response.text)["prompt"]
 
         json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
         response = session.post(f"{server_address}/completion", json=json_data, stream=True)
+    response.raise_for_status()
 
+    lines = []
     token_arrival_times: list[float] = []
     for line in response.iter_lines(decode_unicode=False):
         if not line.startswith(b"data: "):
             continue
+        lines.append(line)
         token_arrival_times.append(time())
     token_arrival_times = token_arrival_times[:-1]
-
-    if response.status_code != 200:
-        raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
+    if len(lines) > 1 and "timings" in json.loads(lines[-2][6:]):
+        token_arrival_times = token_arrival_times[:-1]
 
     return (t_submit, token_arrival_times)
 
 
-def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_prompts: int, n_predict: int, n_predict_min: int, seed_offset: int):
+def benchmark(
+        path_server: str, path_log: Optional[str], path_db: Optional[str], name: Optional[str], prompt_source: str, n_prompts: int,
+        n_predict: int, n_predict_min: int, seed_offset: int):
+    external_server: bool = path_server.startswith("http://") or path_server.startswith("https://")
     if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
         logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
         os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
-    if os.environ.get("LLAMA_ARG_N_GPU_LAYERS") is None:
+    if not external_server and os.environ.get("LLAMA_ARG_N_GPU_LAYERS") is None:
         logger.info("LLAMA_ARG_N_GPU_LAYERS not explicitly set, using 999")
         os.environ["LLAMA_ARG_N_GPU_LAYERS"] = "999"
-    if os.environ.get("LLAMA_ARG_FLASH_ATTN") is None:
+    if not external_server and os.environ.get("LLAMA_ARG_FLASH_ATTN") is None:
         logger.info("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'")
         os.environ["LLAMA_ARG_FLASH_ATTN"] = "true"
 
@@ -165,7 +175,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
     else:
         n_predict_min = n_predict
 
-    if os.environ.get("LLAMA_ARG_CTX_SIZE") is None:
+    if not external_server and os.environ.get("LLAMA_ARG_CTX_SIZE") is None:
         context_per_slot: int = int(1.05 * (n_predict + (np.max(prompt_n) if synthetic_prompts else 2048)))
         context_total: int = context_per_slot * parallel
         os.environ["LLAMA_ARG_CTX_SIZE"] = str(context_total)
@@ -176,6 +186,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
     try:
         server = get_server(path_server, path_log)
         server_address: str = server["address"]
+        assert external_server == (server["process"] is None)
 
         adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel)  # type: ignore
         session = requests.Session()
@@ -188,8 +199,9 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
             if seed_offset >= 0:
                 random.seed(3 * (seed_offset + 1000 * i) + 1)
             data.append({
-                "session": session, "server_address": server_address, "prompt": p, "synthetic_prompt": synthetic_prompts,
-                "n_predict": random.randint(n_predict_min, n_predict), "seed": (3 * (seed_offset + 1000 * i) + 2) if seed_offset >= 0 else -1})
+                "session": session, "server_address": server_address, "external_server": external_server, "prompt": p,
+                "synthetic_prompt": synthetic_prompts, "n_predict": random.randint(n_predict_min, n_predict),
+                "seed": (3 * (seed_offset + 1000 * i) + 2) if seed_offset >= 0 else -1})
 
         if not synthetic_prompts:
             logger.info("Getting the prompt lengths...")
@@ -199,7 +211,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
         t0 = time()
         results: list[tuple[float, list[float]]] = thread_map(send_prompt, data, max_workers=parallel, chunksize=1)
     finally:
-        if server is not None:
+        if server is not None and server["process"] is not None:
             server["process"].terminate()
             server["process"].wait()
         if session is not None:
@@ -233,15 +245,24 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
     logger.info(f"Average generation depth:          {depth_sum / token_t.shape[0]:.2f} tokens")
     logger.info(f"Average total generation speed:    {token_t.shape[0] / token_t_last:.2f} tokens/s")
     logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot")
-    logger.info("")
-    logger.info(
-        "The above numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
-        "particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model).")
+
+    if path_db is not None:
+        con = sqlite3.connect(path_db)
+        cursor = con.cursor()
+        cursor.execute(
+            "CREATE TABLE IF NOT EXISTS server_bench"
+            "(name TEXT, n_parallel INTEGER, prompt_source TEXT, n_prompts INTEGER, "
+            "n_predict INTEGER, n_predict_min INTEGER, seed_offset INTEGER, runtime REAL);")
+        cursor.execute(
+            "INSERT INTO server_bench VALUES (?, ?, ?, ?, ?, ?, ?, ?);",
+            [name, parallel, prompt_source, n_prompts, n_predict, n_predict_min, seed_offset, token_t_last])
+        con.commit()
 
     plt.figure()
     plt.scatter(prompt_n, 1e3 * prompt_t, s=10.0, marker=".", alpha=0.25)
     plt.xlim(0, 1.05e0 * np.max(prompt_n))
     plt.ylim(0, 1.05e3 * np.max(prompt_t))
+    plt.title(name or "")
     plt.xlabel("Prompt length [tokens]")
     plt.ylabel("Time to first token [ms]")
     plt.savefig("prompt_time.png", dpi=240)
@@ -250,6 +271,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
     plt.figure()
     plt.hist(token_t, np.arange(0, bin_max))
     plt.xlim(0, bin_max + 1)
+    plt.title(name or "")
     plt.xlabel("Time [s]")
     plt.ylabel("Num. tokens generated per second")
     plt.savefig("gen_rate.png", dpi=240)
@@ -259,9 +281,13 @@ if __name__ == "__main__":
     parser = argparse.ArgumentParser(
         description="Tool for benchmarking the throughput of the llama.cpp HTTP server. "
         "Results are printed to console and visualized as plots (saved to current working directory). "
-        "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help).")
+        "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help). "
+        "The reported numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
+        "particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model).")
     parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary")
     parser.add_argument("--path_log", type=str, default="server-bench-{port}.log", help="Path to the model to use for the benchmark")
+    parser.add_argument("--path_db", type=str, default=None, help="Path to an sqlite database to store the benchmark results in")
+    parser.add_argument("--name", type=str, default=None, help="Name to label plots and database entries with")
     parser.add_argument(
         "--prompt_source", type=str, default="rng-1024-2048",
         help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "