break;
case 0: // max absolute
for (int i = 0; i < n; i++) {
- if (sum < std::abs(inp[i])) sum = std::abs(inp[i]);
+ if (sum < std::abs(inp[i])) {
+ sum = std::abs(inp[i]);
+ }
}
sum /= 32760.0; // make an int16 range
break;
// Embedding utils
//
-void common_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2);
+// TODO: repace embd_norm with an enum
+void common_embd_normalize(const float * inp, float * out, int n, int embd_norm);
float common_embd_similarity_cos(const float * embd1, const float * embd2, int n);
}
std::vector<float> emb_norm(emb_unorm.size());
- common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd);
+ common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd, 2);
result.push_back(emb_norm);
#ifdef GRIT_DEBUG
}
float * out = output + batch.seq_id[i][0] * n_embd;
- common_embd_normalize(embd, out, n_embd);
+ common_embd_normalize(embd, out, n_embd, 2);
}
}
### POST `/v1/embeddings`: OpenAI-compatible embeddings API
+This endpoint requires that the model uses a pooling different than type `none`. The embeddings are normalized using the Eucledian norm.
+
*Options:*
See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-reference/embeddings).
}'
```
+### POST `/embeddings`: non-OpenAI-compatible embeddings API
+
+This endpoint supports all poolings, including `--pooling none`. When the pooling is `none`, the responses will contain the *unnormalized* embeddings for *all* input tokens. For all other pooling types, only the pooled embeddings are returned, normalized using Euclidian norm.
+
+Note that the response format of this endpoint is different from `/v1/embeddings`.
+
+*Options:*
+
+Same as the `/v1/embeddings` endpoint.
+
+*Examples:*
+
+Same as the `/v1/embeddings` endpoint.
+
+**Response format**
+
+```json
+[
+ {
+ "index": 0,
+ "embedding": [
+ [ ... embeddings for token 0 ... ],
+ [ ... embeddings for token 1 ... ],
+ [ ... ]
+ [ ... embeddings for token N-1 ... ],
+ ]
+ },
+ ...
+ {
+ "index": P,
+ "embedding": [
+ [ ... embeddings for token 0 ... ],
+ [ ... embeddings for token 1 ... ],
+ [ ... ]
+ [ ... embeddings for token N-1 ... ],
+ ]
+ }
+]
+```
+
### GET `/slots`: Returns the current slots processing state
> [!WARNING]
struct server_task_result_embd : server_task_result {
int index = 0;
- std::vector<float> embedding;
+ std::vector<std::vector<float>> embedding;
int32_t n_tokens;
+ // OAI-compat fields
+ bool oaicompat = false;
+
virtual int get_index() override {
return index;
}
virtual json to_json() override {
+ return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
+ }
+
+ json to_json_non_oaicompat() {
+ return json {
+ {"index", index},
+ {"embedding", embedding},
+ };
+ }
+
+ json to_json_oaicompat() {
return json {
{"index", index},
- {"embedding", embedding},
+ {"embedding", embedding[0]},
{"tokens_evaluated", n_tokens},
};
}
void send_embedding(const server_slot & slot, const llama_batch & batch) {
auto res = std::make_unique<server_task_result_embd>();
- res->id = slot.id_task;
- res->index = slot.index;
- res->n_tokens = slot.n_prompt_tokens;
+ res->id = slot.id_task;
+ res->index = slot.index;
+ res->n_tokens = slot.n_prompt_tokens;
+ res->oaicompat = slot.params.oaicompat;
const int n_embd = llama_n_embd(model);
if (embd == NULL) {
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
- res->embedding = std::vector<float>(n_embd, 0.0f);
+ res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
continue;
}
- common_embd_normalize(embd, embd_res.data(), n_embd);
- res->embedding = embd_res;
+ // normalize only when there is pooling
+ // TODO: configurable
+ if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
+ common_embd_normalize(embd, embd_res.data(), n_embd, 2);
+ res->embedding.push_back(embd_res);
+ } else {
+ res->embedding.push_back({ embd, embd + n_embd });
+ }
}
SLT_DBG(slot, "%s", "sending embeddings\n");
// add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
- common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false);
+ // without pooling, we want to output the embeddings for all the tokens in the batch
+ const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
+
+ common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
res_ok(res, data);
};
- const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
+ const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat) {
const json body = json::parse(req.body);
- bool oaicompat = false;
+
+ if (oaicompat && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
+ res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
+ return;
+ }
// for the shape of input/content, see tokenize_input_prompts()
json prompt;
- if (body.contains("input")) {
- oaicompat = true;
+ if (body.count("input") != 0) {
prompt = body.at("input");
} else if (body.contains("content")) {
oaicompat = false;
{
std::vector<server_task> tasks;
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
- server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
+ server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
+
task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
task.prompt_tokens = std::move(tokenized_prompts[i]);
+
+ // OAI-compat
+ task.params.oaicompat = oaicompat;
+
tasks.push_back(task);
}
}
// write JSON response
- json root = oaicompat
- ? format_embeddings_response_oaicompat(body, responses)
- : responses.size() == 1 ? responses[0] : json(responses);
+ json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses);
res_ok(res, root);
};
+ const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
+ handle_embeddings_impl(req, res, false);
+ };
+
+ const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
+ handle_embeddings_impl(req, res, true);
+ };
+
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
svr->Post("/infill", handle_infill);
svr->Post("/embedding", handle_embeddings); // legacy
svr->Post("/embeddings", handle_embeddings);
- svr->Post("/v1/embeddings", handle_embeddings);
+ svr->Post("/v1/embeddings", handle_embeddings_oai);
svr->Post("/rerank", handle_rerank);
svr->Post("/reranking", handle_rerank);
svr->Post("/v1/rerank", handle_rerank);
def test_embedding_single():
global server
+ server.pooling = 'last'
server.start()
- res = server.make_request("POST", "/embeddings", data={
+ res = server.make_request("POST", "/v1/embeddings", data={
"input": "I believe the meaning of life is",
})
assert res.status_code == 200
def test_embedding_multiple():
global server
+ server.pooling = 'last'
server.start()
- res = server.make_request("POST", "/embeddings", data={
+ res = server.make_request("POST", "/v1/embeddings", data={
"input": [
"I believe the meaning of life is",
"Write a joke about AI from a very long prompt which will not be truncated",
@pytest.mark.parametrize(
- "content,is_multi_prompt",
+ "input,is_multi_prompt",
[
# single prompt
("string", False),
([[12, 34, 56], [12, "string", 34, 56]], True),
]
)
-def test_embedding_mixed_input(content, is_multi_prompt: bool):
+def test_embedding_mixed_input(input, is_multi_prompt: bool):
global server
server.start()
- res = server.make_request("POST", "/embeddings", data={"content": content})
+ res = server.make_request("POST", "/v1/embeddings", data={"input": input})
assert res.status_code == 200
+ data = res.body['data']
if is_multi_prompt:
- assert len(res.body) == len(content)
- for d in res.body:
+ assert len(data) == len(input)
+ for d in data:
assert 'embedding' in d
assert len(d['embedding']) > 1
else:
- assert 'embedding' in res.body
- assert len(res.body['embedding']) > 1
+ assert 'embedding' in data[0]
+ assert len(data[0]['embedding']) > 1
+
+
+def test_embedding_pooling_none():
+ global server
+ server.pooling = 'none'
+ server.start()
+ res = server.make_request("POST", "/embeddings", data={
+ "input": "hello hello hello",
+ })
+ assert res.status_code == 200
+ assert 'embedding' in res.body[0]
+ assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special
+
+ # make sure embedding vector is not normalized
+ for x in res.body[0]['embedding']:
+ assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON
+
+
+def test_embedding_pooling_none_oai():
+ global server
+ server.pooling = 'none'
+ server.start()
+ res = server.make_request("POST", "/v1/embeddings", data={
+ "input": "hello hello hello",
+ })
+
+ # /v1/embeddings does not support pooling type 'none'
+ assert res.status_code == 400
def test_embedding_openai_library_single():
global server
+ server.pooling = 'last'
server.start()
- client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
+ client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
assert len(res.data) == 1
assert len(res.data[0].embedding) > 1
def test_embedding_openai_library_multiple():
global server
+ server.pooling = 'last'
server.start()
- client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
+ client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.embeddings.create(model="text-embedding-3-small", input=[
"I believe the meaning of life is",
"Write a joke about AI from a very long prompt which will not be truncated",
def test_embedding_error_prompt_too_long():
global server
+ server.pooling = 'last'
server.start()
- res = server.make_request("POST", "/embeddings", data={
+ res = server.make_request("POST", "/v1/embeddings", data={
"input": "This is a test " * 512,
})
assert res.status_code != 200
def test_same_prompt_give_same_result():
+ server.pooling = 'last'
server.start()
- res = server.make_request("POST", "/embeddings", data={
+ res = server.make_request("POST", "/v1/embeddings", data={
"input": [
"I believe the meaning of life is",
"I believe the meaning of life is",
def test_embedding_usage_single(content, n_tokens):
global server
server.start()
- res = server.make_request("POST", "/embeddings", data={"input": content})
+ res = server.make_request("POST", "/v1/embeddings", data={"input": content})
assert res.status_code == 200
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
assert res.body['usage']['prompt_tokens'] == n_tokens
def test_embedding_usage_multiple():
global server
server.start()
- res = server.make_request("POST", "/embeddings", data={
+ res = server.make_request("POST", "/v1/embeddings", data={
"input": [
"I believe the meaning of life is",
"I believe the meaning of life is",
server_reranking: bool | None = False
server_metrics: bool | None = False
server_slots: bool | None = False
+ pooling: str | None = None
draft: int | None = None
api_key: str | None = None
response_format: str | None = None
server_args.append("--metrics")
if self.server_slots:
server_args.append("--slots")
+ if self.pooling:
+ server_args.extend(["--pooling", self.pooling])
if self.model_alias:
server_args.extend(["--alias", self.model_alias])
if self.n_ctx: