]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : add --no-context-shift option (#9607)
authorXuan Son Nguyen <redacted>
Mon, 23 Sep 2024 20:23:54 +0000 (22:23 +0200)
committerGitHub <redacted>
Mon, 23 Sep 2024 20:23:54 +0000 (22:23 +0200)
* server : add --no-context-shift option

* small fix

* Update examples/server/tests/features/embeddings.feature

Co-authored-by: Georgi Gerganov <redacted>
* tests : minor fix

* revert usage of GGML_ASSERT

* update server documentation

---------

Co-authored-by: Georgi Gerganov <redacted>
common/arg.cpp
examples/server/README.md
examples/server/server.cpp
examples/server/tests/features/ctx_shift.feature [new file with mode: 0644]
examples/server/tests/features/embeddings.feature
examples/server/tests/features/steps/steps.py

index 922391069d32aa56646d3754e7d309809d498627..c1ec3c4f99c379bb8167be49752b3750ffc4b1b4 100644 (file)
@@ -691,7 +691,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
         [](gpt_params & params) {
             params.ctx_shift = false;
         }
-    ).set_examples({LLAMA_EXAMPLE_MAIN}));
+    ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}));
     add_opt(llama_arg(
         {"--chunks"}, "N",
         format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
index 326e05e1e3ea1078b8102d0a243fb93cc0c12dda..741950c8a51939e783050f973d83e97873bfa117 100644 (file)
@@ -21,8 +21,6 @@ The project is under active development, and we are [looking for feedback and co
 | -------- | ----------- |
 | `-h, --help, --usage` | print usage and exit |
 | `--version` | show version and build info |
-| `-v, --verbose` | print verbose information |
-| `--verbosity N` | set specific verbosity level (default: 0) |
 | `-t, --threads N` | number of threads to use during generation (default: -1)<br/>(env: LLAMA_ARG_THREADS) |
 | `-tb, --threads-batch N` | number of threads to use during batch and prompt processing (default: same as --threads) |
 | `-C, --cpu-mask M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: "") |
@@ -40,15 +38,18 @@ The project is under active development, and we are [looking for feedback and co
 | `-b, --batch-size N` | logical maximum batch size (default: 2048)<br/>(env: LLAMA_ARG_BATCH) |
 | `-ub, --ubatch-size N` | physical maximum batch size (default: 512)<br/>(env: LLAMA_ARG_UBATCH) |
 | `--keep N` | number of tokens to keep from the initial prompt (default: 0, -1 = all) |
+| `--no-context-shift` | disables context shift on inifinite text generation (default: disabled) |
 | `-fa, --flash-attn` | enable Flash Attention (default: disabled)<br/>(env: LLAMA_ARG_FLASH_ATTN) |
 | `-p, --prompt PROMPT` | prompt to start generation with |
+| `--no-perf` | disable internal libllama performance timings (default: false)<br/>(env: LLAMA_ARG_NO_PERF) |
 | `-f, --file FNAME` | a file containing the prompt (default: none) |
 | `-bf, --binary-file FNAME` | binary file containing the prompt (default: none) |
 | `-e, --escape` | process escapes sequences (\n, \r, \t, \', \", \\) (default: true) |
 | `--no-escape` | do not process escape sequences |
+| `-sp, --special` | special tokens output enabled (default: false) |
 | `--spm-infill` | use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled) |
 | `--samplers SAMPLERS` | samplers that will be used for generation in the order, separated by ';'<br/>(default: top_k;tfs_z;typ_p;top_p;min_p;temperature) |
-| `-s, --seed SEED` | RNG seed (default: -1, use random seed for < 0) |
+| `-s, --seed SEED` | RNG seed (default: 4294967295, use random seed for 4294967295) |
 | `--sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: kfypmt) |
 | `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) |
 | `--penalize-nl` | penalize newline tokens (default: false) |
@@ -87,7 +88,7 @@ The project is under active development, and we are [looking for feedback and co
 | `-ctk, --cache-type-k TYPE` | KV cache data type for K (default: f16) |
 | `-ctv, --cache-type-v TYPE` | KV cache data type for V (default: f16) |
 | `-dt, --defrag-thold N` | KV cache defragmentation threshold (default: -1.0, < 0 - disabled)<br/>(env: LLAMA_ARG_DEFRAG_THOLD) |
-| `-np, --parallel N` | number of parallel sequences to decode (default: 1)<br/>(env:  LLAMA_ARG_N_PARALLEL) |
+| `-np, --parallel N` | number of parallel sequences to decode (default: 1)<br/>(env: LLAMA_ARG_N_PARALLEL) |
 | `-cb, --cont-batching` | enable continuous batching (a.k.a dynamic batching) (default: enabled)<br/>(env: LLAMA_ARG_CONT_BATCHING) |
 | `-nocb, --no-cont-batching` | disable continuous batching<br/>(env: LLAMA_ARG_NO_CONT_BATCHING) |
 | `--mlock` | force system to keep model in RAM rather than swapping or compressing |
@@ -128,12 +129,13 @@ The project is under active development, and we are [looking for feedback and co
 | `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)<br/> |
 | `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) |
 | `-ld, --logdir LOGDIR` | path under which to save YAML logs (no logging if unset) |
-| `--log-test` | Log test |
 | `--log-disable` | Log disable |
-| `--log-enable` | Log enable |
-| `--log-new` | Log new |
-| `--log-append` | Log append |
-| `--log-file FNAME` | Log file |
+| `--log-file FNAME` | Log to file |
+| `--log-colors` | Enable colored logging<br/>(env: LLAMA_LOG_COLORS) |
+| `-v, --verbose, --log-verbose` | Set verbosity level to infinity (i.e. log all messages, useful for debugging) |
+| `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored.<br/>(env: LLAMA_LOG_VERBOSITY) |
+| `--log-prefix` | Enable prefx in log messages<br/>(env: LLAMA_LOG_PREFIX) |
+| `--log-timestamps` | Enable timestamps in log messages<br/>(env: LLAMA_LOG_TIMESTAMPS) |
 
 Note: If both command line argument and environment variable are both set for the same param, the argument will take precedence over env var.
 
index 0ca9999940606d91f91bfb1353afc67630bdc8d2..8655c097aa51b1477da22b2e4cde9d8cbccf9b78 100644 (file)
@@ -1180,6 +1180,15 @@ struct server_context {
             SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
         }
 
+        // if context shift is disabled, we stop when it reaches the context limit
+        if (slot.n_decoded >= slot.n_ctx) {
+            slot.truncated      = true;
+            slot.stopped_limit  = true;
+            slot.has_next_token = false;
+
+            SLT_DBG(slot, "stopped due to running out of context capacity, n_decoded = %d, n_ctx = %d\n", slot.n_decoded, slot.n_ctx);
+        }
+
         if (llama_token_is_eog(model, result.tok)) {
             slot.stopped_eos    = true;
             slot.has_next_token = false;
@@ -1480,7 +1489,7 @@ struct server_context {
             if (result.error) {
                 error_handler(result.data);
                 cancel_tasks(id_tasks);
-                break;
+                return;
             }
 
             size_t idx = result.data["index"];
@@ -1827,6 +1836,14 @@ struct server_context {
         for (server_slot & slot : slots) {
             if (slot.ga_n == 1) {
                 if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
+                    if (!params.ctx_shift) {
+                        // this check is redundant (for good)
+                        // we should never get here, because generation should already stopped in process_token()
+                        slot.release();
+                        send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
+                        continue;
+                    }
+
                     // Shift context
                     const int n_keep    = slot.params.n_keep + add_bos_token;
                     const int n_left    = (int) system_tokens.size() + slot.n_past - n_keep;
@@ -1961,6 +1978,14 @@ struct server_context {
                                 continue;
                             }
                         } else {
+                            if (!params.ctx_shift) {
+                                // if context shift is disabled, we make sure prompt size is smaller than KV size
+                                if ((int) system_tokens.size() + slot.n_prompt_tokens >= slot.n_ctx) {
+                                    slot.release();
+                                    send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
+                                    continue;
+                                }
+                            }
                             if (slot.params.n_keep < 0) {
                                 slot.params.n_keep = slot.n_prompt_tokens;
                             }
diff --git a/examples/server/tests/features/ctx_shift.feature b/examples/server/tests/features/ctx_shift.feature
new file mode 100644 (file)
index 0000000..ba3afcf
--- /dev/null
@@ -0,0 +1,62 @@
+@llama.cpp
+@ctx_shift
+Feature: llama.cpp server
+
+  Background: Server startup
+    Given a server listening on localhost:8080
+    And   a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
+    And   a model file test-model.gguf
+    And   a model alias tinyllama-2
+    And   BOS token is 1
+    And   42 as server seed
+    And   256 KV cache size
+    And   32 as batch size
+    And   2 slots
+
+  Scenario: Inference with context shift
+    And   64 server max tokens to predict
+    Then  the server is starting
+    Then  the server is healthy
+    Given a prompt:
+    """
+    Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
+    Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
+    Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
+    Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
+    """
+    And   a completion request with no api error
+    Then  64 tokens are predicted matching fun|Annaks|popcorns|pictry|bowl
+    And   the completion is  truncated
+    And   109 prompt tokens are processed
+
+  Scenario Outline: Inference without context shift
+    And   <n_predict> server max tokens to predict
+    And   disable context shifting
+    Then  the server is starting
+    Then  the server is healthy
+    Given a prompt:
+    """
+    Hi how are you
+    """
+    And   a completion request with no api error
+    Then  <n_token_output> tokens are predicted matching twind|Anna
+    And   the completion is <truncated> truncated
+    And   8 prompt tokens are processed
+    Examples:
+      | n_predict | n_token_output | truncated |
+      | 64        | 64             | not       |
+      | -1        | 120            |           |
+
+  Scenario: Inference without context shift (expected error: prompt too long)
+    And   disable context shifting
+    Then  the server is starting
+    Then  the server is healthy
+    Given a prompt:
+    """
+    Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
+    Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
+    Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
+    Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
+    """
+    And   a completion request with 400 api error
+
index e1eade6cdbc9b04615e09f8b76fc80b16c3499c8..818ea3beb90cd3131f54270e499de9dd8b37b27d 100644 (file)
@@ -10,11 +10,11 @@ Feature: llama.cpp server
     And   42 as server seed
     And   2 slots
     # the bert-bge-small model has context size of 512
-    # since the generated prompts are as big as the batch size, we need to set the batch size to 512
+    # since the generated prompts are as big as the batch size, we need to set the batch size to <= 512
     # ref: https://huggingface.co/BAAI/bge-small-en-v1.5/blob/5c38ec7c405ec4b44b94cc5a9bb96e735b38267a/config.json#L20
-    And   512 as batch size
-    And   512 as ubatch size
-    And   2048 KV cache size
+    And   128 as batch size
+    And   128 as ubatch size
+    And   512 KV cache size
     And   embeddings extraction
     Then  the server is starting
     Then  the server is healthy
@@ -26,6 +26,20 @@ Feature: llama.cpp server
     """
     Then embeddings are generated
 
+  Scenario: Embedding (error: prompt too long)
+    When embeddings are computed for:
+    """
+    Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
+    Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
+    Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
+    Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
+    Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
+    Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
+    Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
+    Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
+    """
+    And  embeddings request with 500 api error
+
   Scenario: OAI Embeddings compatibility
     Given a model bert-bge-small
     When an OAI compatible embeddings computation request for:
index 062f084be42d4a671d4ca35bd70f5f0debcc37dd..0fea0fe87b79950bee4f7cac5cbf7fea31955247 100644 (file)
@@ -77,6 +77,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
     context.response_format = None
     context.temperature = None
     context.lora_file = None
+    context.disable_ctx_shift = False
 
     context.tasks_result = []
     context.concurrent_tasks = []
@@ -148,7 +149,7 @@ def step_n_slots(context, n_slots: int):
 
 @step('{n_predict:d} server max tokens to predict')
 def step_server_n_predict(context, n_predict: int):
-    context.n_server_predict = n_predict
+    context.n_server_predict = n_predict if n_predict > 0 else None
 
 
 @step('{slot_save_path} as slot save path')
@@ -180,6 +181,9 @@ def step_server_embeddings(context):
 def step_server_metrics(context):
     context.server_metrics = True
 
+@step('disable context shifting')
+def step_server_disable_ctx_shift(context):
+    context.disable_ctx_shift = True
 
 @step("the server is starting")
 def step_start_server(context):
@@ -257,7 +261,7 @@ async def step_all_slots_status(context, expected_slot_status_string: Literal['i
 @step('a completion request with {api_error} api error')
 @async_run_until_complete
 async def step_request_completion(context, api_error: Literal['raised'] | str):
-    expect_api_error = api_error == 'raised'
+    expect_api_error = api_error == 'raised' or api_error != 'no'
     seeds = await completions_seed(context, num_seeds=1)
     completion = await request_completion(context.prompts.pop(),
                                           seeds[0] if seeds is not None else seeds,
@@ -272,8 +276,11 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
     context.tasks_result.append(completion)
     if context.debug:
         print(f"Completion response: {completion}")
-    if expect_api_error:
+    if api_error == 'raised':
         assert completion == 401, f"completion must be an 401 status code: {completion}"
+    elif api_error.isdigit():
+        api_error_code = int(api_error)
+        assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}"
 
 
 @step('{predicted_n:d} tokens are predicted matching {re_content}')
@@ -645,6 +652,9 @@ def step_assert_embeddings(context):
     for embedding in context.embeddings:
         assert_embeddings(embedding)
 
+@step('embeddings request with {api_error_code:d} api error')
+def step_assert_embeddings(context, api_error_code: int):
+    assert context.embeddings == api_error_code, f"embeddings request must return code {api_error_code}, but got {context.embeddings}"
 
 @step('an OAI compatible embeddings computation request for')
 @async_run_until_complete
@@ -1089,15 +1099,17 @@ async def oai_chat_completions(user_prompt,
     return completion_response
 
 
-async def request_embedding(content, seed, base_url=None) -> list[list[float]]:
+async def request_embedding(content, seed, base_url=None) -> list[list[float]] | int:
     async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
         async with session.post(f'{base_url}/embedding',
                                 json={
                                     "content": content,
                                 }) as response:
-            assert response.status == 200
-            response_json = await response.json()
-            return [response_json['embedding']]
+            if response.status == 200:
+                response_json = await response.json()
+                return [response_json['embedding']]
+            else:
+                return response.status
 
 
 async def request_oai_embeddings(input, seed,
@@ -1372,6 +1384,8 @@ def start_server_background(context):
         server_args.append('--verbose')
     if context.lora_file:
         server_args.extend(['--lora', context.lora_file])
+    if context.disable_ctx_shift:
+        server_args.extend(['--no-context-shift'])
 
     args = [str(arg) for arg in [context.server_path, *server_args]]
     print(f"bench: starting server with: {' '.join(args)}")