]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : (embeddings) using same format for "input" and "content" (#10872)
authorXuan Son Nguyen <redacted>
Wed, 18 Dec 2024 08:55:09 +0000 (09:55 +0100)
committerGitHub <redacted>
Wed, 18 Dec 2024 08:55:09 +0000 (10:55 +0200)
* server : (embeddings) using same format for "input" and "content"

* fix test case

* handle empty input case

* fix test

examples/server/server.cpp
examples/server/tests/unit/test_embedding.py
examples/server/utils.hpp

index 436170a034fde4b215a0527673c8fb2dbe6d0b74..71566b94e61bb050b9ffb0f486f6001bea682599 100644 (file)
@@ -3651,25 +3651,33 @@ int main(int argc, char ** argv) {
         const json body = json::parse(req.body);
         bool oaicompat = false;
 
-        // an input prompt can be a string or a list of tokens (integer)
+        // for the shape of input/content, see tokenize_input_prompts()
         json prompt;
-        if (body.count("input") != 0) {
+        if (body.contains("input")) {
             oaicompat = true;
             prompt = body.at("input");
-        } else if (body.count("content") != 0) {
-            // with "content", we only support single prompt
-            prompt = std::vector<std::string>{body.at("content")};
+        } else if (body.contains("content")) {
+            oaicompat = false;
+            prompt = body.at("content");
         } else {
             res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
             return;
         }
 
+        std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
+        for (const auto & tokens : tokenized_prompts) {
+            // this check is necessary for models that do not add BOS token to the input
+            if (tokens.empty()) {
+                res_error(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST));
+                return;
+            }
+        }
+
         // create and queue the task
         json responses = json::array();
         bool error = false;
         {
             std::vector<server_task> tasks;
-            std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, /* add_special */ false, true);
             for (size_t i = 0; i < tokenized_prompts.size(); i++) {
                 server_task task   = server_task(SERVER_TASK_TYPE_EMBEDDING);
                 task.id            = ctx_server.queue_tasks.get_new_id();
index fea1d6510c89e7c3dcbb6de65af7aae57dd4baf9..4f4e9dcf087fa9262ca82ecbd6a539e85867035f 100644 (file)
@@ -45,6 +45,35 @@ def test_embedding_multiple():
         assert len(d['embedding']) > 1
 
 
+@pytest.mark.parametrize(
+    "content,is_multi_prompt",
+    [
+        # single prompt
+        ("string", False),
+        ([12, 34, 56], False),
+        ([12, 34, "string", 56, 78], False),
+        # multiple prompts
+        (["string1", "string2"], True),
+        (["string1", [12, 34, 56]], True),
+        ([[12, 34, 56], [12, 34, 56]], True),
+        ([[12, 34, 56], [12, "string", 34, 56]], True),
+    ]
+)
+def test_embedding_mixed_input(content, is_multi_prompt: bool):
+    global server
+    server.start()
+    res = server.make_request("POST", "/embeddings", data={"content": content})
+    assert res.status_code == 200
+    if is_multi_prompt:
+        assert len(res.body) == len(content)
+        for d in res.body:
+            assert 'embedding' in d
+            assert len(d['embedding']) > 1
+    else:
+        assert 'embedding' in res.body
+        assert len(res.body['embedding']) > 1
+
+
 def test_embedding_openai_library_single():
     global server
     server.start()
@@ -102,8 +131,8 @@ def test_same_prompt_give_same_result():
 @pytest.mark.parametrize(
     "content,n_tokens",
     [
-        ("I believe the meaning of life is", 7),
-        ("This is a test", 4),
+        ("I believe the meaning of life is", 9),
+        ("This is a test", 6),
     ]
 )
 def test_embedding_usage_single(content, n_tokens):
@@ -126,4 +155,4 @@ def test_embedding_usage_multiple():
     })
     assert res.status_code == 200
     assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
-    assert res.body['usage']['prompt_tokens'] == 2 * 7
+    assert res.body['usage']['prompt_tokens'] == 2 * 9
index 8fffe484aec12398e07735e6486ce86e1cebbc35..ffdffe904308f98d11024f60a3f78376de55fdf4 100644 (file)
@@ -138,6 +138,7 @@ static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_
  * and multiple prompts (multi-tasks):
  * - "prompt": ["string1", "string2"]
  * - "prompt": ["string1", [12, 34, 56]]
+ * - "prompt": [[12, 34, 56], [78, 90, 12]]
  * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
  */
 static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {