]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : rework embeddings logic (#14208)
authorGeorgi Gerganov <redacted>
Mon, 16 Jun 2025 11:14:00 +0000 (14:14 +0300)
committerGitHub <redacted>
Mon, 16 Jun 2025 11:14:00 +0000 (14:14 +0300)
* llama : rework embeddings logic

ggml-ci

* cont : fix rerank

ggml-ci

* cont : engrish [no ci]

* cont : fix rerank

ggml-ci

* server : support both embeddings and completions with single model

ggml-ci

* cont : avoid embeddings_org

ggml-ci

16 files changed:
common/arg.cpp
common/common.cpp
common/common.h
examples/gritlm/gritlm.cpp
include/llama.h
src/llama-batch.cpp
src/llama-batch.h
src/llama-context.cpp
src/llama-kv-cache-recurrent.cpp
src/llama-kv-cache-recurrent.h
src/llama-kv-cache-unified-iswa.cpp
src/llama-kv-cache-unified-iswa.h
src/llama-kv-cache-unified.cpp
src/llama-kv-cache-unified.h
src/llama-memory.h
tools/server/server.cpp

index 0d0daa361010567986be41d8f001bba179cd4744..231de227a9122815f2e28489dcb407095d6738bb 100644 (file)
@@ -988,10 +988,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
         params.tensor_buft_overrides.push_back({nullptr, nullptr});
     }
 
