]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server: tests: add truncated prompt tests, better kv cache size (#5933)
authorPierrick Hymbert <redacted>
Sat, 9 Mar 2024 09:30:04 +0000 (10:30 +0100)
committerGitHub <redacted>
Sat, 9 Mar 2024 09:30:04 +0000 (11:30 +0200)
* server: tests: add truncated prompt tests, better size

* server, tests : update regex

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/server/server.cpp
examples/server/tests/features/parallel.feature
examples/server/tests/features/server.feature
examples/server/tests/features/steps/steps.py

index 59a59d56b60235125d972e8ad3ddb73e2489d659..6f44499843a633b914858864a2fcdc5b516cbbd0 100644 (file)
@@ -1128,6 +1128,7 @@ struct server_context {
 
             LOG_VERBOSE("stopped by limit", {
                 {"id_slot",   slot.id},
+                {"id_task",   slot.id_task},
                 {"n_decoded", slot.n_decoded},
                 {"n_predict", slot.params.n_predict},
             });
@@ -1141,6 +1142,8 @@ struct server_context {
         }
 
         LOG_VERBOSE("next token", {
+            {"id_slot",        slot.id},
+            {"id_task",        slot.id_task},
             {"token",          result.tok},
             {"token_text",     tokens_to_output_formatted_string(ctx, result.tok)},
             {"has_next_token", slot.has_next_token},
@@ -1750,6 +1753,15 @@ struct server_context {
                         slot.n_past = 0;
                         slot.n_prompt_tokens = prompt_tokens.size();
 
+                        LOG_VERBOSE("prompt tokenized", {
+                            {"id_slot",         slot.id},
+                            {"id_task",         slot.id_task},
+                            {"n_ctx",           slot.n_ctx},
+                            {"n_keep",          slot.params.n_keep},
+                            {"n_prompt_tokens", slot.n_prompt_tokens},
+                            {"prompt_tokens",   tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
+                        });
+
                         if (slot.embedding) {
                             // this prompt is too large to process - discard it
                             if (slot.n_prompt_tokens > n_batch) {
@@ -1788,10 +1800,13 @@ struct server_context {
                                 slot.n_prompt_tokens = prompt_tokens.size();
 
                                 LOG_VERBOSE("input truncated", {
-                                    {"n_ctx",         slot.n_ctx},
-                                    {"n_keep",        slot.params.n_keep},
-                                    {"n_left",        n_left},
-                                    {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
+                                    {"id_slot",         slot.id},
+                                    {"id_task",         slot.id_task},
+                                    {"n_ctx",           slot.n_ctx},
+                                    {"n_keep",          slot.params.n_keep},
+                                    {"n_left",          n_left},
+                                    {"n_prompt_tokens", slot.n_prompt_tokens},
+                                    {"prompt_tokens",   tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
                                 });
 
                                 GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
index 066698c8e3c1ef1e1b579ebd8755b2eb28fee812..a66fed626619dacc197b22a1ebb49e5c38934b4f 100644 (file)
@@ -6,8 +6,8 @@ Feature: Parallel
     Given a server listening on localhost:8080
     And   a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
     And   42 as server seed
-    And   512 as batch size
-    And   64 KV cache size
+    And   128 as batch size
+    And   256 KV cache size
     And   2 slots
     And   continuous batching
     Then  the server is starting
@@ -76,6 +76,7 @@ Feature: Parallel
       | disabled  | 128       |
       | enabled   | 64        |
 
+
   Scenario:  Multi users with total number of tokens to predict exceeds the KV Cache size #3969
     Given a prompt:
       """
index 878ac1363c419b675391ad07bf9401a0ed3df53b..aa132fa3472ef99030f7f690c9a3efe6b4e55750 100644 (file)
@@ -10,11 +10,10 @@ Feature: llama.cpp server
       # KV Cache corresponds to the total amount of tokens
       # that can be stored across all independent sequences: #4130
       # see --ctx-size and #5568
-    And   32 KV cache size
-    And   512 as batch size
-    And   1 slots
-    And   embeddings extraction
-    And   32 server max tokens to predict
+    And   256 KV cache size
+    And   32 as batch size
+    And   2 slots
+    And   64 server max tokens to predict
     And   prometheus compatible metrics exposed
     Then  the server is starting
     Then  the server is healthy
@@ -23,18 +22,35 @@ Feature: llama.cpp server
     Then the server is ready
     And  all slots are idle
 
+
   Scenario Outline: Completion
     Given a prompt <prompt>
     And   <n_predict> max tokens to predict
     And   a completion request with no api error
     Then  <n_predicted> tokens are predicted matching <re_content>
+    And   the completion is <truncated> truncated
+    And   <n_prompt> prompt tokens are processed
     And   prometheus metrics are exposed
     And   metric llamacpp:tokens_predicted is <n_predicted>
 
     Examples: Prompts
-      | prompt                           | n_predict | re_content                       | n_predicted |
-      | I believe the meaning of life is | 8         | (read\|going)+                   | 8           |
-      | Write a joke about AI            | 64        | (park\|friends\|scared\|always)+ | 32          |
+      | prompt                                                                    | n_predict | re_content                    | n_prompt | n_predicted | truncated |
+      | I believe the meaning of life is                                          | 8         | (read\|going)+                | 18       | 8           | not       |
+      | Write a joke about AI from a very long prompt which will not be truncated | 256       | (princesses\|everyone\|kids)+ | 46       | 64          | not       |
+
+  Scenario: Completion prompt truncated
+    Given a prompt:
+    """
+    Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
+    Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
+    Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
+    Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
+    """
+    And   a completion request with no api error
+    Then  64 tokens are predicted matching fun|Annaks|popcorns
+    And   the completion is  truncated
+    And   109 prompt tokens are processed
+
 
   Scenario Outline: OAI Compatibility
     Given a model <model>
@@ -44,11 +60,14 @@ Feature: llama.cpp server
     And   streaming is <enable_streaming>
     Given an OAI compatible chat completions request with no api error
     Then  <n_predicted> tokens are predicted matching <re_content>
+    And   <n_prompt> prompt tokens are processed
+    And   the completion is <truncated> truncated
 
     Examples: Prompts
-      | model        | system_prompt               | user_prompt                          | max_tokens | re_content             | n_predicted | enable_streaming |
-      | llama-2      | Book                        | What is the best book                | 8          | (Mom\|what)+           | 8           | disabled         |
-      | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64         | (thanks\|happy\|bird)+ | 32          | enabled          |
+      | model        | system_prompt               | user_prompt                          | max_tokens | re_content             | n_prompt | n_predicted | enable_streaming | truncated |
+      | llama-2      | Book                        | What is the best book                | 8          | (Here\|what)+          | 77       | 8           | disabled         | not       |
+      | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 128        | (thanks\|happy\|bird)+ | -1       | 64          | enabled          |           |
+
 
   Scenario: Tokenize / Detokenize
     When tokenizing:
index d7f0058360c2355646031e458940219d13c7819f..0076f805be4d32696335b06e7a84a127e677cb47 100644 (file)
@@ -196,12 +196,30 @@ async def step_request_completion(context, api_error):
 
 @step(u'{predicted_n:d} tokens are predicted matching {re_content}')
 def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
-    assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n, re_content)
+    context.completion = context.tasks_result.pop()
+    assert_n_tokens_predicted(context.completion, predicted_n, re_content)
 
 
 @step(u'{predicted_n:d} tokens are predicted')
 def step_n_tokens_predicted(context, predicted_n):
-    assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n)
+    context.completion = context.tasks_result.pop()
+    assert_n_tokens_predicted(context.completion, predicted_n)
+
+
+@step(u'the completion is  truncated')
+def step_assert_completion_truncated(context):
+    step_assert_completion_truncated(context, '')
+
+
+@step(u'the completion is {truncated} truncated')
+def step_assert_completion_truncated(context, truncated):
+    truncated = truncated != "not"
+    assert context.completion['truncated'] == truncated, f'{context.completion}'
+
+
+@step(u'{n_prompt:d} prompt tokens are processed')
+def step_impl(context, n_prompt):
+    assert n_prompt < 0 or n_prompt == context.completion['timings']['prompt_n'], f"n_prompt={context.completion['timings']['prompt_n']}"
 
 
 @step(u'a user prompt {user_prompt}')
@@ -722,7 +740,8 @@ async def oai_chat_completions(user_prompt,
     completion_response = {
         'content': '',
         'timings': {
-            'predicted_n': 0
+            'predicted_n': 0,
+            'prompt_n': 0
         }
     }
     if async_client:
@@ -763,7 +782,8 @@ async def oai_chat_completions(user_prompt,
                         completion_response = {
                             'content': chat_completion_raw['choices'][0]['message'],
                             'timings': {
-                                'predicted_n': chat_completion_raw['usage']['completion_tokens']
+                                'predicted_n': chat_completion_raw['usage']['completion_tokens'],
+                                'prompt_n': chat_completion_raw['usage']['prompt_tokens']
                             }
                         }
                     else:
@@ -792,13 +812,16 @@ async def oai_chat_completions(user_prompt,
                 if 'content' in delta:
                     completion_response['content'] += delta['content']
                     completion_response['timings']['predicted_n'] += 1
+                completion_response['truncated'] = chunk.choices[0].finish_reason != 'stop'
         else:
             assert len(chat_completion.choices) == 1
             completion_response = {
                 'content': chat_completion.choices[0].message.content,
                 'timings': {
-                    'predicted_n': chat_completion.usage.completion_tokens
-                }
+                    'predicted_n': chat_completion.usage.completion_tokens,
+                    'prompt_n': chat_completion.usage.prompt_tokens
+                    },
+                'truncated': chat_completion.choices[0].finish_reason != 'stop'
             }
     if debug:
         print("OAI response formatted to llama.cpp:", completion_response)