From: Georgi Gerganov Date: Mon, 16 Jun 2025 19:33:27 +0000 (+0300) Subject: server : fix incorrect usage of llama_get_embeddings() (#14225) X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=89fea80d298184d1cd93564f48e060d9f541f4b4;p=pkg%2Fggml%2Fsources%2Fllama.cpp server : fix incorrect usage of llama_get_embeddings() (#14225) * server : fix incorrect usage of llama_get_embeddings() ggml-ci * cont : fix the fix ggml-ci --- diff --git a/include/llama.h b/include/llama.h index b086b68e..635508b1 100644 --- a/include/llama.h +++ b/include/llama.h @@ -965,6 +965,7 @@ extern "C" { LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx); // Set whether the context outputs embeddings or not + // TODO: rename to avoid confusion with llama_get_embeddings() LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings); // Set whether to use causal attention or not diff --git a/tools/server/server.cpp b/tools/server/server.cpp index c08e4212..721d0918 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1358,6 +1358,14 @@ struct server_slot { return server_task_type_need_logits(task_type); } + // if the context does not have a memory module then all embeddings have to be computed within a single ubatch + // also we cannot split if the pooling would require any past tokens + bool can_split() const { + return + !need_embd() || + (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); + } + bool can_batch_with(server_slot & other_slot) const { return task_type == other_slot.task_type && are_lora_equal(lora, other_slot.lora); } @@ -1929,14 +1937,6 @@ struct server_context { llama_batch_free(batch); } - // if the context does not have a memory module then all embeddings have to be computed within a single ubatch - // also we cannot split if the pooling would require any past tokens - bool can_split() const { - return - !llama_get_embeddings(ctx) || - (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); - } - bool load_model(const common_params & params) { SRV_INF("loading model '%s'\n", params.model.path.c_str()); @@ -3130,7 +3130,7 @@ struct server_context { continue; } - if (!can_split()) { + if (!slot.can_split()) { if (slot.n_prompt_tokens > n_ubatch) { slot.release(); send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); @@ -3273,7 +3273,7 @@ struct server_context { slot.n_prompt_tokens_processed = 0; } - if (!can_split()) { + if (!slot.can_split()) { // cannot fit the prompt in the current batch - will try next iter if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { continue;