-    if (params.reranking && params.embedding) {
-        throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
-    }
-
     if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) {
         throw std::runtime_error(string_format(
             "error: the supplied chat template is not supported: %s%s\n",
@@ -2747,9 +2743,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
     ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
     add_opt(common_arg(
         {"--reranking", "--rerank"},
-        string_format("enable reranking endpoint on server (default: %s)", params.reranking ? "enabled" : "disabled"),
+        string_format("enable reranking endpoint on server (default: %s)", "disabled"),
         [](common_params & params) {
-            params.reranking = true;
+            params.embedding = true;
+            params.pooling_type = LLAMA_POOLING_TYPE_RANK;
         }
     ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING"));
     add_opt(common_arg(
index e23887c70770c4d2e2de6ff1243f427b8592a777..5b465150f05337e2d645ca9ef29650133d4fca42 100644 (file)
@@ -897,34 +897,6 @@ struct common_init_result common_init_from_params(common_params & params) {
 
     const llama_vocab * vocab = llama_model_get_vocab(model);
 
-    if (params.reranking) {
-        bool ok = true;
-
-        if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
-            LOG_WRN("%s: warning: vocab does not have a  BOS token, reranking will not work\n", __func__);
-            ok = false;
-        }
-
-        bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
-        bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
-
-        if (!has_eos && !has_sep) {
-            LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__);
-            ok = false;
-        } else if (!has_eos) {
-            LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
-        } else if (!has_sep) {
-            LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
-            ok = false;
-        }
-
-        if (!ok) {
-            llama_model_free(model);
-
-            return iparams;
-        }
-    }
-
     auto cparams = common_context_params_to_llama(params);
 
     llama_context * lctx = llama_init_from_model(model, cparams);
@@ -966,6 +938,35 @@ struct common_init_result common_init_from_params(common_params & params) {
         }
     }
 
+    if (llama_pooling_type(lctx) == LLAMA_POOLING_TYPE_RANK) {
+        bool ok = true;
+
+        if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
+            LOG_WRN("%s: warning: vocab does not have a  BOS token, reranking will not work\n", __func__);
+            ok = false;
+        }
+
+        bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
+        bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
+
+        if (!has_eos && !has_sep) {
+            LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__);
+            ok = false;
+        } else if (!has_eos) {
+            LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
+        } else if (!has_sep) {
+            LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
+            ok = false;
+        }
+
+        if (!ok) {
+            llama_free(lctx);
+            llama_model_free(model);
+
+            return iparams;
+        }
+    }
+
     // load and optionally apply lora adapters
     for (auto & la : params.lora_adapters) {
         llama_adapter_lora_ptr lora;
@@ -1143,11 +1144,6 @@ struct llama_context_params common_context_params_to_llama(const common_params &
     cparams.op_offload        = !params.no_op_offload;
     cparams.swa_full          = params.swa_full;
 
-    if (params.reranking) {
-        cparams.embeddings    = true;
-        cparams.pooling_type  = LLAMA_POOLING_TYPE_RANK;
-    }
-
     cparams.type_k = params.cache_type_k;
     cparams.type_v = params.cache_type_v;
 
index f26724b6e14951893712c9d1626bbd51cdb0d2a0..00b6ca03a20b4b42634f8ef67e041f992694089b 100644 (file)
@@ -355,7 +355,6 @@ struct common_params {
     int32_t embd_normalize = 2;     // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
     std::string embd_out   = "";    // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
     std::string embd_sep   = "\n";  // separator of embeddings
-    bool reranking         = false; // enable reranking support on server
 
     // server params
     int32_t port           = 8080;         // server listens on this network port
index 041da61c743c12cca43c088acb211b16871e382e..bdab052c3390ff7b440275237b46fbf13612da80 100644 (file)
@@ -41,12 +41,11 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
 
         // add input to batch (this increments n_tokens)
         for (int32_t j = 0; j < n_toks; j++) {
-            common_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst);
+            common_batch_add(batch, inputs[j], j, { 0 }, true);
         }
 
         // clear previous kv_cache values (irrelevant for embeddings)
         llama_memory_clear(llama_get_memory(ctx), true);
-        llama_set_embeddings(ctx, true);
         llama_set_causal_attn(ctx, false);
 
         // run model
@@ -103,7 +102,6 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
     llama_token eos_token = llama_vocab_eos(vocab);
 
     llama_memory_clear(llama_get_memory(ctx), true);
-    llama_set_embeddings(ctx, false);
     llama_set_causal_attn(ctx, true);
 
     llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
@@ -166,6 +164,8 @@ int main(int argc, char * argv[]) {
     llama_model_params mparams = common_model_params_to_llama(params);
     llama_context_params cparams = common_context_params_to_llama(params);
 
+    cparams.embeddings = true;
+
     llama_backend_init();
 
     llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
@@ -213,6 +213,8 @@ int main(int argc, char * argv[]) {
         std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[1].c_str(), cosine_sim_q1_d1);
     }
 
+    llama_set_embeddings(ctx, false);
+
     // ### Generation ###
     // GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
     {
index d5e4cef68c213f0e2bb7fdb54ee84ed58acea6f4..b086b68e6d4ea40ca6af8143f68425342be0a87c 100644 (file)
@@ -254,7 +254,10 @@ extern "C" {
     // - seq_id : the sequence to which the respective token belongs
     //            (if set to NULL, the sequence ID will be assumed to be 0)
     // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
-    //            (if set to NULL, only the logits for last token will be returned)
+    //            (if set to NULL:
+    //               - if embeddings: all tokens are output
+    //               - if not:        only the last token is output
+    //            )
     //
     typedef struct llama_batch {
         int32_t n_tokens;
@@ -262,8 +265,8 @@ extern "C" {
         llama_token  *  token;
         float        *  embd;
         llama_pos    *  pos;
-        int32_t      *  n_seq_id; // TODO: remove, should belong to only 1 sequence
-        llama_seq_id ** seq_id;   // TODO: become llama_seq_id * seq_id;
+        int32_t      *  n_seq_id;
+        llama_seq_id ** seq_id;
         int8_t       *  logits;   // TODO: rename this to "output"
     } llama_batch;
 
@@ -961,8 +964,7 @@ extern "C" {
     // Get the number of threads used for prompt and batch processing (multiple token).
     LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx);
 
-    // Set whether the model is in embeddings mode or not
-    // If true, embeddings will be returned but logits will not
+    // Set whether the context outputs embeddings or not
     LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
 
     // Set whether to use causal attention or not
index a9f4a3d4c45c5da6fe8ada2d164d424eea7774e7..8b6d14fe8813c3d0874a4ad2aaca71c9e6f0aa4b 100644 (file)
@@ -299,7 +299,8 @@ llama_batch_allocr::llama_batch_allocr() {
 bool llama_batch_allocr::init(
         const llama_batch & batch_inp,
         const llama_vocab & vocab,
-        const llama_memory_i * memory) {
+        const llama_memory_i * memory,
+        bool embd_all) {
     clear();
 
     batch = batch_inp;
@@ -378,10 +379,31 @@ bool llama_batch_allocr::init(
     }
 
     if (!batch.logits) {
-        // by default return the output only for the last token
-        output.resize(batch.n_tokens);
-        output[output.size() - 1] = true;
+        if (embd_all) {
+            // return the output for all tokens
+            output.resize(batch.n_tokens, true);
+        } else {
+            // return the output only for the last token
+            output.resize(batch.n_tokens, false);
+            output[output.size() - 1] = true;
+        }
+
         batch.logits = output.data();
+    } else if (embd_all) {
+        bool warn = false;
+
+        for (int32_t i = 0; i < batch.n_tokens; ++i) {
+            if (batch.logits[i] == 0) {
+                warn = true;
+            }
+        }
+
+        if (warn) {
+            LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__);
+
+            output.resize(batch.n_tokens, true);
+            batch.logits = output.data();
+        }
     }
 
     //
index 04501ce5d424c729749a4ef4696453f1f98a5239..a555c157234be82933b0bf7b35b43c4fed2593e4 100644 (file)
@@ -88,7 +88,8 @@ public:
     bool init(
             const llama_batch & batch_inp,
             const llama_vocab & vocab,
-            const llama_memory_i * memory);
+            const llama_memory_i * memory,
+            bool embd_all);
 
     const llama_batch & get_batch() const;
 
index 3a113d1bcfb2a48df98b2c1150dc717bcdc3ad8c..f56a58e9b6ec6208c92e0e38c4a879c5af44937c 100644 (file)
@@ -728,7 +728,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
     }
 
     // note: during encode, we always pass the full sequence starting from pos = 0
-    if (!batch_allocr->init(batch_inp, model.vocab, nullptr)) {
+    if (!batch_allocr->init(batch_inp, model.vocab, nullptr, true)) {
         LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
         return -1;
     }
@@ -894,7 +894,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
         return -1;
     }
 
-    if (!batch_allocr->init(batch_inp, model.vocab, memory.get())) {
+    // when computing embeddings, all tokens are output
+    const bool embd_all = cparams.embeddings;
+
+    if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) {
         LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
         return -1;
     }
@@ -911,12 +914,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
 
     GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
 
-    // this indicates we are doing pooled embedding
-    const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
-
     const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
 
-    if (embd_pooled) {
+    if (embd_all) {
         // require that all tokens are output
         if (n_outputs_all != n_tokens_all) {
             LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
@@ -945,7 +945,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
     llama_memory_state_ptr mstate;
 
     while (true) {
-        mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
+        mstate = memory->init_batch(batch, cparams.n_ubatch, embd_all);
         if (!mstate) {
             return -2;
         }
@@ -1058,7 +1058,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
         //    ggml_graph_dump_dot(gf, NULL, "llama.dot");
         //}
 
-        auto * t_logits = cparams.embeddings ? nullptr         : res->get_logits();
+        auto * t_logits = res->get_logits();
         auto * t_embd   = cparams.embeddings ? res->get_embd() : nullptr;
 
         if (t_embd && res->get_embd_pooled()) {
@@ -1222,9 +1222,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
     const auto n_vocab = vocab.n_tokens();
     const auto n_embd  = hparams.n_embd;
 
-    // TODO: use a per-batch flag for logits presence instead
-    bool has_logits = !cparams.embeddings;
-    bool has_embd   =  cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
+    bool has_logits = true;
+    bool has_embd   = cparams.embeddings;
 
     // TODO: hacky enc-dec support
     if (model.arch == LLM_ARCH_T5) {
@@ -2044,14 +2043,11 @@ void llama_context::opt_epoch_iter(
 
         n_queued_tokens += n_tokens_all;
 
-        // this indicates we are doing pooled embedding
-        const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
-
         embd_seq.clear();
 
         uint32_t n_outputs_all = n_tokens_all;
 
-        auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
+        auto mstate = memory->init_batch(batch, cparams.n_ubatch, true);
         if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
             LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
             break;
index de23b4ad23bcec457538630a4ff4340b78d8df99..8f6f120f682b769730113d3c4bf835284209a239 100644 (file)
@@ -359,9 +359,7 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
     return result;
 }
 
-llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
-    GGML_UNUSED(embd_pooled);
-
+llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
     auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
 
     std::vector<llama_ubatch> ubatches;
@@ -369,8 +367,8 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch &
     while (sbatch.n_tokens > 0) {
         llama_ubatch ubatch;
 
-        if (embd_pooled) {
-            // Pooled embeddings cannot be split across ubatches (yet)
+        if (embd_all) {
+            // if all tokens are output, split by sequence
             ubatch = sbatch.split_seq(n_ubatch);
         } else {
             ubatch = sbatch.split_equal(n_ubatch);
index d7c02ea8721609150e3a346d9b9df688d9883d0a..f9b01a6513393fa4fea94eeeda9f51cbc89e5377 100644 (file)
@@ -32,7 +32,7 @@ public:
     llama_memory_state_ptr init_batch(
             const llama_batch & batch,
             uint32_t n_ubatch,
-            bool embd_pooled) override;
+            bool embd_all) override;
 
     llama_memory_state_ptr init_full() override;
 
index 9814f766312036399c51623275f9cb81923c48a0..a4a4c2b1b859de2c2b4be3904f8ecdf0c2e27d9c 100644 (file)
@@ -95,8 +95,8 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
     return kv_swa->seq_pos_max(seq_id);
 }
 
-llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
-    GGML_UNUSED(embd_pooled);
+llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
+    GGML_UNUSED(embd_all);
 
     // first try simple split
     do {
index d114c7378fbe94f4c3b57dca729d7fb9ccebc1a0..6e941e1a41b88b31f0838c4834ae6eda830a3455 100644 (file)
@@ -34,7 +34,7 @@ public:
     llama_memory_state_ptr init_batch(
             const llama_batch & batch,
             uint32_t n_ubatch,
-            bool embd_pooled) override;
+            bool embd_all) override;
 
     llama_memory_state_ptr init_full() override;
 
index b17936abdb4c6a433dfd403458d2248756dbe4d1..3b37679859d392481c4e56bb1004acbb4004ee76 100644 (file)
@@ -310,8 +310,8 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
 llama_memory_state_ptr llama_kv_cache_unified::init_batch(
             const llama_batch & batch,
             uint32_t n_ubatch,
-            bool embd_pooled) {
-    GGML_UNUSED(embd_pooled);
+            bool embd_all) {
+    GGML_UNUSED(embd_all);
 
     do {
         auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
index d6dcd19f2507e04b37525df8d5b1d0ebac87d9e6..d96571d952b81db3e47b8653018b79b5e3235325 100644 (file)
@@ -59,7 +59,7 @@ public:
     llama_memory_state_ptr init_batch(
             const llama_batch & batch,
             uint32_t n_ubatch,
-            bool embd_pooled) override;
+            bool embd_all) override;
 
     llama_memory_state_ptr init_full() override;
 
index 42e226dc0ed61ebf3a8d4a45a46211392fc79a53..24668f861b976243bb8bb59fd149e615c71cb87a 100644 (file)
@@ -73,7 +73,7 @@ struct llama_memory_i {
     virtual llama_memory_state_ptr init_batch(
             const llama_batch & batch,
             uint32_t n_ubatch,
-            bool embd_pooled) = 0;
+            bool embd_all) = 0;
 
     // simulate full cache, used for allocating worst-case compute buffers
     virtual llama_memory_state_ptr init_full() = 0;
index 626c58bd304ffe9e0006766e1a775b349f17b62b..c08e421255fcee5d25ab1f483404a873f79b3ace 100644 (file)
@@ -88,6 +88,26 @@ enum error_type {
     ERROR_TYPE_NOT_SUPPORTED, // custom error
 };
 
+static bool server_task_type_need_embd(server_task_type task_type) {
+    switch (task_type) {
+        case SERVER_TASK_TYPE_EMBEDDING:
+        case SERVER_TASK_TYPE_RERANK:
+            return true;
+        default:
+            return false;
+    }
+}
+
+static bool server_task_type_need_logits(server_task_type task_type) {
+    switch (task_type) {
+        case SERVER_TASK_TYPE_COMPLETION:
+        case SERVER_TASK_TYPE_INFILL:
+            return true;
+        default:
+            return false;
+    }
+}
+
 struct slot_params {
     bool stream        = true;
     bool cache_prompt  = true; // remember the prompt to avoid reprocessing all prompt
@@ -1330,13 +1350,16 @@ struct server_slot {
         n_draft_accepted = 0;
     }
 
-    bool is_non_causal() const {
-        return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
+    bool need_embd() const {
+        return server_task_type_need_embd(task_type);
+    }
+
+    bool need_logits() const {
+        return server_task_type_need_logits(task_type);
     }
 
     bool can_batch_with(server_slot & other_slot) const {
-        return is_non_causal() == other_slot.is_non_causal()
-            && are_lora_equal(lora, other_slot.lora);
+        return task_type == other_slot.task_type && are_lora_equal(lora, other_slot.lora);
     }
 
     bool has_budget(const common_params & global_params) {
@@ -1480,7 +1503,6 @@ struct server_slot {
             {"n_ctx",         n_ctx},
             {"speculative",   can_speculate()},
             {"is_processing", is_processing()},
-            {"non_causal",    is_non_causal()},
             {"params",        params.to_json()},
             {"prompt",        prompt_tokens.detokenize(ctx, true)},
             {"next_token",
@@ -1907,6 +1929,14 @@ struct server_context {
         llama_batch_free(batch);
     }
 
+    // if the context does not have a memory module then all embeddings have to be computed within a single ubatch
+    // also we cannot split if the pooling would require any past tokens
+    bool can_split() const {
+        return
+            !llama_get_embeddings(ctx) ||
+            (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
+    }
+
     bool load_model(const common_params & params) {
         SRV_INF("loading model '%s'\n", params.model.path.c_str());
 
@@ -2730,6 +2760,7 @@ struct server_context {
                         queue_tasks.defer(std::move(task));
                         break;
                     }
+
                     if (slot->is_processing()) {
                         // if requested slot is unavailable, we defer this task for processing later
                         SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
@@ -3092,7 +3123,14 @@ struct server_context {
                             continue;
                         }
 
-                        if (slot.is_non_causal()) {
+                        // TODO: support memory-less logits computation
+                        if (slot.need_logits() && !llama_get_memory(ctx)) {
+                            slot.release();
+                            send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER);
+                            continue;
+                        }
+
+                        if (!can_split()) {
                             if (slot.n_prompt_tokens > n_ubatch) {
                                 slot.release();
                                 send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
@@ -3227,8 +3265,7 @@ struct server_context {
                         }
 
                         if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
-                            // we have to evaluate at least 1 token to generate logits.
-                            SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);
+                            SLT_WRN(slot, "need to evaluate at least 1 token for each active slot, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);
 
                             slot.n_past--;
                         }
@@ -3236,8 +3273,7 @@ struct server_context {
                         slot.n_prompt_tokens_processed = 0;
                     }
 
-                    // non-causal tasks require to fit the entire prompt in the physical batch
-                    if (slot.is_non_causal()) {
+                    if (!can_split()) {
                         // cannot fit the prompt in the current batch - will try next iter
                         if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
                             continue;
@@ -3259,8 +3295,7 @@ struct server_context {
                     slot.cache_tokens.keep_first(slot.n_past);
 
                     // check if we should process the image
-                    if (slot.n_past < slot.n_prompt_tokens
-                            && slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) {
+                    if (slot.n_past < slot.n_prompt_tokens && slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) {
                         // process the image
                         int32_t new_n_past;
                         int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past);
@@ -3291,8 +3326,8 @@ struct server_context {
                             break; // end of text chunk
                         }
 
-                        // without pooling, we want to output the embeddings for all the tokens in the batch
-                        const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
+                        // embedding requires all tokens in the batch to be output
+                        const bool need_embd = server_task_type_need_embd(slot.task_type);
 
                         common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd);
                         slot.cache_tokens.push_back(cur_tok);
@@ -3346,17 +3381,15 @@ struct server_context {
         SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
 
         if (slot_batched) {
-            // make sure we're in the right embedding mode
-            llama_set_embeddings(ctx, slot_batched->is_non_causal());
             // apply lora, only need to do it once per batch
             common_set_adapter_lora(ctx, slot_batched->lora);
-        }
 
-        const bool do_encode = (params_base.embedding || params_base.reranking);
+            llama_set_embeddings(ctx, slot_batched->need_embd());
+        }
 
         // pad the batch so that batch.n_tokens >= n_slots
         // TODO: temporary workaround for https://github.com/ggml-org/llama.cpp/issues/13689
-        if (do_encode) {
+        if (slot_batched->need_embd()) {
             const int n_slots = slots.size();
 
             if (batch.n_tokens < n_slots) {
@@ -3378,8 +3411,11 @@ struct server_context {
                 SRV_WRN("adding %d dummy tokens to the batch, seq_id = %d\n", n_add, seq_id);
 
                 for (int j = 0; j < n_add; ++j) {
-                    common_batch_add(batch, 0, j, { seq_id }, false);
+                    common_batch_add(batch, 0, j, { seq_id }, true);
                 }
+
+                slots[seq_id].cache_tokens.clear();
+                llama_memory_seq_rm(llama_get_memory(ctx), seq_id, -1, -1);
             }
         }
 
@@ -4174,11 +4210,6 @@ int main(int argc, char ** argv) {
             oaicompat_type oaicompat) -> void {
         GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
 
-        if (ctx_server.params_base.embedding) {
-            res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
-            return;
-        }
-
         auto completion_id = gen_chatcmplid();
         std::unordered_set<int> task_ids;
         try {
@@ -4433,12 +4464,8 @@ int main(int argc, char ** argv) {
             OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
     };
 
-    const auto handle_chat_completions = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
+    const auto handle_chat_completions = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
         LOG_DBG("request: %s\n", req.body.c_str());
-        if (ctx_server.params_base.embedding) {
-            res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
-            return;
-        }
 
         auto body = json::parse(req.body);
         std::vector<raw_buffer> files;
@@ -4566,13 +4593,18 @@ int main(int argc, char ** argv) {
     };
 
     const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) {
-        const json body = json::parse(req.body);
+        if (!ctx_server.params_base.embedding) {
+            res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
+            return;
+        }
 
         if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
             res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
             return;
         }
 
+        const json body = json::parse(req.body);
+
         // for the shape of input/content, see tokenize_input_prompts()
         json prompt;
         if (body.count("input") != 0) {
@@ -4662,8 +4694,8 @@ int main(int argc, char ** argv) {
     };
 
     const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
-        if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
-            res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
+        if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) {
+            res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
             return;
         }