]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server: benchmark: chat/completions scenario and other llm servers comparison (#5941)
authorPierrick Hymbert <redacted>
Sat, 9 Mar 2024 22:41:49 +0000 (23:41 +0100)
committerGitHub <redacted>
Sat, 9 Mar 2024 22:41:49 +0000 (23:41 +0100)
* server: bench: Init a bench scenario with K6
See #5827

* server: bench: EOL EOF

* server: bench: PR feedback and improved k6 script configuration

* server: bench: remove llamacpp_completions_tokens_seconds as it include prompt processing time and it's misleading

server: bench: add max_tokens from SERVER_BENCH_MAX_TOKENS

server: bench: increase truncated rate to 80% before failing

* server: bench: fix doc

* server: bench: change gauge custom metrics to trend

* server: bench: change gauge custom metrics to trend
server: bench: add trend custom metrics for total tokens per second average

* server: bench: doc add an option to debug http request

* server: bench: filter dataset too short and too long sequences

* server: bench: allow to filter out conversation in the dataset based on env variable

* server: bench: fix assistant message sent instead of user message

* server: bench: fix assistant message sent instead of user message

* server : add defrag thold parameter

* server: bench: select prompts based on the current iteration id not randomly to make the bench more reproducible

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/server/bench/README.md [new file with mode: 0644]
examples/server/bench/script.js [new file with mode: 0644]
examples/server/server.cpp

diff --git a/examples/server/bench/README.md b/examples/server/bench/README.md
new file mode 100644 (file)
index 0000000..a53ad64
--- /dev/null
@@ -0,0 +1,88 @@
+### Server benchmark tools
+
+Benchmark is using [k6](https://k6.io/).
+
+##### Install k6
+
+Follow instruction from: https://k6.io/docs/get-started/installation/
+
+Example for ubuntu:
+```shell
+snap install k6
+```
+
+#### Download a dataset
+
+This dataset was originally proposed in [vLLM benchmarks](https://github.com/vllm-project/vllm/blob/main/benchmarks/README.md).
+
+```shell
+wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
+```
+
+#### Download a model
+Example for PHI-2
+
+```shell
+../../../scripts/hf.sh --repo ggml-org/models --file phi-2/ggml-model-q4_0.gguf
+```
+
+#### Start the server
+The server must answer OAI Chat completion requests on `http://localhost:8080/v1` or according to the environment variable `SERVER_BENCH_URL`.
+
+Example:
+```shell
+server --host localhost --port 8080 \
+  --model ggml-model-q4_0.gguf \
+  --cont-batching \
+  --metrics \
+  --parallel 8 \
+  --batch-size 512 \
+  --ctx-size 4096 \
+  --log-format text \
+  -ngl 33
+```
+
+#### Run the benchmark
+
+For 500 chat completions request with 8 concurrent users during maximum 10 minutes, run:
+```shell
+k6 run script.js --duration 10m --iterations 500 --vus 8
+```
+
+The benchmark values can be overridden with:
+- `SERVER_BENCH_URL` server url prefix for chat completions, default `http://localhost:8080/v1`
+- `SERVER_BENCH_N_PROMPTS` total prompts to randomly select in the benchmark, default `480`
+- `SERVER_BENCH_MODEL_ALIAS` model alias to pass in the completion request, default `my-model`
+- `SERVER_BENCH_MAX_TOKENS` max tokens to predict, default: `512`
+- `SERVER_BENCH_DATASET` path to the benchmark dataset file
+- `SERVER_BENCH_MAX_PROMPT_TOKENS` maximum prompt tokens to filter out in the dataset: default `1024`
+- `SERVER_BENCH_MAX_CONTEXT` maximum context size of the completions request to filter out in the dataset: prompt + predicted tokens, default `2048`
+
+Note: the local tokenizer is just a string space split, real number of tokens will differ.
+
+Or with [k6 options](https://k6.io/docs/using-k6/k6-options/reference/):
+
+```shell
+SERVER_BENCH_N_PROMPTS=500 k6 run script.js --duration 10m --iterations 500 --vus 8
+```
+
+To [debug http request](https://k6.io/docs/using-k6/http-debugging/) use `--http-debug="full"`.
+
+#### Metrics
+
+Following metrics are available computed from the OAI chat completions response `usage`:
+- `llamacpp_tokens_second` Trend of `usage.total_tokens / request duration`
+- `llamacpp_prompt_tokens` Trend of `usage.prompt_tokens`
+- `llamacpp_prompt_tokens_total_counter` Counter of `usage.prompt_tokens`
+- `llamacpp_completion_tokens` Trend of `usage.completion_tokens`
+- `llamacpp_completion_tokens_total_counter` Counter of `usage.completion_tokens`
+- `llamacpp_completions_truncated_rate` Rate of completions truncated, i.e. if `finish_reason === 'length'`
+- `llamacpp_completions_stop_rate` Rate of completions stopped by the model, i.e. if `finish_reason === 'stop'`
+
+The script will fail if too many completions are truncated, see `llamacpp_completions_truncated_rate`.
+
+K6 metrics might be compared against [server metrics](../README.md), with:
+
+```shell
+curl http://localhost:8080/metrics
+```
diff --git a/examples/server/bench/script.js b/examples/server/bench/script.js
new file mode 100644 (file)
index 0000000..a4f5ac5
--- /dev/null
@@ -0,0 +1,120 @@
+import http from 'k6/http'
+import {check, sleep} from 'k6'
+import {SharedArray} from 'k6/data'
+import {Counter, Rate, Trend} from 'k6/metrics'
+import exec from 'k6/execution';
+
+// Server chat completions prefix
+const server_url = __ENV.SERVER_BENCH_URL ? __ENV.SERVER_BENCH_URL : 'http://localhost:8080/v1'
+
+// Number of total prompts in the dataset - default 10m / 10 seconds/request * number of users
+const n_prompt = __ENV.SERVER_BENCH_N_PROMPTS ? parseInt(__ENV.SERVER_BENCH_N_PROMPTS) : 600 / 10 * 8
+
+// Model name to request
+const model = __ENV.SERVER_BENCH_MODEL_ALIAS ? __ENV.SERVER_BENCH_MODEL_ALIAS : 'my-model'
+
+// Dataset path
+const dataset_path = __ENV.SERVER_BENCH_DATASET ? __ENV.SERVER_BENCH_DATASET : './ShareGPT_V3_unfiltered_cleaned_split.json'
+
+// Max tokens to predict
+const max_tokens = __ENV.SERVER_BENCH_MAX_TOKENS ? parseInt(__ENV.SERVER_BENCH_MAX_TOKENS) : 512
+
+// Max prompt tokens
+const n_prompt_tokens = __ENV.SERVER_BENCH_MAX_PROMPT_TOKENS ? parseInt(__ENV.SERVER_BENCH_MAX_PROMPT_TOKENS) : 1024
+
+// Max slot context
+const n_ctx_slot = __ENV.SERVER_BENCH_MAX_CONTEXT ? parseInt(__ENV.SERVER_BENCH_MAX_CONTEXT) : 2048
+
+export function setup() {
+    console.info(`Benchmark config: server_url=${server_url} n_prompt=${n_prompt} model=${model} dataset_path=${dataset_path} max_tokens=${max_tokens}`)
+}
+
+const data = new SharedArray('conversations', function () {
+    const tokenizer = (message) => message.split(/[\s,'".?]/)
+
+    return JSON.parse(open(dataset_path))
+        // Filter out the conversations with less than 2 turns.
+        .filter(data => data["conversations"].length >= 2)
+        .filter(data => data["conversations"][0]["from"] === "human")
+        .map(data => {
+            return {
+                prompt: data["conversations"][0]["value"],
+                n_prompt_tokens: tokenizer(data["conversations"][0]["value"]).length,
+                n_completion_tokens: tokenizer(data["conversations"][1]["value"]).length,
+            }
+        })
+        // Filter out too short sequences
+        .filter(conv => conv.n_prompt_tokens >= 4 && conv.n_completion_tokens >= 4)
+        // Filter out too long sequences.
+        .filter(conv => conv.n_prompt_tokens <= n_prompt_tokens && conv.n_prompt_tokens + conv.n_completion_tokens <= n_ctx_slot)
+        // Keep only first n prompts
+        .slice(0, n_prompt)
+})
+
+const llamacpp_prompt_tokens = new Trend('llamacpp_prompt_tokens')
+const llamacpp_completion_tokens = new Trend('llamacpp_completion_tokens')
+const llamacpp_tokens_second = new Trend('llamacpp_tokens_second')
+
+const llamacpp_prompt_tokens_total_counter = new Counter('llamacpp_prompt_tokens_total_counter')
+const llamacpp_completion_tokens_total_counter = new Counter('llamacpp_completion_tokens_total_counter')
+
+const llamacpp_completions_truncated_rate = new Rate('llamacpp_completions_truncated_rate')
+const llamacpp_completions_stop_rate = new Rate('llamacpp_completions_stop_rate')
+
+export const options = {
+    thresholds: {
+        llamacpp_completions_truncated_rate: [
+            // more than 80% of truncated input will abort the test
+            {threshold: 'rate < 0.8', abortOnFail: true, delayAbortEval: '1m'},
+        ],
+    },
+    duration: '10m',
+    vus: 8,
+}
+
+export default function () {
+    const conversation = data[exec.scenario.iterationInInstance % data.length]
+    const payload = {
+        "messages": [
+            {
+                "role": "system",
+                "content": "You are ChatGPT, an AI assistant.",
+            },
+            {
+                "role": "user",
+                "content": conversation.prompt,
+            }
+        ],
+        "model": model,
+        "stream": false,
+        "max_tokens": max_tokens
+    }
+
+    const body = JSON.stringify(payload)
+
+    let res = http.post(`${server_url}/chat/completions`, body, {
+        headers: {'Content-Type': 'application/json'},
+        timeout: '300s'
+    })
+
+    check(res, {'success completion': (r) => r.status === 200})
+
+    if (res.status === 200) {
+        const completions = res.json()
+
+        llamacpp_prompt_tokens.add(completions.usage.prompt_tokens)
+        llamacpp_prompt_tokens_total_counter.add(completions.usage.prompt_tokens)
+
+        llamacpp_completion_tokens.add(completions.usage.completion_tokens)
+        llamacpp_completion_tokens_total_counter.add(completions.usage.completion_tokens)
+
+        llamacpp_completions_truncated_rate.add(completions.choices[0].finish_reason === 'length')
+        llamacpp_completions_stop_rate.add(completions.choices[0].finish_reason === 'stop')
+
+        llamacpp_tokens_second.add(completions.usage.total_tokens / res.timings.duration * 1.e3)
+    } else {
+        console.error(`response: ${res.body} request=${payload}`)
+    }
+
+    sleep(0.3)
+}
index b14cca61b153062c1a51ef24336c960d90217526..c7d3ed01b63470e86267e3a465b1b3d528d6460e 100644 (file)
@@ -2133,6 +2133,8 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
     printf("  --yarn-beta-slow N        YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
     printf("  --yarn-beta-fast N        YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
     printf("  --pooling {none,mean,cls} pooling type for embeddings, use model default if unspecified\n");
+    printf("  -dt N, --defrag-thold N\n");
+    printf("                            KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold);
     printf("  -b N, --batch-size N      batch size for prompt processing (default: %d)\n", params.n_batch);
     printf("  --memory-f32              use f32 instead of f16 for memory key+value (default: disabled)\n");
     printf("                            not recommended: doubles context memory required and no measurable increase in quality\n");
@@ -2355,6 +2357,12 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
             else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
             else if (value == "cls")  { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
             else { invalid_param = true; break; }
+        } else if (arg == "--defrag-thold" || arg == "-dt") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.defrag_thold = std::stof(argv[i]);
         } else if (arg == "--threads" || arg == "-t") {
             if (++i >= argc)
             {