]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server: tests: passkey challenge / self-extend with context shift demo (#5832)
authorPierrick Hymbert <redacted>
Sat, 2 Mar 2024 21:00:14 +0000 (22:00 +0100)
committerGitHub <redacted>
Sat, 2 Mar 2024 21:00:14 +0000 (22:00 +0100)
* server: tests: add models endpoint scenario

* server: /v1/models add some metadata

* server: tests: add debug field in context before scenario

* server: tests: download model from HF, add batch size

* server: tests: add passkey test

* server: tests: add group attention params

* server: do not truncate prompt tokens if self-extend through group attention is enabled

* server: logs: do not truncate log values

* server: tests - passkey - first good working value of nga

* server: tests: fix server timeout

* server: tests: fix passkey, add doc, fix regex content matching, fix timeout

* server: tests: fix regex content matching

* server: tests: schedule slow tests on master

* server: metrics: fix when no prompt processed

* server: tests: self-extend add llama-2-7B and Mixtral-8x7B-v0.1

* server: tests: increase timeout for completion

* server: tests: keep only the PHI-2 test

* server: tests: passkey add a negative test

14 files changed:
.github/workflows/server.yml
examples/server/server.cpp
examples/server/tests/README.md
examples/server/tests/features/environment.py
examples/server/tests/features/issues.feature
examples/server/tests/features/parallel.feature
examples/server/tests/features/passkey.feature [new file with mode: 0644]
examples/server/tests/features/security.feature
examples/server/tests/features/server.feature
examples/server/tests/features/steps/steps.py
examples/server/tests/features/wrong_usages.feature
examples/server/tests/requirements.txt
examples/server/tests/tests.sh
examples/server/utils.hpp

index 0b6f6669b23c703e1a5ad5767b6a326214e1787e..8c63125087d625fc85cb055cbab06101e2322da5 100644 (file)
@@ -10,6 +10,8 @@ on:
   pull_request:
     types: [opened, synchronize, reopened]
     paths: ['.github/workflows/server.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'examples/server/tests/**.*']
+  schedule:
+    -  cron: '00 0 * * *'
 
 jobs:
   server:
@@ -70,14 +72,15 @@ jobs:
         run: |
           pip install -r examples/server/tests/requirements.txt
 
-      - name: Download models
-        id: download_models
+      - name: Tests
+        id: server_integration_tests
         run: |
           cd examples/server/tests
-          ../../../scripts/hf.sh --repo ggml-org/models --file tinyllamas/stories260K.gguf
+          PORT=8888 ./tests.sh
 
-      - name: Tests
-        id: server_integration_test
+      - name: Slow tests
+        id: server_integration_tests_slow
+        if: github.event.schedule != ''
         run: |
           cd examples/server/tests
-          PORT=8888 ./tests.sh
+          PORT=8888 ./tests.sh --stop --no-skipped --no-capture --tags slow
index 2b2f4a0f4a48b757e534e2b5f5a9d05bd41148f6..52daf9e7a3db4a2fbec51640f1fa5b32915a3202 100644 (file)
@@ -441,8 +441,8 @@ struct llama_server_context
             const int ga_w = params.grp_attn_w;
 
             if (ga_n != 1) {
-                GGML_ASSERT(ga_n > 0                    && "ga_n must be positive");                     // NOLINT
-                GGML_ASSERT(ga_w % ga_n == 0            && "ga_w must be a multiple of ga_n");     // NOLINT
+                GGML_ASSERT(ga_n > 0                    && "ga_n must be positive");                       // NOLINT
+                GGML_ASSERT(ga_w % ga_n == 0            && "ga_w must be a multiple of ga_n");             // NOLINT
                 //GGML_ASSERT(n_ctx_train % ga_w == 0     && "n_ctx_train must be a multiple of ga_w");    // NOLINT
                 //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
 
@@ -1709,8 +1709,8 @@ struct llama_server_context
                     }
                     slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
 
-                    // if input prompt is too big, truncate it
-                    if (slot.n_prompt_tokens >= slot.n_ctx)
+                    // if input prompt is too big, truncate it, if group attention self-extend is disabled
+                    if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx)
                     {
                         const int n_left = slot.n_ctx - slot.params.n_keep;
                         const int n_block_size = n_left / 2;
@@ -1785,9 +1785,11 @@ struct llama_server_context
                         }
 
                         LOG_INFO("slot progression", {
-                            { "slot_id", slot.id },
-                            { "task_id", slot.task_id },
-                            { "n_past",  slot.n_past },
+                            { "slot_id",    slot.id },
+                            { "task_id",    slot.task_id },
+                            { "n_past",     slot.n_past },
+                            { "n_past_se",  slot.n_past_se },
+                            { "ga_i",       slot.ga_i },
                             { "n_prompt_tokens_processed", slot.n_prompt_tokens_processed }
                         });
                     }
@@ -2001,6 +2003,17 @@ struct llama_server_context
         LOG_VERBOSE("slots updated", {});
         return true;
     }
+
+    json model_meta() {
+        return json{
+                {"vocab_type", llama_vocab_type(model)},
+                {"n_vocab", llama_n_vocab(model)},
+                {"n_ctx_train", llama_n_ctx_train(model)},
+                {"n_embd", llama_n_embd(model)},
+                {"n_params", llama_model_n_params(model)},
+                {"size", llama_model_size(model)},
+        };
+    }
 };
 
 static void server_print_usage(const char *argv0, const gpt_params &params,
@@ -2911,9 +2924,10 @@ int main(int argc, char **argv)
                 for (const auto& metric_def : metrics_def) {
                     std::string name = metric_def["name"];
                     std::string help = metric_def["help"];
-                    prometheus << "# HELP llamacpp:" << name << " " << help                << "\n"
-                               << "# TYPE llamacpp:" << name << " " << type                << "\n"
-                               << "llamacpp:"        << name << " " << metric_def["value"] << "\n";
+                    auto value = json_value(metric_def, "value", 0);
+                    prometheus << "# HELP llamacpp:" << name << " " << help  << "\n"
+                               << "# TYPE llamacpp:" << name << " " << type  << "\n"
+                               << "llamacpp:"        << name << " " << value << "\n";
                 }
             }
 
@@ -2994,6 +3008,7 @@ int main(int argc, char **argv)
         state.store(SERVER_STATE_READY);
         LOG_INFO("model loaded", {});
     }
+    const auto model_meta = llama.model_meta();
 
     if (sparams.chat_template.empty()) { // custom chat template is not supplied
         // check if the template comes with the model is supported by us
@@ -3143,7 +3158,7 @@ int main(int argc, char **argv)
                 }
             });
 
-    svr.Get("/v1/models", [&params](const httplib::Request& req, httplib::Response& res)
+    svr.Get("/v1/models", [&params, &model_meta](const httplib::Request& req, httplib::Response& res)
             {
                 res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
                 std::time_t t = std::time(0);
@@ -3152,10 +3167,11 @@ int main(int argc, char **argv)
                     {"object", "list"},
                     {"data", {
                         {
-                            {"id", params.model_alias},
-                            {"object", "model"},
-                            {"created", t},
-                            {"owned_by", "llamacpp"}
+                            {"id",       params.model_alias},
+                            {"object",   "model"},
+                            {"created",  t},
+                            {"owned_by", "llamacpp"},
+                            {"meta",     model_meta}
                         },
                     }}
                 };
index 0b9fdc4e726786f849181aae4932df29257c5c94..95a0353b6a9c5ba1a188f82373a13521b717d39b 100644 (file)
@@ -1,22 +1,30 @@
 # Server tests
 
-Python based server tests scenario using [BDD](https://en.wikipedia.org/wiki/Behavior-driven_development) and [behave](https://behave.readthedocs.io/en/latest/):
- * [issues.feature](./features/issues.feature) Pending issues scenario
- * [parallel.feature](./features/parallel.feature) Scenario involving multi slots and concurrent requests
- * [security.feature](./features/security.feature) Security, CORS and API Key
- * [server.feature](./features/server.feature) Server base scenario: completion, embedding, tokenization, etc...
+Python based server tests scenario using [BDD](https://en.wikipedia.org/wiki/Behavior-driven_development)
+and [behave](https://behave.readthedocs.io/en/latest/):
+
+* [issues.feature](./features/issues.feature) Pending issues scenario
+* [parallel.feature](./features/parallel.feature) Scenario involving multi slots and concurrent requests
+* [security.feature](./features/security.feature) Security, CORS and API Key
+* [server.feature](./features/server.feature) Server base scenario: completion, embedding, tokenization, etc...
 
 Tests target GitHub workflows job runners with 4 vCPU.
 
-Requests are using [aiohttp](https://docs.aiohttp.org/en/stable/client_reference.html), [asyncio](https://docs.python.org/fr/3/library/asyncio.html) based http client.
+Requests are
+using [aiohttp](https://docs.aiohttp.org/en/stable/client_reference.html), [asyncio](https://docs.python.org/fr/3/library/asyncio.html)
+based http client.
 
-Note: If the host architecture inference speed is faster than GitHub runners one, parallel scenario may randomly fail. To mitigate it, you can increase values in `n_predict`, `kv_size`.
+Note: If the host architecture inference speed is faster than GitHub runners one, parallel scenario may randomly fail.
+To mitigate it, you can increase values in `n_predict`, `kv_size`.
 
 ### Install dependencies
+
 `pip install -r requirements.txt`
 
 ### Run tests
+
 1. Build the server
+
 ```shell
 cd ../../..
 mkdir build
@@ -24,24 +32,36 @@ cd build
 cmake ../
 cmake --build . --target server
 ```
-2. download required models:
-   1. `../../../scripts/hf.sh --repo ggml-org/models --file tinyllamas/stories260K.gguf`
-3. Start the test: `./tests.sh`
+
+2. Start the test: `./tests.sh`
 
 It's possible to override some scenario steps values with environment variables:
- - `PORT` -> `context.server_port` to set the listening port of the server during scenario, default: `8080`
- - `LLAMA_SERVER_BIN_PATH` -> to change the server binary path, default: `../../../build/bin/server`
- - `DEBUG` -> "ON" to enable steps and server verbose mode `--verbose`
- - `SERVER_LOG_FORMAT_JSON` -> if set switch server logs to json format
+
+| variable                 | description                                                                                    |
+|--------------------------|------------------------------------------------------------------------------------------------|
+| `PORT`                   | `context.server_port` to set the listening port of the server during scenario, default: `8080` |
+| `LLAMA_SERVER_BIN_PATH`  | to change the server binary path, default: `../../../build/bin/server`                         |
+| `DEBUG`                  | "ON" to enable steps and server verbose mode `--verbose`                                       |
+| `SERVER_LOG_FORMAT_JSON` | if set switch server logs to json format                                                       |
+| `N_GPU_LAYERS`           | number of model layers to offload to VRAM `-ngl --n-gpu-layers`                                |
 
 ### Run @bug, @wip or @wrong_usage annotated scenario
 
 Feature or Scenario must be annotated with `@llama.cpp` to be included in the default scope.
+
 - `@bug` annotation aims to link a scenario with a GitHub issue.
 - `@wrong_usage` are meant to show user issue that are actually an expected behavior
 - `@wip` to focus on a scenario working in progress
+- `@slow` heavy test, disabled by default
 
 To run a scenario annotated with `@bug`, start:
-`DEBUG=ON ./tests.sh --no-skipped --tags bug`
+
+```shell
+DEBUG=ON ./tests.sh --no-skipped --tags bug
+```
 
 After changing logic in `steps.py`, ensure that `@bug` and `@wrong_usage` scenario are updated.
+
+```shell
+./tests.sh --no-skipped --tags bug,wrong_usage || echo "should failed but compile"
+```
index 09e8267476135ac8ba680e19d79c6de4c93f24ac..9fd330db6ddc90055163123e02b53c7ef2a0a5b8 100644 (file)
@@ -7,7 +7,10 @@ from signal import SIGKILL
 
 
 def before_scenario(context, scenario):
-    print(f"\x1b[33;42mStarting new scenario: {scenario.name}!\x1b[0m")
+    context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON'
+    if context.debug:
+        print("DEBUG=ON\n")
+    print(f"\x1b[33;42mStarting new scenario: {scenario.name}!\x1b[0m\n")
     port = 8080
     if 'PORT' in os.environ:
         port = int(os.environ['PORT'])
index bf5a175a357ca6017f21f11b12b9a45387442e1e..7b13e44cad3958898f27f65395818d77320b0888 100644 (file)
@@ -1,4 +1,5 @@
 # List of ongoing issues
+# run with: DEBUG=ON ./tests.sh --no-skipped --tags bug
 @bug
 Feature: Issues
   # No confirmed issue at the moment
index 5f895cf90b9668995edee6da78b51c24ade99927..86cdf72829f8c6818890313b88a1238673a10ead 100644 (file)
@@ -1,11 +1,12 @@
 @llama.cpp
+@parallel
 Feature: Parallel
 
   Background: Server startup
     Given a server listening on localhost:8080
-    And   a model file stories260K.gguf
-    And   a model alias tinyllama-2
+    And   a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
     And   42 as server seed
+    And   512 as batch size
     And   64 KV cache size
     And   2 slots
     And   embeddings extraction
diff --git a/examples/server/tests/features/passkey.feature b/examples/server/tests/features/passkey.feature
new file mode 100644 (file)
index 0000000..1bde7aa
--- /dev/null
@@ -0,0 +1,55 @@
+# run with: ./tests.sh --no-skipped --tags passkey
+@passkey
+@slow
+Feature: Passkey / Self-extend with context shift
+
+  Background: Server startup
+    Given a server listening on localhost:8080
+
+  # Generates a long text of junk and inserts a secret passkey number inside it.
+  # Then we query the LLM for the secret passkey.
+  # see #3856 and #4810
+  Scenario Outline: Passkey
+    Given a model file <hf_file> from HF repo <hf_repo>
+    And   <n_batch> as batch size
+    And   <n_junk> as number of junk
+    And   <n_predicted> server max tokens to predict
+    And   42 as seed
+    And   <n_ctx> KV cache size
+    And   1 slots
+    And   <n_ga> group attention factor to extend context size through self-extend
+    And   <n_ga_w> group attention width to extend context size through self-extend
+    # Can be override with N_GPU_LAYERS
+    And   <ngl> GPU offloaded layers
+    Then  the server is starting
+    Then  the server is healthy
+    Given available models
+    Then  model 0 is trained on <n_ctx_train> tokens context
+    Given a prefix prompt:
+    """
+    here is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.
+    """
+    And a passkey prompt template:
+    """
+    The pass key is <passkey> Remember it. <passkey> is the pass key.
+    """
+    And a junk suffix prompt:
+    """
+    The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.
+    """
+    And a suffix prompt:
+    """
+    What is the pass key? The pass key is
+    """
+    Given a "<passkey>" passkey challenge prompt with the passkey inserted every <i_pos> junk
+    And  a completion request with no api error
+    Then <n_predicted> tokens are predicted matching <re_content>
+
+    Examples:
+      | hf_repo                         | hf_file                     | n_ctx_train | ngl | n_ctx | n_batch | n_ga | n_ga_w | n_junk | i_pos | passkey | n_predicted | re_content     |
+      | TheBloke/phi-2-GGUF             | phi-2.Q4_K_M.gguf           | 2048        | 5   | 8192  | 512     | 4    | 512    | 250    | 50    | 42      | 1           | 42             |
+      | TheBloke/phi-2-GGUF             | phi-2.Q4_K_M.gguf           | 2048        | 5   | 8192  | 512     | 2    | 512    | 250    | 50    | 42      | 1           | \b((?!42)\w)+\b  |
+      #| TheBloke/Llama-2-7B-GGUF        | llama-2-7b.Q2_K.gguf        | 4096        | 3   | 16384 | 512     | 4    | 512    | 500    | 300   | 1234    | 5           | 1234           |
+      #| TheBloke/Mixtral-8x7B-v0.1-GGUF | mixtral-8x7b-v0.1.Q2_K.gguf | 32768       | 2   | 16384 | 512     | 4    | 512    | 500    | 100   | 0987    | 5           | 0
+      # 987           |
+
index db06d39775c053677a606586ee9f55eac478cf74..42a6709a53380e1c54b95601ba3a9f7daea31d1d 100644 (file)
@@ -1,9 +1,10 @@
 @llama.cpp
+@security
 Feature: Security
 
   Background: Server startup with an api key defined
     Given a server listening on localhost:8080
-    And   a model file stories260K.gguf
+    And   a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
     And   a server api key llama.cpp
     Then  the server is starting
     Then  the server is healthy
index b571582a7857e392451ff1f559eb8eaa5fcad9e4..7c977bccecaadf712c6ac5b216a909867be00e55 100644 (file)
@@ -1,15 +1,17 @@
 @llama.cpp
+@server
 Feature: llama.cpp server
 
   Background: Server startup
     Given a server listening on localhost:8080
-    And   a model file stories260K.gguf
+    And   a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
     And   a model alias tinyllama-2
     And   42 as server seed
       # KV Cache corresponds to the total amount of tokens
       # that can be stored across all independent sequences: #4130
       # see --ctx-size and #5568
     And   32 KV cache size
+    And   512 as batch size
     And   1 slots
     And   embeddings extraction
     And   32 server max tokens to predict
@@ -29,9 +31,9 @@ Feature: llama.cpp server
     And   prometheus metrics are exposed
 
     Examples: Prompts
-      | prompt                           | n_predict | re_content                             | n_predicted |
-      | I believe the meaning of life is | 8         | (read<or>going)+                       | 8           |
-      | Write a joke about AI            | 64        | (park<or>friends<or>scared<or>always)+ | 32          |
+      | prompt                           | n_predict | re_content                       | n_predicted |
+      | I believe the meaning of life is | 8         | (read\|going)+                   | 8           |
+      | Write a joke about AI            | 64        | (park\|friends\|scared\|always)+ | 32          |
 
   Scenario Outline: OAI Compatibility
     Given a model <model>
@@ -43,9 +45,9 @@ Feature: llama.cpp server
     Then  <n_predicted> tokens are predicted matching <re_content>
 
     Examples: Prompts
-      | model        | system_prompt               | user_prompt                          | max_tokens | re_content                 | n_predicted | enable_streaming |
-      | llama-2      | Book                        | What is the best book                | 8          | (Mom<or>what)+             | 8           | disabled         |
-      | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64         | (thanks<or>happy<or>bird)+ | 32          | enabled          |
+      | model        | system_prompt               | user_prompt                          | max_tokens | re_content             | n_predicted | enable_streaming |
+      | llama-2      | Book                        | What is the best book                | 8          | (Mom\|what)+           | 8           | disabled         |
+      | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64         | (thanks\|happy\|bird)+ | 32          | enabled          |
 
   Scenario: Embedding
     When embeddings are computed for:
@@ -75,10 +77,15 @@ Feature: llama.cpp server
     When an OAI compatible embeddings computation request for multiple inputs
     Then embeddings are generated
 
-
   Scenario: Tokenize / Detokenize
     When tokenizing:
     """
     What is the capital of France ?
     """
     Then tokens can be detokenize
+
+  Scenario: Models available
+    Given available models
+    Then  1 models are supported
+    Then  model 0 is identified by tinyllama-2
+    Then  model 0 is trained on 128 tokens context
index 381da105e279e5efd1712b4c321fe5eb49ee72df..3195278022ffb625d3c252cd9928b84325514d6b 100644 (file)
@@ -13,6 +13,7 @@ import aiohttp
 import openai
 from behave import step
 from behave.api.async_step import async_run_until_complete
+from huggingface_hub import hf_hub_download
 from prometheus_client import parser
 
 
@@ -26,17 +27,23 @@ def step_server_config(context, server_fqdn, server_port):
 
     context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
 
-    context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON'
     context.model_alias = None
+    context.n_batch = None
     context.n_ctx = None
+    context.n_ga = None
+    context.n_ga_w = None
+    context.n_gpu_layer = None
     context.n_predict = None
     context.n_server_predict = None
     context.n_slots = None
+    context.prompt_prefix = None
+    context.prompt_suffix = None
     context.server_api_key = None
     context.server_continuous_batching = False
     context.server_embeddings = False
     context.server_metrics = False
     context.server_process = None
+    context.seed = None
     context.server_seed = None
     context.user_api_key = None
 
@@ -45,9 +52,11 @@ def step_server_config(context, server_fqdn, server_port):
     context.prompts = []
 
 
-@step(u'a model file {model_file}')
-def step_model_file(context, model_file):
-    context.model_file = model_file
+@step(u'a model file {hf_file} from HF repo {hf_repo}')
+def step_download_hf_model(context, hf_file, hf_repo):
+    context.model_file = hf_hub_download(repo_id=hf_repo, filename=hf_file)
+    if context.debug:
+        print(f"model file: {context.model_file}\n")
 
 
 @step(u'a model alias {model_alias}')
@@ -55,24 +64,34 @@ def step_model_alias(context, model_alias):
     context.model_alias = model_alias
 
 
-@step(u'{seed} as server seed')
+@step(u'{seed:d} as server seed')
 def step_seed(context, seed):
-    context.server_seed = int(seed)
+    context.server_seed = seed
+
+
+@step(u'{ngl:d} GPU offloaded layers')
+def step_n_gpu_layer(context, ngl):
+    if 'N_GPU_LAYERS' in os.environ:
+        new_ngl = int(os.environ['N_GPU_LAYERS'])
+        if context.debug:
+            print(f"-ngl upgraded from {ngl} to {new_ngl}")
+        ngl = new_ngl
+    context.n_gpu_layer = ngl
 
 
-@step(u'{n_ctx} KV cache size')
+@step(u'{n_ctx:d} KV cache size')
 def step_n_ctx(context, n_ctx):
-    context.n_ctx = int(n_ctx)
+    context.n_ctx = n_ctx
 
 
-@step(u'{n_slots} slots')
+@step(u'{n_slots:d} slots')
 def step_n_slots(context, n_slots):
-    context.n_slots = int(n_slots)
+    context.n_slots = n_slots
 
 
-@step(u'{n_predict} server max tokens to predict')
+@step(u'{n_predict:d} server max tokens to predict')
 def step_server_n_predict(context, n_predict):
-    context.n_server_predict = int(n_predict)
+    context.n_server_predict = n_predict
 
 
 @step(u'continuous batching')
@@ -116,11 +135,13 @@ async def step_wait_for_the_server_to_be_started(context, expecting_status):
 
         case 'ready' | 'idle':
             await wait_for_health_status(context, context.base_url, 200, 'ok',
+                                         timeout=10,
                                          params={'fail_on_no_slot': 0, 'include_slots': 0},
                                          slots_idle=context.n_slots,
                                          slots_processing=0,
                                          expected_slots=[{'id': slot_id, 'state': 0}
-                                                         for slot_id in range(context.n_slots)])
+                                                         for slot_id in
+                                                         range(context.n_slots if context.n_slots else 1)])
         case 'busy':
             await wait_for_health_status(context, context.base_url, 503,
                                          'no slot available',
@@ -128,7 +149,8 @@ async def step_wait_for_the_server_to_be_started(context, expecting_status):
                                          slots_idle=0,
                                          slots_processing=context.n_slots,
                                          expected_slots=[{'id': slot_id, 'state': 1}
-                                                         for slot_id in range(context.n_slots)])
+                                                         for slot_id in
+                                                         range(context.n_slots if context.n_slots else 1)])
         case _:
             assert False, "unknown status"
 
@@ -157,24 +179,24 @@ async def step_request_completion(context, api_error):
                                           context.base_url,
                                           debug=context.debug,
                                           n_predict=context.n_predict,
-                                          server_seed=context.server_seed,
+                                          seed=await completions_seed(context),
                                           expect_api_error=expect_api_error,
                                           user_api_key=context.user_api_key)
     context.tasks_result.append(completion)
     if context.debug:
-        print(f"Completion response: {completion}")
+        print(f"Completion response: {completion}\n")
     if expect_api_error:
         assert completion == 401, f"completion must be an 401 status code: {completion}"
 
 
-@step(u'{predicted_n} tokens are predicted matching {re_content}')
+@step(u'{predicted_n:d} tokens are predicted matching {re_content}')
 def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
-    assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n), re_content)
+    assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n, re_content)
 
 
-@step(u'{predicted_n} tokens are predicted')
+@step(u'{predicted_n:d} tokens are predicted')
 def step_n_tokens_predicted(context, predicted_n):
-    assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n))
+    assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n)
 
 
 @step(u'a user prompt {user_prompt}')
@@ -192,9 +214,9 @@ def step_model(context, model):
     context.model = model
 
 
-@step(u'{max_tokens} max tokens to predict')
+@step(u'{max_tokens:d} max tokens to predict')
 def step_max_tokens(context, max_tokens):
-    context.n_predict = int(max_tokens)
+    context.n_predict = max_tokens
 
 
 @step(u'streaming is {enable_streaming}')
@@ -222,11 +244,70 @@ def step_server_api_key(context, server_api_key):
     context.server_api_key = server_api_key
 
 
+@step(u'{n_junk:d} as number of junk')
+def step_n_junk(context, n_junk):
+    context.n_junk = n_junk
+
+
+@step(u'{n_batch:d} as batch size')
+def step_n_batch(context, n_batch):
+    context.n_batch = n_batch
+
+
+@step(u'{seed:d} as seed')
+def step_seed(context, seed):
+    context.seed = seed
+
+
+@step(u'a prefix prompt')
+def step_prompt_prefix(context):
+    context.prompt_prefix = context.text
+
+
+@step(u'a junk suffix prompt')
+def step_prompt_junk_suffix(context):
+    context.prompt_junk_suffix = context.text
+
+
+@step(u'a suffix prompt')
+def step_prompt_suffix(context):
+    context.prompt_suffix = context.text
+
+
+@step(u'{n_ga:d} group attention factor'
+      u' to extend context size through self-extend')
+def step_impl(context, n_ga):
+    context.n_ga = n_ga
+
+
+@step(u'{n_ga_w:d} group attention width to extend context size through self-extend')
+def step_impl(context, n_ga_w):
+    context.n_ga_w = n_ga_w
+
+
+@step(u'a passkey prompt template')
+def step_prompt_passkey(context):
+    context.prompt_passkey = context.text
+
+
+@step(u'a "{passkey}" passkey challenge prompt with the passkey inserted every {i_pos:d} junk')
+def step_prompt_passkey(context, passkey, i_pos):
+    prompt = ""
+    for i in range(context.n_junk):
+        if i % context.n_junk == i_pos:
+            prompt += context.prompt_passkey # the passkey is already substituted
+        prompt += context.prompt_junk_suffix
+    if context.debug:
+        passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m"
+        print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```\n")
+    context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix)
+
+
 @step(u'an OAI compatible chat completions request with {api_error} api error')
 @async_run_until_complete
 async def step_oai_chat_completions(context, api_error):
     if context.debug:
-        print(f"Submitting OAI compatible completions request...")
+        print(f"Submitting OAI compatible completions request...\n")
     expect_api_error = api_error == 'raised'
     completion = await oai_chat_completions(context.prompts.pop(),
                                             context.system_prompt,
@@ -241,8 +322,7 @@ async def step_oai_chat_completions(context, api_error):
                                             enable_streaming=context.enable_streaming
                                             if hasattr(context, 'enable_streaming') else None,
 
-                                            server_seed=context.server_seed
-                                            if hasattr(context, 'server_seed') else None,
+                                            seed=await completions_seed(context),
 
                                             user_api_key=context.user_api_key
                                             if hasattr(context, 'user_api_key') else None,
@@ -276,8 +356,10 @@ async def step_concurrent_completion_requests(context):
                               # prompt is inserted automatically
                               context.base_url,
                               debug=context.debug,
+                              prompt_prefix=context.prompt_prefix,
+                              prompt_suffix=context.prompt_suffix,
                               n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
-                              server_seed=context.server_seed if hasattr(context, 'server_seed') else None,
+                              seed=await completions_seed(context),
                               user_api_key=context.user_api_key if hasattr(context,
                                                                            'user_api_key') else None)
 
@@ -297,8 +379,7 @@ async def step_oai_chat_completions(context):
                               if hasattr(context, 'n_predict') else None,
                               enable_streaming=context.enable_streaming
                               if hasattr(context, 'enable_streaming') else None,
-                              server_seed=context.server_seed
-                              if hasattr(context, 'server_seed') else None,
+                              seed=await completions_seed(context),
                               user_api_key=context.user_api_key
                               if hasattr(context, 'user_api_key') else None)
 
@@ -318,7 +399,9 @@ async def step_oai_chat_completions(context):
                               if hasattr(context, 'n_predict') else None,
                               enable_streaming=context.enable_streaming
                               if hasattr(context, 'enable_streaming') else None,
-                              server_seed=context.server_seed
+                              seed=context.seed
+                              if hasattr(context, 'seed') else
+                              context.server_seed
                               if hasattr(context, 'server_seed') else None,
                               user_api_key=context.user_api_key
                               if hasattr(context, 'user_api_key') else None)
@@ -330,11 +413,10 @@ async def step_all_prompts_are_predicted(context):
     await all_prompts_are_predicted(context)
 
 
-@step(u'all prompts are predicted with {n_predict} tokens')
+@step(u'all prompts are predicted with {n_expected_predicted:d} tokens')
 @async_run_until_complete
-async def step_all_prompts_are_predicted_with_n_tokens(context, n_predict):
-    expected_predicted_n = int(n_predict)
-    await all_prompts_are_predicted(context, expected_predicted_n)
+async def step_all_prompts_are_predicted_with_n_tokens(context, n_expected_predicted):
+    await all_prompts_are_predicted(context, n_expected_predicted)
 
 
 async def all_prompts_are_predicted(context, expected_predicted_n=None):
@@ -464,6 +546,8 @@ async def step_prometheus_metrics_exported(context):
             assert metrics_response.headers['Content-Type'] == "text/plain; version=0.0.4"
             metrics_raw = await metrics_response.text()
             metric_exported = False
+            if context.debug:
+                print(f"/metrics answer:\n{metrics_raw}\n")
             for metric in parser.text_string_to_metric_families(metrics_raw):
                 match metric.name:
                     case "llamacpp:kv_cache_usage_ratio":
@@ -472,6 +556,37 @@ async def step_prometheus_metrics_exported(context):
             assert metric_exported, "No metrics exported"
 
 
+@step(u'available models')
+def step_available_models(context):
+    # openai client always expects an api_key
+    openai.api_key = context.user_api_key if context.user_api_key is not None else 'nope'
+    openai.api_base = f'{context.base_url}/v1'
+    context.models = openai.Model.list().data
+
+
+@step(u'{n_model:d} models are supported')
+def step_supported_models(context, n_model):
+    if context.debug:
+        print("server models available:", context.models)
+    assert len(context.models) == n_model
+
+
+@step(u'model {i_model:d} is {param} {preposition} {param_value}')
+def step_supported_models(context, i_model, param, preposition, param_value):
+    assert i_model < len(context.models)
+    model = context.models[i_model]
+
+    param_value = param_value.split(' ', 1)[0]
+    match param:
+        case 'identified':
+            value = model.id
+        case 'trained':
+            value = str(model.meta.n_ctx_train)
+        case _:
+            assert False, "param {param} not supported"
+    assert param_value == value, f"model param {param} {value} != {param_value}"
+
+
 async def concurrent_requests(context, f_completion, *args, **kwargs):
     n_prompts = len(context.prompts)
     if context.debug:
@@ -486,8 +601,10 @@ async def concurrent_requests(context, f_completion, *args, **kwargs):
 async def request_completion(prompt,
                              base_url,
                              debug=False,
+                             prompt_prefix=None,
+                             prompt_suffix=None,
                              n_predict=None,
-                             server_seed=None,
+                             seed=None,
                              expect_api_error=None,
                              user_api_key=None):
     if debug:
@@ -504,11 +621,14 @@ async def request_completion(prompt,
     async with aiohttp.ClientSession() as session:
         async with session.post(f'{base_url}/completion',
                                 json={
+                                    "input_prefix": prompt_prefix,
                                     "prompt": prompt,
-                                    "n_predict": int(n_predict) if n_predict is not None else -1,
-                                    "seed": server_seed if server_seed is not None else 42
+                                    "input_suffix": prompt_suffix,
+                                    "n_predict": n_predict if n_predict is not None else -1,
+                                    "seed": seed if seed is not None else 42
                                 },
-                                headers=headers) as response:
+                                headers=headers,
+                                timeout=3600) as response:
             if expect_api_error is None or not expect_api_error:
                 assert response.status == 200
                 assert response.headers['Access-Control-Allow-Origin'] == origin
@@ -526,14 +646,14 @@ async def oai_chat_completions(user_prompt,
                                model=None,
                                n_predict=None,
                                enable_streaming=None,
-                               server_seed=None,
+                               seed=None,
                                user_api_key=None,
                                expect_api_error=None):
     if debug:
         print(f"Sending OAI Chat completions request: {user_prompt}")
     # openai client always expects an api key
     user_api_key = user_api_key if user_api_key is not None else 'nope'
-    seed = server_seed if server_seed is not None else 42
+    seed = seed if seed is not None else 42
     enable_streaming = enable_streaming if enable_streaming is not None else False
     payload = {
         "messages": [
@@ -692,20 +812,32 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
     content = completion_response['content']
     n_predicted = completion_response['timings']['predicted_n']
     assert len(content) > 0, "no token predicted"
-    if expected_predicted_n is not None:
+    if re_content is not None:
+        p = re.compile(re_content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL)
+        matches = p.finditer(content)
+        last_match = 0
+        highlighted = ''
+        for match in matches:
+            start, end = match.span()
+            highlighted += content[last_match: start]
+            highlighted += '\x1b[33m'
+            highlighted += content[start: end]
+            highlighted += '\x1b[0m'
+            last_match = end
+        highlighted += content[last_match:]
+        if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
+          print(f"Checking completion response: {highlighted}\n")
+        assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```'
+    if expected_predicted_n and expected_predicted_n > 0:
         assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
                                                      f' {n_predicted} <> {expected_predicted_n}')
-    if re_content is not None:
-        re_content = '^.*' + re_content.replace('<or>', '|') + '.*$'
-        assert re.match(re_content, content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL), (
-            f'invalid tokens predicted:'
-            f' ```\n{content}\n``` do not match /{re_content}/')
+
 
 
 async def gather_tasks_results(context):
     n_tasks = len(context.concurrent_tasks)
     if context.debug:
-        print(f"Waiting for all {n_tasks} tasks results...")
+        print(f"Waiting for all {n_tasks} tasks results...\n")
     for task_no in range(n_tasks):
         context.tasks_result.append(await context.concurrent_tasks.pop())
     n_completions = len(context.tasks_result)
@@ -716,15 +848,13 @@ async def wait_for_health_status(context,
                                  base_url,
                                  expected_http_status_code,
                                  expected_health_status,
+                                 timeout=3,
                                  params=None,
                                  slots_idle=None,
                                  slots_processing=None,
                                  expected_slots=None):
     if context.debug:
-        print(f"Starting checking for health for expected_health_status={expected_health_status}")
-    timeout = 3  # seconds
-    if expected_health_status == 'ok':
-        timeout = 10 # CI slow inference
+        print(f"Starting checking for health for expected_health_status={expected_health_status}\n")
     interval = 0.5
     counter = 0
     async with aiohttp.ClientSession() as session:
@@ -734,7 +864,7 @@ async def wait_for_health_status(context,
                 health = await health_response.json()
                 if context.debug:
                     print(f"HEALTH - response for expected health status='{expected_health_status}' on "
-                          f"'{base_url}/health'?{params} is {health}")
+                          f"'{base_url}/health'?{params} is {health}\n")
                 if (status_code == expected_http_status_code
                         and health['status'] == expected_health_status
                         and (slots_idle is None or health['slots_idle'] == slots_idle)
@@ -757,7 +887,7 @@ async def wait_for_health_status(context,
                 if expected_http_status_code == 503:
                     if len(context.tasks_result) == 0:
                         print("\x1b[5;37;43mWARNING: forcing concurrent tasks,"
-                              " busy health check missed, probably too fast inference\x1b[0m")
+                              " busy health check missed, probably too fast inference\x1b[0m\n")
                         n_completions = await gather_tasks_results(context)
                         if n_completions > 0:
                             return
@@ -791,6 +921,11 @@ def assert_slots_status(slots, expected_slots):
                                                 f" = {expected[key]} != {slot[key]}")
 
 
+async def completions_seed(context):
+    return context.seed if hasattr(context, 'seed') and context.seed is not None \
+        else context.server_seed if hasattr(context, 'server_seed') else None
+
+
 def start_server_background(context):
     context.server_path = '../../../build/bin/server'
     if 'LLAMA_SERVER_BIN_PATH' in os.environ:
@@ -800,27 +935,35 @@ def start_server_background(context):
         '--port', context.server_port,
         '--model', context.model_file
     ]
+    if context.n_batch:
+        server_args.extend(['--batch-size', context.n_batch])
+    if context.n_gpu_layer:
+        server_args.extend(['--n-gpu-layers', context.n_gpu_layer])
     if context.server_continuous_batching:
         server_args.append('--cont-batching')
     if context.server_embeddings:
         server_args.append('--embedding')
     if context.server_metrics:
         server_args.append('--metrics')
-    if context.model_alias is not None:
+    if context.model_alias:
         server_args.extend(['--alias', context.model_alias])
-    if context.n_ctx is not None:
+    if context.n_ctx:
         server_args.extend(['--ctx-size', context.n_ctx])
-    if context.n_slots is not None:
+    if context.n_slots:
         server_args.extend(['--parallel', context.n_slots])
-    if context.n_server_predict is not None:
+    if context.n_server_predict:
         server_args.extend(['--n-predict', context.n_server_predict])
-    if context.server_api_key is not None:
+    if context.server_api_key:
         server_args.extend(['--api-key', context.server_api_key])
+    if context.n_ga:
+        server_args.extend(['--grp-attn-n', context.n_ga])
+    if context.n_ga_w:
+        server_args.extend(['--grp-attn-w', context.n_ga_w])
     if context.debug:
         server_args.append('--verbose')
     if 'SERVER_LOG_FORMAT_JSON' not in os.environ:
         server_args.extend(['--log-format', "text"])
-    print(f"starting server with: {context.server_path}", *server_args)
+    print(f"starting server with: {context.server_path} {server_args}\n")
     context.server_process = subprocess.Popen(
         [str(arg) for arg in [context.server_path, *server_args]],
         close_fds=True)
index e228b2371cccecafcbd18e7592aada53ce328c8b..cf14b3b44e03b308f889f794563148ff6a925b32 100644 (file)
@@ -1,4 +1,4 @@
-# run with ./test.sh --tags wrong_usage
+# run with: ./tests.sh --no-skipped --tags wrong_usage
 @wrong_usage
 Feature: Wrong usage of llama.cpp server
 
@@ -7,7 +7,7 @@ Feature: Wrong usage of llama.cpp server
   # or pass n_predict/max_tokens in the request.
   Scenario: Infinite loop
     Given a server listening on localhost:8080
-    And   a model file stories260K.gguf
+    And   a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
     # Uncomment below to fix the issue
     #And   64 server max tokens to predict
     Then  the server is starting
@@ -18,4 +18,5 @@ Feature: Wrong usage of llama.cpp server
     # Uncomment below to fix the issue
     #And   128 max tokens to predict
     Given concurrent completion requests
+    Then the server is idle
     Then all prompts are predicted
index 334fa4a70ea72e947e073271c00c4c44498efa21..5d4210164a50ac6c9d90de4501ebffc41bee5a08 100644 (file)
@@ -1,4 +1,5 @@
 aiohttp~=3.9.3
 behave~=1.2.6
+huggingface_hub~=0.20.3
 openai~=0.25.0
 prometheus-client~=0.20.0
index 17a4e6fc64307d794415f7f7fb3932bf0e2fb265..1c6c5695fcf6548b6cc6974f2f91f83de7ffce46 100755 (executable)
@@ -5,7 +5,7 @@ set -eu
 if [ $# -lt 1 ]
 then
   # Start @llama.cpp scenario
-  behave --summary --stop --no-capture --exclude 'issues|wrong_usages' --tags llama.cpp
+  behave --summary --stop --no-capture --exclude 'issues|wrong_usages|passkey' --tags llama.cpp
 else
   behave "$@"
 fi
index d98541f26d123f3059a220c1b8ff72d15e54967b..b6e49d8b98a2a808aa74243f42a461c5c908c8b3 100644 (file)
@@ -126,8 +126,7 @@ static inline void server_log(const char *level, const char *function, int line,
         for (const auto& el : log.items())
         {
             const std::string value = el.value().dump(-1, ' ', false, json::error_handler_t::replace);
-            snprintf(buf, 1024, " %s=%s", el.key().c_str(), value.c_str());
-            ss << buf;
+            ss << " " << el.key() << "=" << value;
         }
 
         const std::string str = ss.str();