]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
scripts: synthetic prompt mode for server-bench.py (#14695)
authorJohannes Gäßler <redacted>
Wed, 16 Jul 2025 07:33:28 +0000 (09:33 +0200)
committerGitHub <redacted>
Wed, 16 Jul 2025 07:33:28 +0000 (09:33 +0200)
scripts/server-bench.py [changed mode: 0644->0755]
tools/server/README.md

old mode 100644 (file)
new mode 100755 (executable)
index 52163d6..3afad66
@@ -2,9 +2,11 @@
 
 import argparse
 import json
+import os
+import random
 import subprocess
 from time import sleep, time
-from typing import Optional
+from typing import Optional, Union
 
 import datasets
 import logging
@@ -18,31 +20,39 @@ logging.basicConfig(level=logging.INFO, format='%(message)s')
 logger = logging.getLogger("server-bench")
 
 
-def get_prompts(n_prompts: int) -> list[str]:
-    logger.info("Loading MMLU dataset...")
-    ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"]  # type: ignore
+def get_prompts_text(dataset_name: str, n_prompts: int) -> Optional[list[str]]:
+    ret = []
+    if dataset_name.lower() == "mmlu":
+        logger.info("Loading MMLU dataset...")
+        ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"]  # type: ignore
+    else:
+        return None
     if n_prompts >= 0:
         ret = ret[:n_prompts]
     return ret
 
 
-def get_server(path_server: str, path_model: str, path_log: Optional[str], port: int, n_gpu_layers: int, parallel: int, ctx_size: int) -> dict:
+def get_prompt_lengths_rng(n_prompts: int, prompt_length_min: int, prompt_length_max: int) -> list[int]:
+    assert n_prompts >= 0
+    ret: list[int] = []
+    for i in range(n_prompts):
+        random.seed(13 * i + 0)
+        ret.append(random.randint(prompt_length_min, prompt_length_max))
+    return ret
+
+
+def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]:
+    return [[random.randint(100, 10000) for _ in range(pl)] for pl in prompt_lengths]
+
+
+def get_server(path_server: str, path_log: Optional[str]) -> dict:
     logger.info("Starting the llama.cpp server...")
-    address = f"http://localhost:{port}"
-
-    popen_args: list[str] = [
-        path_server,
-        "--flash-attn",
-        "--n-gpu-layers", str(n_gpu_layers),
-        "--parallel", str(parallel),
-        "--ctx-size", str(parallel * ctx_size),
-        "--model", path_model,
-        "--port", str(port),
-        "--swa-full",  # FIXME performance bad otherwise
-        # "--attn-streams",
-    ]
-    fout = open("bench.log", "w") if path_log is not None else subprocess.DEVNULL
-    process = subprocess.Popen(popen_args, stdout=fout, stderr=subprocess.STDOUT)
+    hostname: str = os.environ.get("LLAMA_ARG_HOST", "127.0.0.1")
+    port: str = os.environ.get("LLAMA_ARG_PORT", "8080")
+    address: str = f"http://{hostname}:{port}"
+
+    fout = open(path_log, "w") if path_log is not None else subprocess.DEVNULL
+    process = subprocess.Popen([path_server], stdout=fout, stderr=subprocess.STDOUT)
 
     n_failures: int = 0
     while True:
@@ -50,14 +60,14 @@ def get_server(path_server: str, path_model: str, path_log: Optional[str], port:
             sleep(1.0)
             exit_code = process.poll()
             if exit_code is not None:
-                raise RuntimeError(f"llama.cpp server for {path_model} exited unexpectedly with exit code {exit_code}")
+                raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}, see {path_log}")
             response = requests.get(f"{address}/health")
             if response.status_code == 200:
                 break
         except requests.ConnectionError:
             n_failures += 1
             if n_failures >= 10:
-                raise RuntimeError(f"llama.cpp server for {path_model} is not healthy after 10 seconds")
+                raise RuntimeError("llama.cpp server is not healthy after 10 seconds")
 
     return {"process": process, "address": address, "fout": fout}
 
@@ -87,58 +97,97 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
     session = data["session"]
     server_address: str = data["server_address"]
 
