]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server: continue to update other slots on embedding concurrent request (#5699)
authorPierrick Hymbert <redacted>
Sat, 24 Feb 2024 18:16:04 +0000 (19:16 +0100)
committerGitHub <redacted>
Sat, 24 Feb 2024 18:16:04 +0000 (19:16 +0100)
* server: #5655 - continue to update other slots on embedding concurrent request.

* server: tests: add multi users embeddings as fixed

* server: tests: adding OAI compatible embedding concurrent endpoint

* server: tests: adding OAI compatible embedding with multiple inputs

examples/server/server.cpp
examples/server/tests/features/issues.feature
examples/server/tests/features/parallel.feature
examples/server/tests/features/server.feature
examples/server/tests/features/steps/steps.py

index 9fb436c2a18ec134f9e15edaee13a1f1a19140b0..19a8c1067e72ad442000bfd7e6c00b073b0273a6 100644 (file)
@@ -1836,7 +1836,7 @@ struct llama_server_context
                     send_embedding(slot);
                     slot.release();
                     slot.i_batch = -1;
-                    return true;
+                    continue;
                 }
 
                 completion_token_output result;
index 542006d9a8df2600ebad291a660c780656d7c25d..bf5a175a357ca6017f21f11b12b9a45387442e1e 100644 (file)
@@ -1,36 +1,4 @@
 # List of ongoing issues
 @bug
 Feature: Issues
-    # Issue #5655
-  Scenario: Multi users embeddings
-    Given a server listening on localhost:8080
-    And   a model file stories260K.gguf
-    And   a model alias tinyllama-2
-    And   42 as server seed
-    And   64 KV cache size
-    And   2 slots
-    And   continuous batching
-    And   embeddings extraction
-    Then  the server is starting
-    Then  the server is healthy
-
-    Given a prompt:
-      """
-      Write a very long story about AI.
-      """
-    And a prompt:
-      """
-      Write another very long music lyrics.
-      """
-    And a prompt:
-      """
-      Write a very long poem.
-      """
-    And a prompt:
-      """
-      Write a very long joke.
-      """
-    Given concurrent embedding requests
-    Then the server is busy
-    Then the server is idle
-    Then all embeddings are generated
+  # No confirmed issue at the moment
index 802d624ffc9a335c796f5c9e724aa6c68ba03a1a..c85f9de1d9a5259c9fe9cfe5623c83fe06bc6ec0 100644 (file)
@@ -8,6 +8,7 @@ Feature: Parallel
     And   42 as server seed
     And   64 KV cache size
     And   2 slots
+    And   embeddings extraction
     And   continuous batching
     Then  the server is starting
     Then  the server is healthy
@@ -75,3 +76,48 @@ Feature: Parallel
     Then the server is busy
     Then the server is idle
     Then all prompts are predicted
