const json body = json::parse(req.body);
bool oaicompat = false;
- // an input prompt can be a string or a list of tokens (integer)
+ // for the shape of input/content, see tokenize_input_prompts()
json prompt;
- if (body.count("input") != 0) {
+ if (body.contains("input")) {
oaicompat = true;
prompt = body.at("input");
- } else if (body.count("content") != 0) {
- // with "content", we only support single prompt
- prompt = std::vector<std::string>{body.at("content")};
+ } else if (body.contains("content")) {
+ oaicompat = false;
+ prompt = body.at("content");
} else {
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
return;
}
+ std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
+ for (const auto & tokens : tokenized_prompts) {
+ // this check is necessary for models that do not add BOS token to the input
+ if (tokens.empty()) {
+ res_error(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST));
+ return;
+ }
+ }
+
// create and queue the task
json responses = json::array();
bool error = false;
{
std::vector<server_task> tasks;
- std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, /* add_special */ false, true);
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
task.id = ctx_server.queue_tasks.get_new_id();
assert len(d['embedding']) > 1
+@pytest.mark.parametrize(
+ "content,is_multi_prompt",
+ [
+ # single prompt
+ ("string", False),
+ ([12, 34, 56], False),
+ ([12, 34, "string", 56, 78], False),
+ # multiple prompts
+ (["string1", "string2"], True),
+ (["string1", [12, 34, 56]], True),
+ ([[12, 34, 56], [12, 34, 56]], True),
+ ([[12, 34, 56], [12, "string", 34, 56]], True),
+ ]
+)
+def test_embedding_mixed_input(content, is_multi_prompt: bool):
+ global server
+ server.start()
+ res = server.make_request("POST", "/embeddings", data={"content": content})
+ assert res.status_code == 200
+ if is_multi_prompt:
+ assert len(res.body) == len(content)
+ for d in res.body:
+ assert 'embedding' in d
+ assert len(d['embedding']) > 1
+ else:
+ assert 'embedding' in res.body
+ assert len(res.body['embedding']) > 1
+
+
def test_embedding_openai_library_single():
global server
server.start()
@pytest.mark.parametrize(
"content,n_tokens",
[
- ("I believe the meaning of life is", 7),
- ("This is a test", 4),
+ ("I believe the meaning of life is", 9),
+ ("This is a test", 6),
]
)
def test_embedding_usage_single(content, n_tokens):
})
assert res.status_code == 200
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
- assert res.body['usage']['prompt_tokens'] == 2 * 7
+ assert res.body['usage']['prompt_tokens'] == 2 * 9
* and multiple prompts (multi-tasks):
* - "prompt": ["string1", "string2"]
* - "prompt": ["string1", [12, 34, 56]]
+ * - "prompt": [[12, 34, 56], [78, 90, 12]]
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
*/
static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {