]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : output embeddings for all tokens when pooling = none (#10861)
authorGeorgi Gerganov <redacted>
Wed, 18 Dec 2024 11:01:41 +0000 (13:01 +0200)
committerGitHub <redacted>
Wed, 18 Dec 2024 11:01:41 +0000 (13:01 +0200)
* server : add "tokens" output

ggml-ci

* server : output embeddings for all tokens when pooling = none

ggml-ci

* server : update readme [no ci]

* server : fix spacing [no ci]

Co-authored-by: Xuan Son Nguyen <redacted>
* server : be explicit about the pooling type in the tests

ggml-ci

* server : update /embeddings and /v1/embeddings endpoints

ggml-ci

* server : do not normalize embeddings when there is no pooling

ggml-ci

* server : update readme

ggml-ci

* server : fixes

* tests : update server tests

ggml-ci

* server : update readme [no ci]

* server : remove rebase artifact

---------

Co-authored-by: Xuan Son Nguyen <redacted>
common/common.cpp
common/common.h
examples/gritlm/gritlm.cpp
examples/retrieval/retrieval.cpp
examples/server/README.md
examples/server/server.cpp
examples/server/tests/unit/test_embedding.py
examples/server/tests/utils.py

index c0c98232ed3bba1720ec76957f3e3b44f9d9b8e9..05d3ba766e38bf2490f00ad65cce157af25265d1 100644 (file)
@@ -1780,7 +1780,9 @@ void common_embd_normalize(const float * inp, float * out, int n, int embd_norm)
             break;
         case 0: // max absolute
             for (int i = 0; i < n; i++) {
-                if (sum < std::abs(inp[i])) sum = std::abs(inp[i]);
+                if (sum < std::abs(inp[i])) {
+                    sum = std::abs(inp[i]);
+                }
             }
             sum /= 32760.0; // make an int16 range
             break;
index 5f556c24d933c61e8666207a5ca9c02389be2158..ec0e49f6f1806e3a6da5008a689a850791e73136 100644 (file)
@@ -596,7 +596,8 @@ void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_si
 // Embedding utils
 //
 
-void common_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2);
+// TODO: repace embd_norm with an enum
+void common_embd_normalize(const float * inp, float * out, int n, int embd_norm);
 
 float common_embd_similarity_cos(const float * embd1, const float * embd2, int n);
 
index 6e42fa0734ecbe6fa3b7c84b0d7b4d5a0f64c4cf..18a945b33905fa4fc86e1b91d4cc80d373902bac 100644 (file)
@@ -75,7 +75,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
         }
 
         std::vector<float> emb_norm(emb_unorm.size());
-        common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd);
+        common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd, 2);
         result.push_back(emb_norm);
 
 #ifdef GRIT_DEBUG
index 23ff4db27a4201c4fcc5bd10597df7d822957645..a5c6fe7e58523ecf023b9c648108493648937e5b 100644 (file)
@@ -107,7 +107,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
         }
 
         float * out = output + batch.seq_id[i][0] * n_embd;
-        common_embd_normalize(embd, out, n_embd);
+        common_embd_normalize(embd, out, n_embd, 2);
     }
 }
 
index ecd24c899fc86dfaf5932e99eee6e0733fd0afd3..d006a8d37cf6fe02cf9c594d83f591ddd906e595 100644 (file)
@@ -763,6 +763,8 @@ curl http://localhost:8080/v1/chat/completions \
 
 ### POST `/v1/embeddings`: OpenAI-compatible embeddings API
 
+This endpoint requires that the model uses a pooling different than type `none`. The embeddings are normalized using the Eucledian norm.
+
 *Options:*
 
 See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-reference/embeddings).
