const json body = json::parse(req.body);
- // TODO: implement
- //int top_n = 1;
- //if (body.count("top_n") != 1) {
- // top_n = body.at("top_n");
- //} else {
- // res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST));
- // return;
- //}
-
// if true, use TEI API format, otherwise use Jina API format
// Jina: https://jina.ai/reranker/
// TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank
return;
}
+ int top_n = json_value(body, "top_n", (int)documents.size());
+
// create and queue the task
json responses = json::array();
bool error = false;
body,
responses,
is_tei_format,
- documents);
+ documents,
+ top_n);
res_ok(res, root);
};
assert res.status_code == 200
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
assert res.body['usage']['prompt_tokens'] == n_tokens
+
+
+@pytest.mark.parametrize("top_n,expected_len", [
+ (None, len(TEST_DOCUMENTS)), # no top_n parameter
+ (2, 2),
+ (4, 4),
+ (99, len(TEST_DOCUMENTS)), # higher than available docs
+])
+def test_rerank_top_n(top_n, expected_len):
+ global server
+ server.start()
+ data = {
+ "query": "Machine learning is",
+ "documents": TEST_DOCUMENTS,
+ }
+ if top_n is not None:
+ data["top_n"] = top_n
+
+ res = server.make_request("POST", "/rerank", data=data)
+ assert res.status_code == 200
+ assert len(res.body["results"]) == expected_len
+
+
+@pytest.mark.parametrize("top_n,expected_len", [
+ (None, len(TEST_DOCUMENTS)), # no top_n parameter
+ (2, 2),
+ (4, 4),
+ (99, len(TEST_DOCUMENTS)), # higher than available docs
+])
+def test_rerank_tei_top_n(top_n, expected_len):
+ global server
+ server.start()
+ data = {
+ "query": "Machine learning is",
+ "texts": TEST_DOCUMENTS,
+ }
+ if top_n is not None:
+ data["top_n"] = top_n
+
+ res = server.make_request("POST", "/rerank", data=data)
+ assert res.status_code == 200
+ assert len(res.body) == expected_len
const json & request,
const json & ranks,
bool is_tei_format,
- std::vector<std::string> & texts) {
- json res;
- if (is_tei_format) {
- // TEI response format
- res = json::array();
- bool return_text = json_value(request, "return_text", false);
- for (const auto & rank : ranks) {
- int index = json_value(rank, "index", 0);
- json elem = json{
- {"index", index},
- {"score", json_value(rank, "score", 0.0)},
- };
- if (return_text) {
- elem["text"] = std::move(texts[index]);
- }
- res.push_back(elem);
- }
- } else {
- // Jina response format
- json results = json::array();
- int32_t n_tokens = 0;
- for (const auto & rank : ranks) {
- results.push_back(json{
- {"index", json_value(rank, "index", 0)},
- {"relevance_score", json_value(rank, "score", 0.0)},
- });
-
- n_tokens += json_value(rank, "tokens_evaluated", 0);
- }
-
- res = json{
- {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
- {"object", "list"},
- {"usage", json{
- {"prompt_tokens", n_tokens},
- {"total_tokens", n_tokens}
- }},
- {"results", results}
+ std::vector<std::string> & texts,
+ int top_n) {
+ int32_t n_tokens = 0;
+ bool return_text = is_tei_format && json_value(request, "return_text", false);
+ std::vector<json> elements; // Temporary vector to hold unsorted elements
+ std::string score_label = is_tei_format ? "score" : "relevance_score";
+ for (const auto & rank : ranks) {
+ int index = json_value(rank, "index", 0);
+ json elem = json{
+ {"index", index},
+ {score_label, json_value(rank, "score", 0.0)},
};
+ n_tokens += json_value(rank, "tokens_evaluated", 0);
+ if (return_text) {
+ elem["text"] = std::move(texts[index]);
+ }
+ elements.push_back(elem);
}
+ std::sort(elements.begin(), elements.end(), [score_label](const json& a, const json& b) {
+ return json_value(a, score_label, 0.0) > json_value(b, score_label, 0.0);
+ });
+
+ elements.resize(std::min(top_n, (int)elements.size()));
+ json results = elements;
+
+ if (is_tei_format) return results;
+
+ json res = json{
+ {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
+ {"object", "list"},
+ {"usage", json{
+ {"prompt_tokens", n_tokens},
+ {"total_tokens", n_tokens}
+ }},
+ {"results", results}
+ };
+
return res;
}