+
+  Scenario: Multi users embeddings
+    Given a prompt:
+      """
+      Write a very long story about AI.
+      """
+    And a prompt:
+      """
+      Write another very long music lyrics.
+      """
+    And a prompt:
+      """
+      Write a very long poem.
+      """
+    And a prompt:
+      """
+      Write a very long joke.
+      """
+    Given concurrent embedding requests
+    Then the server is busy
+    Then the server is idle
+    Then all embeddings are generated
+
+  Scenario: Multi users OAI compatibility embeddings
+    Given a prompt:
+      """
+      In which country Paris is located ?
+      """
+    And a prompt:
+      """
+      Is Madrid the capital of Spain ?
+      """
+    And a prompt:
+      """
+      What is the biggest US city ?
+      """
+    And a prompt:
+      """
+      What is the capital of Bulgaria ?
+      """
+    And   a model tinyllama-2
+    Given concurrent OAI embedding requests
+    Then the server is busy
+    Then the server is idle
+    Then all embeddings are generated
index fedcfe5aef1b372adc354baa45e947d0217904e8..5f81d256a548ccbedb65a9d222210e3570d3a766 100644 (file)
@@ -60,6 +60,19 @@ Feature: llama.cpp server
     """
     Then embeddings are generated
 
+  Scenario: OAI Embeddings compatibility with multiple inputs
+    Given a model tinyllama-2
+    Given a prompt:
+      """
+      In which country Paris is located ?
+      """
+    And a prompt:
+      """
+      Is Madrid the capital of Spain ?
+      """
+    When an OAI compatible embeddings computation request for multiple inputs
+    Then embeddings are generated
+
 
   Scenario: Tokenize / Detokenize
     When tokenizing:
index 50f2b641e764e7ae1e7639c0bfa375e3c99c1909..9c825fdbcd7f52d0b2fc8990dd6ea46e192afe64 100644 (file)
@@ -1,4 +1,5 @@
 import asyncio
+import collections
 import json
 import os
 import re
@@ -261,35 +262,35 @@ def step_a_prompt_prompt(context, prompt):
 @step(u'concurrent completion requests')
 @async_run_until_complete()
 async def step_concurrent_completion_requests(context):
-    await concurrent_completion_requests(context,
-                                         request_completion,
-                                         # prompt is inserted automatically
-                                         context.base_url,
-                                         debug=context.debug,
-                                         n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
-                                         server_seed=context.server_seed if hasattr(context, 'server_seed') else None,
-                                         user_api_key=context.user_api_key if hasattr(context,
-                                                                                      'user_api_key') else None)
+    await concurrent_requests(context,
+                              request_completion,
+                              # prompt is inserted automatically
+                              context.base_url,
+                              debug=context.debug,
+                              n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
+                              server_seed=context.server_seed if hasattr(context, 'server_seed') else None,
+                              user_api_key=context.user_api_key if hasattr(context,
+                                                                           'user_api_key') else None)
 
 
 @step(u'concurrent OAI completions requests')
 @async_run_until_complete
 async def step_oai_chat_completions(context):
-    await concurrent_completion_requests(context, oai_chat_completions,
-                                         # user_prompt is inserted automatically
-                                         context.system_prompt,
-                                         context.base_url,
-                                         True,  # async_client
-                                         model=context.model
-                                         if hasattr(context, 'model') else None,
-                                         n_predict=context.n_predict
-                                         if hasattr(context, 'n_predict') else None,
-                                         enable_streaming=context.enable_streaming
-                                         if hasattr(context, 'enable_streaming') else None,
-                                         server_seed=context.server_seed
-                                         if hasattr(context, 'server_seed') else None,
-                                         user_api_key=context.user_api_key
-                                         if hasattr(context, 'user_api_key') else None)
+    await concurrent_requests(context, oai_chat_completions,
+                              # user_prompt is inserted automatically
+                              context.system_prompt,
+                              context.base_url,
+                              True,  # async_client
+                              model=context.model
+                              if hasattr(context, 'model') else None,
+                              n_predict=context.n_predict
+                              if hasattr(context, 'n_predict') else None,
+                              enable_streaming=context.enable_streaming
+                              if hasattr(context, 'enable_streaming') else None,
+                              server_seed=context.server_seed
+                              if hasattr(context, 'server_seed') else None,
+                              user_api_key=context.user_api_key
+                              if hasattr(context, 'user_api_key') else None)
 
 
 @step(u'all prompts are predicted')
@@ -316,36 +317,58 @@ async def all_prompts_are_predicted(context, expected_predicted_n=None):
 @step(u'embeddings are computed for')
 @async_run_until_complete
 async def step_compute_embedding(context):
-    content = context.text
-    base_url = context.base_url
-    context.embeddings = await request_embedding(content, base_url)
+    context.embeddings = await request_embedding(context.text, base_url=context.base_url)
 
 
 @step(u'embeddings are generated')
 def step_assert_embeddings(context):
-    assert_embeddings(context.embeddings)
+    if len(context.prompts) == 0:
+        assert_embeddings(context.embeddings)
+    else:
+        assert len(context.embeddings) == len(context.prompts), (f"unexpected response:\n"
+                                                                 f"context.prompts={context.prompts}\n"
+                                                                 f"context.embeddings={context.embeddings}")
+        for embedding in context.embeddings:
+            context.prompts.pop()
+            assert_embeddings(embedding)
 
 
 @step(u'an OAI compatible embeddings computation request for')
-def step_oai_compute_embedding(context):
-    openai.api_key = 'nope'  # openai client always expects an api_keu
-    if context.user_api_key is not None:
-        openai.api_key = context.user_api_key
-    openai.api_base = f'{context.base_url}/v1'
-    embeddings = openai.Embedding.create(
-        model=context.model,
-        input=context.text,
-    )
-    context.embeddings = embeddings
+@async_run_until_complete
+async def step_oai_compute_embeddings(context):
+    context.embeddings = await request_oai_embeddings(context.text,
+                                                      base_url=context.base_url,
+                                                      user_api_key=context.user_api_key,
+                                                      model=context.model)
+
+
+@step(u'an OAI compatible embeddings computation request for multiple inputs')
+@async_run_until_complete
+async def step_oai_compute_embeddings_multiple_inputs(context):
+    context.embeddings = await request_oai_embeddings(context.prompts,
+                                                      base_url=context.base_url,
+                                                      user_api_key=context.user_api_key,
+                                                      model=context.model)
 
 
 @step(u'concurrent embedding requests')
 @async_run_until_complete()
 async def step_concurrent_embedding_requests(context):
-    await concurrent_completion_requests(context,
-                                         request_embedding,
-                                         # prompt is inserted automatically
-                                         context.base_url)
+    await concurrent_requests(context,
+                              request_embedding,
+                              # prompt is inserted automatically
+                              base_url=context.base_url)
+
+
+@step(u'concurrent OAI embedding requests')
+@async_run_until_complete()
+async def step_concurrent_oai_embedding_requests(context):
+    await concurrent_requests(context,
+                              request_oai_embeddings,
+                              # prompt is inserted automatically
+                              base_url=context.base_url,
+                              async_client=True,
+                              model=context.model)
 
 
 @step(u'all embeddings are generated')
@@ -401,7 +424,7 @@ def step_check_options_header_value(context, cors_header, cors_header_value):
     assert context.options_response.headers[cors_header] == cors_header_value
 
 
-async def concurrent_completion_requests(context, f_completion, *args, **kwargs):
+async def concurrent_requests(context, f_completion, *args, **kwargs):
     n_prompts = len(context.prompts)
     if context.debug:
         print(f"starting {n_prompts} concurrent completion requests...")
@@ -565,7 +588,7 @@ async def oai_chat_completions(user_prompt,
     return completion_response
 
 
-async def request_embedding(content, base_url):
+async def request_embedding(content, base_url=None):
     async with aiohttp.ClientSession() as session:
         async with session.post(f'{base_url}/embedding',
                                 json={
@@ -576,6 +599,46 @@ async def request_embedding(content, base_url):
             return response_json['embedding']
 
 
+async def request_oai_embeddings(input,
+                                 base_url=None, user_api_key=None,
+                                 model=None, async_client=False):
+    # openai client always expects an api_key
+    user_api_key = user_api_key if user_api_key is not None else 'nope'
+    if async_client:
+        origin = 'llama.cpp'
+        if user_api_key is not None:
+            headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
+        async with aiohttp.ClientSession() as session:
+            async with session.post(f'{base_url}/v1/embeddings',
+                                    json={
+                                        "input": input,
+                                        "model": model,
+                                    },
+                                    headers=headers) as response:
+                assert response.status == 200, f"received status code not expected: {response.status}"
+                assert response.headers['Access-Control-Allow-Origin'] == origin
+                assert response.headers['Content-Type'] == "application/json; charset=utf-8"
+                response_json = await response.json()
+                assert response_json['model'] == model, f"invalid model received: {response_json['model']}"
+                assert response_json['object'] == 'list'
+                return response_json['data']
+    else:
+        openai.api_key = user_api_key
+        openai.api_base = f'{base_url}/v1'
+        oai_embeddings = openai.Embedding.create(
+            model=model,
+            input=input,
+        )
+
+        if isinstance(input, collections.abc.Sequence):
+            embeddings = []
+            for an_oai_embeddings in oai_embeddings.data:
+                embeddings.append(an_oai_embeddings.embedding)
+        else:
+            embeddings = oai_embeddings.data.embedding
+        return embeddings
+
+
 def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
     content = completion_response['content']
     n_predicted = completion_response['timings']['predicted_n']