-    response = session.post(
-        f"{server_address}/apply-template",
-        json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
-    )
-    if response.status_code != 200:
-        raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
-    prompt: str = json.loads(response.text)["prompt"]
-
-    json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
-    response = session.post(f"{server_address}/completion", json=json_data, stream=True)
+    t_submit = time()
+    if data["synthetic_prompt"]:
+        json_data: dict = {
+            "prompt": data["prompt"], "ignore_eos": True, "cache_prompt": False,
+            "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
+        response = session.post(f"{server_address}/completion", json=json_data, stream=True)
+    else:
+        response = session.post(
+            f"{server_address}/apply-template",
+            json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
+        )
+        if response.status_code != 200:
+            raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
+        prompt: str = json.loads(response.text)["prompt"]
+
+        json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
+        response = session.post(f"{server_address}/completion", json=json_data, stream=True)
 
-    last_valid_line: str = ""
     token_arrival_times: list[float] = []
-    for line in response.iter_lines(decode_unicode=True):
-        if not line.startswith("data: "):
+    for line in response.iter_lines(decode_unicode=False):
+        if not line.startswith(b"data: "):
             continue
-        last_valid_line = line
         token_arrival_times.append(time())
     token_arrival_times = token_arrival_times[:-1]
 
     if response.status_code != 200:
         raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
-    timings: dict = json.loads(last_valid_line[6:])["timings"]
 
-    return (timings["prompt_ms"], token_arrival_times)
-
-
-def benchmark(path_server: str, path_model: str, path_log: Optional[str], port: int, n_gpu_layers: int, parallel: int, ctx_size: int, n_prompts: int, n_predict: int):
-    num_workers: int = parallel + 1
-    prompts: list[str] = get_prompts(n_prompts)
+    return (t_submit, token_arrival_times)
+
+
+def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_prompts: int, n_predict: int, n_predict_min: int):
+    if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
+        logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
+        os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
+    if os.environ.get("LLAMA_ARG_N_GPU_LAYERS") is None:
+        logger.info("LLAMA_ARG_N_GPU_LAYERS not explicitly set, using 999")
+        os.environ["LLAMA_ARG_N_GPU_LAYERS"] = "999"
+    if os.environ.get("LLAMA_ARG_FLASH_ATTN") is None:
+        logger.info("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'")
+        os.environ["LLAMA_ARG_FLASH_ATTN"] = "true"
+
+    parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL", 1))
+    prompts: Union[None, list[str], list[list[int]]] = get_prompts_text(prompt_source, n_prompts)
+    synthetic_prompts: bool = prompts is None
+    prompt_n = []
+
+    if synthetic_prompts:
+        prompt_source_split: list[str] = prompt_source.split("-")
+        assert len(prompt_source_split) == 3
+        assert prompt_source_split[0].lower() == "rng"
+        prompt_length_min: int = int(prompt_source_split[1])
+        prompt_length_max: int = int(prompt_source_split[2])
+        logger.info("Generating random prompts...")
+        prompt_n = get_prompt_lengths_rng(n_prompts, prompt_length_min, prompt_length_max)
+        prompts = get_prompts_rng(prompt_n)
+    else:
+        n_predict_min = n_predict
+
+    if os.environ.get("LLAMA_ARG_CTX_SIZE") is None:
+        context_per_slot: int = int(1.05 * (n_predict + (np.max(prompt_n) if synthetic_prompts else 2048)))
+        context_total: int = context_per_slot * parallel
+        os.environ["LLAMA_ARG_CTX_SIZE"] = str(context_total)
+        logger.info(f"LLAMA_ARG_CTX_SIZE not explicitly set, using {context_total} ({context_per_slot} per slot).")
 
     server: Optional[dict] = None
     session = None
     try:
-        server = get_server(path_server, path_model, path_log, port, n_gpu_layers, parallel, ctx_size)
+        server = get_server(path_server, path_log)
         server_address: str = server["address"]
 
-        adapter = requests.adapters.HTTPAdapter(pool_connections=num_workers, pool_maxsize=num_workers)  # type: ignore
+        adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel)  # type: ignore
         session = requests.Session()
         session.mount("http://", adapter)
         session.mount("https://", adapter)
 
         data: list[dict] = []
+
         for i, p in enumerate(prompts):
-            data.append({"session": session, "server_address": server_address, "prompt": p, "n_predict": n_predict, "seed": i})
+            random.seed(13 * i + 1)
+            data.append({
+                "session": session, "server_address": server_address, "prompt": p, "synthetic_prompt": synthetic_prompts,
+                "n_predict": random.randint(n_predict_min, n_predict), "seed": 13 * i + 2})
 
-        logger.info("Getting the prompt lengths...")
-        prompt_n = [get_prompt_length(d) for d in data]
+        if not synthetic_prompts:
+            logger.info("Getting the prompt lengths...")
+            prompt_n = [get_prompt_length(d) for d in data]
 
         logger.info("Starting the benchmark...\n")
         t0 = time()
-        results: list[tuple[int, list[float]]] = thread_map(send_prompt, data, max_workers=num_workers, chunksize=1)
+        results: list[tuple[float, list[float]]] = thread_map(send_prompt, data, max_workers=parallel, chunksize=1)
     finally:
         if server is not None:
             server["process"].terminate()
@@ -146,17 +195,18 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
         if session is not None:
             session.close()
 
-    prompt_ms = []
+    prompt_t = []
     token_t = []
     depth_sum: int = 0
-    for pn, (pms, tat) in zip(prompt_n, results):
-        prompt_ms.append(pms)
+    for pn, (t_submit, tat) in zip(prompt_n, results):
+        prompt_t.append(tat[0] - t_submit)
         token_t += tat
         n_tokens: int = len(tat)
         depth_sum += n_tokens * pn
         depth_sum += n_tokens * (n_tokens + 1) // 2
