]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
scripts: benchmark for HTTP server throughput (#14668)
authorJohannes Gäßler <redacted>
Mon, 14 Jul 2025 11:14:30 +0000 (13:14 +0200)
committerGitHub <redacted>
Mon, 14 Jul 2025 11:14:30 +0000 (13:14 +0200)
* scripts: benchmark for HTTP server throughput

* fix server connection reset

requirements/requirements-all.txt
requirements/requirements-server-bench.txt [new file with mode: 0644]
scripts/server-bench.py [new file with mode: 0644]
tools/server/utils.hpp

index 9fa7d4d0abdec940437f2f693abcbe6148b86091..56b6752ac0645600b3a20aa97b91e85457399511 100644 (file)
@@ -3,6 +3,7 @@
 -r ../tools/server/tests/requirements.txt
 
 -r ./requirements-compare-llama-bench.txt
+-r ./requirements-server-bench.txt
 -r ./requirements-pydantic.txt
 -r ./requirements-test-tokenizer-random.txt
 
diff --git a/requirements/requirements-server-bench.txt b/requirements/requirements-server-bench.txt
new file mode 100644 (file)
index 0000000..ea5849f
--- /dev/null
@@ -0,0 +1,5 @@
+datasets~=3.2.0
+matplotlib~=3.10.0
+numpy~=1.26.4
+requests~=2.32.3
+tqdm~=4.67.1
diff --git a/scripts/server-bench.py b/scripts/server-bench.py
new file mode 100644 (file)
index 0000000..52163d6
--- /dev/null
@@ -0,0 +1,210 @@
+#!/usr/bin/env python3
+
+import argparse
+import json
+import subprocess
+from time import sleep, time
+from typing import Optional
+
+import datasets
+import logging
+import matplotlib.pyplot as plt
+import numpy as np
+import requests
+from tqdm.contrib.concurrent import thread_map
+
+
+logging.basicConfig(level=logging.INFO, format='%(message)s')
+logger = logging.getLogger("server-bench")
+
+
+def get_prompts(n_prompts: int) -> list[str]:
+    logger.info("Loading MMLU dataset...")
+    ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"]  # type: ignore
+    if n_prompts >= 0:
+        ret = ret[:n_prompts]
+    return ret
+
+
+def get_server(path_server: str, path_model: str, path_log: Optional[str], port: int, n_gpu_layers: int, parallel: int, ctx_size: int) -> dict:
+    logger.info("Starting the llama.cpp server...")
+    address = f"http://localhost:{port}"
+
+    popen_args: list[str] = [
+        path_server,
+        "--flash-attn",
+        "--n-gpu-layers", str(n_gpu_layers),
+        "--parallel", str(parallel),
+        "--ctx-size", str(parallel * ctx_size),
+        "--model", path_model,
+        "--port", str(port),
+        "--swa-full",  # FIXME performance bad otherwise
+        # "--attn-streams",
+    ]
+    fout = open("bench.log", "w") if path_log is not None else subprocess.DEVNULL
+    process = subprocess.Popen(popen_args, stdout=fout, stderr=subprocess.STDOUT)
+
+    n_failures: int = 0
+    while True:
+        try:
+            sleep(1.0)
+            exit_code = process.poll()
+            if exit_code is not None:
+                raise RuntimeError(f"llama.cpp server for {path_model} exited unexpectedly with exit code {exit_code}")
+            response = requests.get(f"{address}/health")
+            if response.status_code == 200:
+                break
+        except requests.ConnectionError:
+            n_failures += 1
+            if n_failures >= 10:
+                raise RuntimeError(f"llama.cpp server for {path_model} is not healthy after 10 seconds")
+
+    return {"process": process, "address": address, "fout": fout}
+
+
+def get_prompt_length(data: dict) -> int:
+    session = data["session"]
+    server_address: str = data["server_address"]
+
+    response = session.post(
+        f"{server_address}/apply-template",
+        json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
+    )
+    if response.status_code != 200:
+        raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
+    prompt: str = json.loads(response.text)["prompt"]
+    response = session.post(
+        f"{server_address}/tokenize",
+        json={"content": prompt, "add_special": True}
+    )
+    if response.status_code != 200:
+        raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
+    tokens: list[str] = json.loads(response.text)["tokens"]
+    return len(tokens)
+
+
+def send_prompt(data: dict) -> tuple[float, list[float]]:
+    session = data["session"]
+    server_address: str = data["server_address"]
+
+    response = session.post(
+        f"{server_address}/apply-template",
+        json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
+    )
+    if response.status_code != 200:
+        raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
+    prompt: str = json.loads(response.text)["prompt"]
+
+    json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
+    response = session.post(f"{server_address}/completion", json=json_data, stream=True)
+
+    last_valid_line: str = ""
+    token_arrival_times: list[float] = []
+    for line in response.iter_lines(decode_unicode=True):
+        if not line.startswith("data: "):
+            continue
+        last_valid_line = line
+        token_arrival_times.append(time())
+    token_arrival_times = token_arrival_times[:-1]
+
+    if response.status_code != 200:
+        raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
+    timings: dict = json.loads(last_valid_line[6:])["timings"]
+
+    return (timings["prompt_ms"], token_arrival_times)
+
+
+def benchmark(path_server: str, path_model: str, path_log: Optional[str], port: int, n_gpu_layers: int, parallel: int, ctx_size: int, n_prompts: int, n_predict: int):
+    num_workers: int = parallel + 1
+    prompts: list[str] = get_prompts(n_prompts)
+
+    server: Optional[dict] = None
+    session = None
+    try:
+        server = get_server(path_server, path_model, path_log, port, n_gpu_layers, parallel, ctx_size)
+        server_address: str = server["address"]
+
+        adapter = requests.adapters.HTTPAdapter(pool_connections=num_workers, pool_maxsize=num_workers)  # type: ignore
+        session = requests.Session()
+        session.mount("http://", adapter)
+        session.mount("https://", adapter)
+
+        data: list[dict] = []
+        for i, p in enumerate(prompts):
+            data.append({"session": session, "server_address": server_address, "prompt": p, "n_predict": n_predict, "seed": i})
+
+        logger.info("Getting the prompt lengths...")
+        prompt_n = [get_prompt_length(d) for d in data]
+
+        logger.info("Starting the benchmark...\n")
+        t0 = time()
+        results: list[tuple[int, list[float]]] = thread_map(send_prompt, data, max_workers=num_workers, chunksize=1)
+    finally:
+        if server is not None:
+            server["process"].terminate()
+            server["process"].wait()
+        if session is not None:
+            session.close()
+
+    prompt_ms = []
+    token_t = []
+    depth_sum: int = 0
+    for pn, (pms, tat) in zip(prompt_n, results):
+        prompt_ms.append(pms)
+        token_t += tat
+        n_tokens: int = len(tat)
+        depth_sum += n_tokens * pn
+        depth_sum += n_tokens * (n_tokens + 1) // 2
+    prompt_n = np.array(prompt_n, dtype=np.int64)
+    prompt_ms = np.array(prompt_ms, dtype=np.float64)
+    token_t = np.array(token_t, dtype=np.float64)
+
+    token_t -= t0
+    token_t_last = np.max(token_t)
+
+    logger.info("")
+    logger.info(f"Benchmark duration:                {token_t_last:.2f} s")
+    logger.info(f"Request throughput:                {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last/60):.2f} requests/min")
+    logger.info(f"Total prompt length:               {np.sum(prompt_n)} tokens")
+    logger.info(f"Average prompt length:             {np.mean(prompt_n):.2f} tokens")
+    logger.info(f"Average prompt latency:            {np.mean(prompt_ms):.2f} ms")
+    logger.info(f"Average prompt speed:              {np.sum(prompt_n) / (1e-3 * np.sum(prompt_ms)):.2f} tokens/s")
+    logger.info(f"Total generated tokens:            {token_t.shape[0]}")
+    logger.info(f"Average generation depth:          {depth_sum / token_t.shape[0]:.2f} tokens")
+    logger.info(f"Average total generation speed:    {token_t.shape[0] / token_t_last:.2f} tokens/s")
+    logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot")
+
+    plt.figure()
+    plt.scatter(prompt_n, prompt_ms, s=10.0, marker=".", alpha=0.25)
+    plt.xlim(0, 1.05 * np.max(prompt_n))
+    plt.ylim(0, 1.05 * np.max(prompt_ms))
+    plt.title(path_model)
+    plt.xlabel("Prompt length [tokens]")
+    plt.ylabel("Time to first token [ms]")
+    plt.savefig("prompt_time.png", dpi=240)
+
+    bin_max = np.ceil(token_t_last) + 1
+    plt.figure()
+    plt.hist(token_t, np.arange(0, bin_max))
+    plt.xlim(0, bin_max + 1)
+    plt.title(path_model)
+    plt.xlabel("Time [s]")
+    plt.ylabel("Num. tokens generated per second")
+    plt.savefig("gen_rate.png", dpi=240)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(
+        description="Tool for benchmarking the throughput of the llama.cpp HTTP server. "
+        "Results are printed to console and visualized as plots (saved to current working directory).")
+    parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary")
+    parser.add_argument("--path_model", type=str, required=True, help="Path to the model to use for the benchmark")
+    parser.add_argument("--path_log", type=str, default=None, help="Path to the model to use for the benchmark")
+    parser.add_argument("--port", type=int, default=18725, help="Port to use for the server during the benchmark")
+    parser.add_argument("--n_gpu_layers", type=int, default=999, help="Number of GPU layers for the server")
+    parser.add_argument("--parallel", type=int, default=16, help="Number of slots for the server")
+    parser.add_argument("--ctx_size", type=int, default=4096, help="Server context size per slot")
+    parser.add_argument("--n_prompts", type=int, default=1000, help="Number of prompts to evaluate")
+    parser.add_argument("--n_predict", type=int, default=2048, help="Max. number of tokens to predict per prompt")
+    args = parser.parse_args()
+    benchmark(**vars(args))
index 6c2e91359a66379072ad3be250f5316499c5d72e..f3dfc8225da4dfb8651121b6d813c735c84e7dcf 100644 (file)
@@ -11,6 +11,8 @@
 
 // increase max payload length to allow use of larger context size
 #define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
+// increase backlog size to avoid connection resets for >> 1 slots
+#define CPPHTTPLIB_LISTEN_BACKLOG 512
 // disable Nagle's algorithm
 #define CPPHTTPLIB_TCP_NODELAY true
 #include <cpp-httplib/httplib.h>