]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : fill usage info in embeddings and rerank responses (#10852)
authorkrystiancha <redacted>
Tue, 17 Dec 2024 16:00:24 +0000 (16:00 +0000)
committerGitHub <redacted>
Tue, 17 Dec 2024 16:00:24 +0000 (18:00 +0200)
* server : fill usage info in embeddings response

* server : fill usage info in reranking response

examples/server/server.cpp
examples/server/tests/unit/test_embedding.py
examples/server/tests/unit/test_rerank.py
examples/server/utils.hpp

index bc0d042ae924735c6e620ad86d3993813bc5b5a8..436170a034fde4b215a0527673c8fb2dbe6d0b74 100644 (file)
@@ -719,14 +719,17 @@ struct server_task_result_embd : server_task_result {
     int index = 0;
     std::vector<float> embedding;
 
+    int32_t n_tokens;
+
     virtual int get_index() override {
         return index;
     }
 
     virtual json to_json() override {
         return json {
-            {"index",     index},
-            {"embedding", embedding},
+            {"index",            index},
+            {"embedding",        embedding},
+            {"tokens_evaluated", n_tokens},
         };
     }
 };
@@ -735,14 +738,17 @@ struct server_task_result_rerank : server_task_result {
     int index = 0;
     float score = -1e6;
 
+    int32_t n_tokens;
+
     virtual int get_index() override {
         return index;
     }
 
     virtual json to_json() override {
         return json {
-            {"index", index},
-            {"score", score},
+            {"index",            index},
+            {"score",            score},
+            {"tokens_evaluated", n_tokens},
         };
     }
 };
@@ -1995,6 +2001,7 @@ struct server_context {
         auto res = std::make_unique<server_task_result_embd>();
         res->id    = slot.id_task;
         res->index = slot.index;
+        res->n_tokens = slot.n_prompt_tokens;
 
         const int n_embd = llama_n_embd(model);
 
@@ -2030,6 +2037,7 @@ struct server_context {
         auto res = std::make_unique<server_task_result_rerank>();
         res->id    = slot.id_task;
         res->index = slot.index;
+        res->n_tokens = slot.n_prompt_tokens;
 
         for (int i = 0; i < batch.n_tokens; ++i) {
             if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
index fc7c20064ddfcee5e02ca91ccf457dca6cdccc6f..fea1d6510c89e7c3dcbb6de65af7aae57dd4baf9 100644 (file)
@@ -97,3 +97,33 @@ def test_same_prompt_give_same_result():
         vi = res.body['data'][i]['embedding']
         for x, y in zip(v0, vi):
             assert abs(x - y) < EPSILON
+
+
+@pytest.mark.parametrize(
+    "content,n_tokens",
+    [
+        ("I believe the meaning of life is", 7),
+        ("This is a test", 4),
+    ]
+)
+def test_embedding_usage_single(content, n_tokens):
+    global server
+    server.start()
+    res = server.make_request("POST", "/embeddings", data={"input": content})
+    assert res.status_code == 200
+    assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
+    assert res.body['usage']['prompt_tokens'] == n_tokens
+
+
+def test_embedding_usage_multiple():
+    global server
+    server.start()
+    res = server.make_request("POST", "/embeddings", data={
+        "input": [
+            "I believe the meaning of life is",
+            "I believe the meaning of life is",
+        ],
+    })
+    assert res.status_code == 200
+    assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
+    assert res.body['usage']['prompt_tokens'] == 2 * 7
index 189bc4c9623298500b14a57bf15e670a9a2dd2fd..7203d79435702c08840500f1917eebe359b86b81 100644 (file)
@@ -53,3 +53,26 @@ def test_invalid_rerank_req(documents):
     })
     assert res.status_code == 400
     assert "error" in res.body
+
+
+@pytest.mark.parametrize(
+    "query,doc1,doc2,n_tokens",
+    [
+        ("Machine learning is", "A machine", "Learning is", 19),
+        ("Which city?", "Machine learning is ", "Paris, capitale de la", 26),
+    ]
+)
+def test_rerank_usage(query, doc1, doc2, n_tokens):
+    global server
+    server.start()
+
+    res = server.make_request("POST", "/rerank", data={
+        "query": query,
+        "documents": [
+            doc1,
+            doc2,
+        ]
+    })
+    assert res.status_code == 200
+    assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
+    assert res.body['usage']['prompt_tokens'] == n_tokens
index c6f08bf21071aacdaddae5b79c7ad6e68f558b55..8fffe484aec12398e07735e6486ce86e1cebbc35 100644 (file)
@@ -560,6 +560,7 @@ static json oaicompat_completion_params_parse(
 
 static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
     json data = json::array();
+    int32_t n_tokens = 0;
     int i = 0;
     for (const auto & elem : embeddings) {
         data.push_back(json{
@@ -567,14 +568,16 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
             {"index",     i++},
             {"object",    "embedding"}
         });
+
+        n_tokens += json_value(elem, "tokens_evaluated", 0);
     }
 
     json res = json {
         {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
         {"object", "list"},
-        {"usage", json { // TODO: fill
-            {"prompt_tokens", 0},
-            {"total_tokens", 0}
+        {"usage", json {
+            {"prompt_tokens", n_tokens},
+            {"total_tokens", n_tokens}
         }},
         {"data", data}
     };
@@ -584,20 +587,23 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
 
 static json format_response_rerank(const json & request, const json & ranks) {
     json data = json::array();
+    int32_t n_tokens = 0;
     int i = 0;
     for (const auto & rank : ranks) {
         data.push_back(json{
             {"index",    i++},
             {"relevance_score", json_value(rank, "score", 0.0)},
         });
+
+        n_tokens += json_value(rank, "tokens_evaluated", 0);
     }
 
     json res = json {
         {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
         {"object", "list"},
-        {"usage", json { // TODO: fill
-            {"prompt_tokens", 0},
-            {"total_tokens", 0}
+        {"usage", json {
+            {"prompt_tokens", n_tokens},
+            {"total_tokens", n_tokens}
         }},
         {"results", data}
     };