+    assert len(token_t) > 0
     prompt_n = np.array(prompt_n, dtype=np.int64)
-    prompt_ms = np.array(prompt_ms, dtype=np.float64)
+    prompt_t = np.array(prompt_t, dtype=np.float64)
     token_t = np.array(token_t, dtype=np.float64)
 
     token_t -= t0
@@ -167,18 +217,21 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
     logger.info(f"Request throughput:                {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last/60):.2f} requests/min")
     logger.info(f"Total prompt length:               {np.sum(prompt_n)} tokens")
     logger.info(f"Average prompt length:             {np.mean(prompt_n):.2f} tokens")
-    logger.info(f"Average prompt latency:            {np.mean(prompt_ms):.2f} ms")
-    logger.info(f"Average prompt speed:              {np.sum(prompt_n) / (1e-3 * np.sum(prompt_ms)):.2f} tokens/s")
+    logger.info(f"Average prompt latency:            {1e3 * np.mean(prompt_t):.2f} ms")
+    logger.info(f"Average prompt speed:              {np.sum(prompt_n) / np.sum(prompt_t):.2f} tokens/s")
     logger.info(f"Total generated tokens:            {token_t.shape[0]}")
     logger.info(f"Average generation depth:          {depth_sum / token_t.shape[0]:.2f} tokens")
     logger.info(f"Average total generation speed:    {token_t.shape[0] / token_t_last:.2f} tokens/s")
     logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot")
+    logger.info("")
+    logger.info(
+        "The above numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
+        "particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model).")
 
     plt.figure()
-    plt.scatter(prompt_n, prompt_ms, s=10.0, marker=".", alpha=0.25)
-    plt.xlim(0, 1.05 * np.max(prompt_n))
-    plt.ylim(0, 1.05 * np.max(prompt_ms))
-    plt.title(path_model)
+    plt.scatter(prompt_n, 1e3 * prompt_t, s=10.0, marker=".", alpha=0.25)
+    plt.xlim(0, 1.05e0 * np.max(prompt_n))
+    plt.ylim(0, 1.05e3 * np.max(prompt_t))
     plt.xlabel("Prompt length [tokens]")
     plt.ylabel("Time to first token [ms]")
     plt.savefig("prompt_time.png", dpi=240)
@@ -187,7 +240,6 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
     plt.figure()
     plt.hist(token_t, np.arange(0, bin_max))
     plt.xlim(0, bin_max + 1)
-    plt.title(path_model)
     plt.xlabel("Time [s]")
     plt.ylabel("Num. tokens generated per second")
     plt.savefig("gen_rate.png", dpi=240)
@@ -196,15 +248,18 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(
         description="Tool for benchmarking the throughput of the llama.cpp HTTP server. "
-        "Results are printed to console and visualized as plots (saved to current working directory).")
+        "Results are printed to console and visualized as plots (saved to current working directory). "
+        "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help).")
     parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary")
-    parser.add_argument("--path_model", type=str, required=True, help="Path to the model to use for the benchmark")
-    parser.add_argument("--path_log", type=str, default=None, help="Path to the model to use for the benchmark")
-    parser.add_argument("--port", type=int, default=18725, help="Port to use for the server during the benchmark")
-    parser.add_argument("--n_gpu_layers", type=int, default=999, help="Number of GPU layers for the server")
-    parser.add_argument("--parallel", type=int, default=16, help="Number of slots for the server")
-    parser.add_argument("--ctx_size", type=int, default=4096, help="Server context size per slot")
-    parser.add_argument("--n_prompts", type=int, default=1000, help="Number of prompts to evaluate")
+    parser.add_argument("--path_log", type=str, default="server-bench.log", help="Path to the model to use for the benchmark")
+    parser.add_argument(
+        "--prompt_source", type=str, default="rng-1024-2048",
+        help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "
+        "rng-MIN-MAX for synthetic prompts with random lengths in the interval [MIN, MAX]")
+    parser.add_argument("--n_prompts", type=int, default=100, help="Number of prompts to evaluate")
     parser.add_argument("--n_predict", type=int, default=2048, help="Max. number of tokens to predict per prompt")
+    parser.add_argument(
+        "--n_predict_min", type=int, default=1024,
+        help="Min. number of tokens to predict per prompt (supported for synthetic prompts only)")
     args = parser.parse_args()
     benchmark(**vars(args))
index 6f962664f67747f5c4386b646d22fa1dc650342c..e29511cb1b457b7873401db5b770e5c9e50ab4e4 100644 (file)
@@ -7,7 +7,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp.
 **Features:**
  * LLM inference of F16 and quantized models on GPU and CPU
  * [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes
- * Reranking endoint (https://github.com/ggml-org/llama.cpp/pull/9510)
+ * Reranking endpoint (https://github.com/ggml-org/llama.cpp/pull/9510)
  * Parallel decoding with multi-user support
  * Continuous batching
  * Multimodal ([documentation](../../docs/multimodal.md)) / with OpenAI-compatible API support