@@ -795,6 +797,46 @@ See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-r
   }'
   ```
 
+### POST `/embeddings`: non-OpenAI-compatible embeddings API
+
+This endpoint supports all poolings, including `--pooling none`. When the pooling is `none`, the responses will contain the *unnormalized* embeddings for *all* input tokens. For all other pooling types, only the pooled embeddings are returned, normalized using Euclidian norm.
+
+Note that the response format of this endpoint is different from `/v1/embeddings`.
+
+*Options:*
+
+Same as the `/v1/embeddings` endpoint.
+
+*Examples:*
+
+Same as the `/v1/embeddings` endpoint.
+
+**Response format**
+
+```json
+[
+  {
+    "index": 0,
+    "embedding": [
+      [ ... embeddings for token 0   ... ],
+      [ ... embeddings for token 1   ... ],
+      [ ... ]
+      [ ... embeddings for token N-1 ... ],
+    ]
+  },
+  ...
+  {
+    "index": P,
+    "embedding": [
+      [ ... embeddings for token 0   ... ],
+      [ ... embeddings for token 1   ... ],
+      [ ... ]
+      [ ... embeddings for token N-1 ... ],
+    ]
+  }
+]
+```
+
 ### GET `/slots`: Returns the current slots processing state
 
 > [!WARNING]
index 40aac33f0bf135f2b1871af5cf1a1e9bd812ee7f..5ed4e8d2744285c39272a2c7ec6621470e67c986 100644 (file)
@@ -726,18 +726,32 @@ struct server_task_result_cmpl_partial : server_task_result {
 
 struct server_task_result_embd : server_task_result {
     int index = 0;
-    std::vector<float> embedding;
+    std::vector<std::vector<float>> embedding;
 
     int32_t n_tokens;
 
+    // OAI-compat fields
+    bool oaicompat = false;
+
     virtual int get_index() override {
         return index;
     }
 
     virtual json to_json() override {
+        return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
+    }
+
+    json to_json_non_oaicompat() {
+        return json {
+            {"index",     index},
+            {"embedding", embedding},
+        };
+    }
+
+    json to_json_oaicompat() {
         return json {
             {"index",            index},
-            {"embedding",        embedding},
+            {"embedding",        embedding[0]},
             {"tokens_evaluated", n_tokens},
         };
     }
@@ -2017,9 +2031,10 @@ struct server_context {
 
     void send_embedding(const server_slot & slot, const llama_batch & batch) {
         auto res = std::make_unique<server_task_result_embd>();
-        res->id    = slot.id_task;
-        res->index = slot.index;
-        res->n_tokens = slot.n_prompt_tokens;
+        res->id        = slot.id_task;
+        res->index     = slot.index;
+        res->n_tokens  = slot.n_prompt_tokens;
+        res->oaicompat = slot.params.oaicompat;
 
         const int n_embd = llama_n_embd(model);
 
@@ -2038,12 +2053,18 @@ struct server_context {
             if (embd == NULL) {
                 SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
 
-                res->embedding = std::vector<float>(n_embd, 0.0f);
+                res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
                 continue;
             }
 
-            common_embd_normalize(embd, embd_res.data(), n_embd);
-            res->embedding = embd_res;
+            // normalize only when there is pooling
+            // TODO: configurable
+            if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
+                common_embd_normalize(embd, embd_res.data(), n_embd, 2);
+                res->embedding.push_back(embd_res);
+            } else {
+                res->embedding.push_back({ embd, embd + n_embd });
+            }
         }
 
         SLT_DBG(slot, "%s", "sending embeddings\n");
@@ -2657,7 +2678,10 @@ struct server_context {
 
                     // add prompt tokens for processing in the current batch
                     while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
-                        common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false);
+                        // without pooling, we want to output the embeddings for all the tokens in the batch
+                        const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
+
+                        common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
 
                         if (slot.params.cache_prompt) {
                             slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
@@ -3665,14 +3689,17 @@ int main(int argc, char ** argv) {
         res_ok(res, data);
     };
 
-    const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
+    const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat) {
         const json body = json::parse(req.body);
-        bool oaicompat = false;
+
+        if (oaicompat && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
+            res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
+            return;
+        }
 
         // for the shape of input/content, see tokenize_input_prompts()
         json prompt;
-        if (body.contains("input")) {
-            oaicompat = true;
+        if (body.count("input") != 0) {
             prompt = body.at("input");
         } else if (body.contains("content")) {
             oaicompat = false;
@@ -3697,10 +3724,15 @@ int main(int argc, char ** argv) {
         {
             std::vector<server_task> tasks;
             for (size_t i = 0; i < tokenized_prompts.size(); i++) {
-                server_task task   = server_task(SERVER_TASK_TYPE_EMBEDDING);
+                server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
+
                 task.id            = ctx_server.queue_tasks.get_new_id();
                 task.index         = i;
                 task.prompt_tokens = std::move(tokenized_prompts[i]);
+
+                // OAI-compat
+                task.params.oaicompat = oaicompat;
+
                 tasks.push_back(task);
             }
 
@@ -3728,12 +3760,18 @@ int main(int argc, char ** argv) {
         }
 
         // write JSON response
-        json root = oaicompat
-            ? format_embeddings_response_oaicompat(body, responses)
-            : responses.size() == 1 ? responses[0] : json(responses);
+        json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses);
         res_ok(res, root);
     };
 
+    const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
+        handle_embeddings_impl(req, res, false);
+    };
+
+    const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
+        handle_embeddings_impl(req, res, true);
+    };
+
     const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
         if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
             res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
@@ -3907,7 +3945,7 @@ int main(int argc, char ** argv) {
     svr->Post("/infill",              handle_infill);
     svr->Post("/embedding",           handle_embeddings); // legacy
     svr->Post("/embeddings",          handle_embeddings);
-    svr->Post("/v1/embeddings",       handle_embeddings);
+    svr->Post("/v1/embeddings",       handle_embeddings_oai);
     svr->Post("/rerank",              handle_rerank);
     svr->Post("/reranking",           handle_rerank);
     svr->Post("/v1/rerank",           handle_rerank);
index 4f4e9dcf087fa9262ca82ecbd6a539e85867035f..e32d7458296051fe4be0ade439167232f82e10a1 100644 (file)
@@ -14,8 +14,9 @@ def create_server():
 
 def test_embedding_single():
     global server
+    server.pooling = 'last'
     server.start()
-    res = server.make_request("POST", "/embeddings", data={
+    res = server.make_request("POST", "/v1/embeddings", data={
         "input": "I believe the meaning of life is",
     })
     assert res.status_code == 200
@@ -29,8 +30,9 @@ def test_embedding_single():
 
 def test_embedding_multiple():
     global server
+    server.pooling = 'last'
     server.start()
-    res = server.make_request("POST", "/embeddings", data={
+    res = server.make_request("POST", "/v1/embeddings", data={
         "input": [
             "I believe the meaning of life is",
             "Write a joke about AI from a very long prompt which will not be truncated",
@@ -46,7 +48,7 @@ def test_embedding_multiple():
 
 
 @pytest.mark.parametrize(
-    "content,is_multi_prompt",
+    "input,is_multi_prompt",
     [
         # single prompt
         ("string", False),
@@ -59,25 +61,55 @@ def test_embedding_multiple():
         ([[12, 34, 56], [12, "string", 34, 56]], True),
     ]
 )
-def test_embedding_mixed_input(content, is_multi_prompt: bool):
+def test_embedding_mixed_input(input, is_multi_prompt: bool):
     global server
     server.start()
-    res = server.make_request("POST", "/embeddings", data={"content": content})
+    res = server.make_request("POST", "/v1/embeddings", data={"input": input})
     assert res.status_code == 200
+    data = res.body['data']
     if is_multi_prompt:
-        assert len(res.body) == len(content)
-        for d in res.body:
+        assert len(data) == len(input)
+        for d in data:
             assert 'embedding' in d
             assert len(d['embedding']) > 1
     else:
-        assert 'embedding' in res.body
-        assert len(res.body['embedding']) > 1
+        assert 'embedding' in data[0]
+        assert len(data[0]['embedding']) > 1
+
+
+def test_embedding_pooling_none():
+    global server
+    server.pooling = 'none'
+    server.start()
+    res = server.make_request("POST", "/embeddings", data={
+        "input": "hello hello hello",
+    })
+    assert res.status_code == 200
+    assert 'embedding' in res.body[0]
+    assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special
+
+    # make sure embedding vector is not normalized
+    for x in res.body[0]['embedding']:
+        assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON
+
+
+def test_embedding_pooling_none_oai():
+    global server
+    server.pooling = 'none'
+    server.start()
+    res = server.make_request("POST", "/v1/embeddings", data={
+        "input": "hello hello hello",
+    })
+
+    # /v1/embeddings does not support pooling type 'none'
+    assert res.status_code == 400
 
 
 def test_embedding_openai_library_single():
     global server
+    server.pooling = 'last'
     server.start()
-    client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
+    client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
     res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
     assert len(res.data) == 1
     assert len(res.data[0].embedding) > 1
@@ -85,8 +117,9 @@ def test_embedding_openai_library_single():
 
 def test_embedding_openai_library_multiple():
     global server
+    server.pooling = 'last'
     server.start()
-    client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
+    client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
     res = client.embeddings.create(model="text-embedding-3-small", input=[
         "I believe the meaning of life is",
         "Write a joke about AI from a very long prompt which will not be truncated",
@@ -100,8 +133,9 @@ def test_embedding_openai_library_multiple():
 
 def test_embedding_error_prompt_too_long():
     global server
+    server.pooling = 'last'
     server.start()
-    res = server.make_request("POST", "/embeddings", data={
+    res = server.make_request("POST", "/v1/embeddings", data={
         "input": "This is a test " * 512,
     })
     assert res.status_code != 200
@@ -109,8 +143,9 @@ def test_embedding_error_prompt_too_long():
 
 
 def test_same_prompt_give_same_result():
+    server.pooling = 'last'
     server.start()
-    res = server.make_request("POST", "/embeddings", data={
+    res = server.make_request("POST", "/v1/embeddings", data={
         "input": [
             "I believe the meaning of life is",
             "I believe the meaning of life is",
@@ -138,7 +173,7 @@ def test_same_prompt_give_same_result():
 def test_embedding_usage_single(content, n_tokens):
     global server
     server.start()
-    res = server.make_request("POST", "/embeddings", data={"input": content})
+    res = server.make_request("POST", "/v1/embeddings", data={"input": content})
     assert res.status_code == 200
     assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
     assert res.body['usage']['prompt_tokens'] == n_tokens
@@ -147,7 +182,7 @@ def test_embedding_usage_single(content, n_tokens):
 def test_embedding_usage_multiple():
     global server
     server.start()
-    res = server.make_request("POST", "/embeddings", data={
+    res = server.make_request("POST", "/v1/embeddings", data={
         "input": [
             "I believe the meaning of life is",
             "I believe the meaning of life is",
index d988ccf5e3061d7989a4dec893bbcc1ef8d487f7..277125e88b53421466e543712e8d961acd45825b 100644 (file)
@@ -65,6 +65,7 @@ class ServerProcess:
     server_reranking: bool | None = False
     server_metrics: bool | None = False
     server_slots: bool | None = False
+    pooling: str | None = None
     draft: int | None = None
     api_key: str | None = None
     response_format: str | None = None
@@ -132,6 +133,8 @@ class ServerProcess:
             server_args.append("--metrics")
         if self.server_slots:
             server_args.append("--slots")
+        if self.pooling:
+            server_args.extend(["--pooling", self.pooling])
         if self.model_alias:
             server_args.extend(["--alias", self.model_alias])
         if self.n_ctx: