]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server / ranking : add sorting and management of top_n (#16403)
authorYann Follet <redacted>
Sat, 11 Oct 2025 13:39:04 +0000 (21:39 +0800)
committerGitHub <redacted>
Sat, 11 Oct 2025 13:39:04 +0000 (16:39 +0300)
* server / ranking : add sorting and management of top_n

* Make the retro compatible if no top_n will return
all results

here is a script to make some test

```script

URL=${1:-http://127.0.0.1:8181}

curl "$URL/v1/rerank" -H "Content-Type: application/json" \
 -d '{ "model": "M", "query": "What is the recipe to make bread ?",
 "return_text" : true,
 "texts" : true,
 "top_n": 6,
 "documents": [
 "voici la recette pour faire du pain, il faut de la farine de l eau et du levain et du sel",
 "it is a bear",
 "bread recipe : floor, water, yest, salt",
 "The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.",
 "here is the ingedients to bake bread : 500g floor, 350g water, 120g fresh refresh yest, 15g salt",
 "recipe to make cookies : floor, eggs, water, chocolat",
 "here is the recipe to make bread : 500g floor, 350g water, 120g fresh refresh yest, 15g salt",
 "il fait tres beau aujourd hui",
 "je n ai pas faim, je ne veux pas manger",
 "je suis a paris"
 ] }' | jq
```

* use resize() instead for(...)

* simplify top_n init since no need to return error

result to test :

./tests.sh unit/test_rerank.py -v -x
==================================================== test session starts =====================================================
platform linux -- Python 3.12.3, pytest-8.3.5, pluggy-1.6.0 -- /home/yann/dev/yann/llama.cpp/tools/server/tests/test/bin/python3
cachedir: .pytest_cache
rootdir: /home/yann/dev/yann/llama.cpp/tools/server/tests
configfile: pytest.ini
plugins: anyio-4.11.0
collected 8 items

unit/test_rerank.py::test_rerank PASSED                                                                                [ 12%]
unit/test_rerank.py::test_rerank_tei_format PASSED                                                                     [ 25%]
unit/test_rerank.py::test_invalid_rerank_req[documents0] PASSED                                                        [ 37%]
unit/test_rerank.py::test_invalid_rerank_req[None] PASSED                                                              [ 50%]
unit/test_rerank.py::test_invalid_rerank_req[123] PASSED                                                               [ 62%]
unit/test_rerank.py::test_invalid_rerank_req[documents3] PASSED                                                        [ 75%]
unit/test_rerank.py::test_rerank_usage[Machine learning is-A machine-Learning is-19] PASSED                            [ 87%]
unit/test_rerank.py::test_rerank_usage[Which city?-Machine learning is -Paris, capitale de la-26] PASSED               [100%]

===================================================== 8 passed in 4.31s ======================================================

* add rerank top_n unit test

here is the result :

./tests.sh unit/test_rerank.py -v -x
=================================================================== test session starts ===================================================================
platform linux -- Python 3.12.3, pytest-8.3.5, pluggy-1.6.0 -- /home/yann/dev/yann/llama.cpp/tools/server/tests/test/bin/python3
cachedir: .pytest_cache
rootdir: /home/yann/dev/yann/llama.cpp/tools/server/tests
configfile: pytest.ini
plugins: anyio-4.11.0
collected 16 items

unit/test_rerank.py::test_rerank PASSED                                                                                                             [  6%]
unit/test_rerank.py::test_rerank_tei_format PASSED                                                                                                  [ 12%]
unit/test_rerank.py::test_invalid_rerank_req[documents0] PASSED                                                                                     [ 18%]
unit/test_rerank.py::test_invalid_rerank_req[None] PASSED                                                                                           [ 25%]
unit/test_rerank.py::test_invalid_rerank_req[123] PASSED                                                                                            [ 31%]
unit/test_rerank.py::test_invalid_rerank_req[documents3] PASSED                                                                                     [ 37%]
unit/test_rerank.py::test_rerank_usage[Machine learning is-A machine-Learning is-19] PASSED                                                         [ 43%]
unit/test_rerank.py::test_rerank_usage[Which city?-Machine learning is -Paris, capitale de la-26] PASSED                                            [ 50%]
unit/test_rerank.py::test_rerank_top_n[None-4] PASSED                                                                                               [ 56%]
unit/test_rerank.py::test_rerank_top_n[2-2] PASSED                                                                                                  [ 62%]
unit/test_rerank.py::test_rerank_top_n[4-4] PASSED                                                                                                  [ 68%]
unit/test_rerank.py::test_rerank_top_n[99-4] PASSED                                                                                                 [ 75%]
unit/test_rerank.py::test_rerank_tei_top_n[None-4] PASSED                                                                                           [ 81%]
unit/test_rerank.py::test_rerank_tei_top_n[2-2] PASSED                                                                                              [ 87%]
unit/test_rerank.py::test_rerank_tei_top_n[4-4] PASSED                                                                                              [ 93%]
unit/test_rerank.py::test_rerank_tei_top_n[99-4] PASSED                                                                                             [100%]

=================================================================== 16 passed in 8.84s ===================================================================

* editor config check fix

tools/server/server.cpp
tools/server/tests/unit/test_rerank.py
tools/server/utils.hpp

index 60326e8e50efe9dc6360c9499bdaf356a1e3e8b4..cf12805b4998a9193f840748b78c934786d46359 100644 (file)
@@ -5401,15 +5401,6 @@ int main(int argc, char ** argv) {
 
         const json body = json::parse(req.body);
 
-        // TODO: implement
-        //int top_n = 1;
-        //if (body.count("top_n") != 1) {
-        //    top_n = body.at("top_n");
-        //} else {
-        //    res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST));
-        //    return;
-        //}
-
         // if true, use TEI API format, otherwise use Jina API format
         // Jina: https://jina.ai/reranker/
         // TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank
@@ -5434,6 +5425,8 @@ int main(int argc, char ** argv) {
             return;
         }
 
+        int top_n = json_value(body, "top_n", (int)documents.size());
+
         // create and queue the task
         json responses = json::array();
         bool error = false;
@@ -5474,7 +5467,8 @@ int main(int argc, char ** argv) {
             body,
             responses,
             is_tei_format,
-            documents);
+            documents,
+            top_n);
 
         res_ok(res, root);
     };
index 0b63c7821eb98b70d6fe5f252a1432474f9a97d3..ded8267109682492b5bcc820e9f5653116725fbd 100644 (file)
@@ -102,3 +102,45 @@ def test_rerank_usage(query, doc1, doc2, n_tokens):
     assert res.status_code == 200
     assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
     assert res.body['usage']['prompt_tokens'] == n_tokens
+
+
+@pytest.mark.parametrize("top_n,expected_len", [
+    (None, len(TEST_DOCUMENTS)),  # no top_n parameter
+    (2, 2),
+    (4, 4),
+    (99, len(TEST_DOCUMENTS)),    # higher than available docs
+])
+def test_rerank_top_n(top_n, expected_len):
+    global server
+    server.start()
+    data = {
+        "query": "Machine learning is",
+        "documents": TEST_DOCUMENTS,
+    }
+    if top_n is not None:
+        data["top_n"] = top_n
+
+    res = server.make_request("POST", "/rerank", data=data)
+    assert res.status_code == 200
+    assert len(res.body["results"]) == expected_len
+
+
+@pytest.mark.parametrize("top_n,expected_len", [
+    (None, len(TEST_DOCUMENTS)),  # no top_n parameter
+    (2, 2),
+    (4, 4),
+    (99, len(TEST_DOCUMENTS)),    # higher than available docs
+])
+def test_rerank_tei_top_n(top_n, expected_len):
+    global server
+    server.start()
+    data = {
+        "query": "Machine learning is",
+        "texts": TEST_DOCUMENTS,
+    }
+    if top_n is not None:
+        data["top_n"] = top_n
+
+    res = server.make_request("POST", "/rerank", data=data)
+    assert res.status_code == 200
+    assert len(res.body) == expected_len
index f175115f4fd6aaa85f462e471c528fbe6e931e69..fd0bc8de533cf707d30da563078320208f844e91 100644 (file)
@@ -849,47 +849,44 @@ static json format_response_rerank(
         const json & request,
         const json & ranks,
         bool is_tei_format,
-        std::vector<std::string> & texts) {
-    json res;
-    if (is_tei_format) {
-        // TEI response format
-        res = json::array();
-        bool return_text = json_value(request, "return_text", false);
-        for (const auto & rank : ranks) {
-            int index = json_value(rank, "index", 0);
-            json elem = json{
-                {"index", index},
-                {"score", json_value(rank, "score", 0.0)},
-            };
-            if (return_text) {
-                elem["text"] = std::move(texts[index]);
-            }
-            res.push_back(elem);
-        }
-    } else {
-        // Jina response format
-        json results = json::array();
-        int32_t n_tokens = 0;
-        for (const auto & rank : ranks) {
-            results.push_back(json{
-                {"index",           json_value(rank, "index", 0)},
-                {"relevance_score", json_value(rank, "score", 0.0)},
-            });
-
-            n_tokens += json_value(rank, "tokens_evaluated", 0);
-        }
-
-        res = json{
-            {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
-            {"object", "list"},
-            {"usage", json{
-                {"prompt_tokens", n_tokens},
-                {"total_tokens", n_tokens}
-            }},
-            {"results", results}
+        std::vector<std::string> & texts,
+        int top_n) {
+    int32_t n_tokens = 0;
+    bool return_text = is_tei_format && json_value(request, "return_text", false);
+    std::vector<json> elements; // Temporary vector to hold unsorted elements
+    std::string score_label = is_tei_format ? "score" : "relevance_score";
+    for (const auto & rank : ranks) {
+        int index = json_value(rank, "index", 0);
+        json elem = json{
+            {"index", index},
+            {score_label, json_value(rank, "score", 0.0)},
         };
+        n_tokens += json_value(rank, "tokens_evaluated", 0);
+        if (return_text) {
+            elem["text"] = std::move(texts[index]);
+        }
+        elements.push_back(elem);
     }
 
+    std::sort(elements.begin(), elements.end(), [score_label](const json& a, const json& b) {
+        return json_value(a, score_label, 0.0) > json_value(b, score_label, 0.0);
+    });
+
+    elements.resize(std::min(top_n, (int)elements.size()));
+    json results = elements;
+
+    if (is_tei_format) return results;
+
+    json res = json{
+        {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
+        {"object", "list"},
+        {"usage", json{
+            {"prompt_tokens", n_tokens},
+            {"total_tokens", n_tokens}
+        }},
+        {"results", results}
+    };
+
     return res;
 }