]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
context : allow cache-less context for embeddings (#13108)
authorGeorgi Gerganov <redacted>
Thu, 8 May 2025 11:28:33 +0000 (14:28 +0300)
committerGitHub <redacted>
Thu, 8 May 2025 11:28:33 +0000 (14:28 +0300)
* context : allow cache-less context for embeddings

ggml-ci

* context : enable reranking with encode()

ggml-ci

* context : encode() clears embd_seq

ggml-ci

* examples : use llama_encode() when appropriate

ggml-ci

* models : nomic bert moe does not require KV cache

* llama : update comments for llama_decode/llama_encode

ggml-ci

* context : update warning log [no ci]

examples/embedding/embedding.cpp
include/llama.h
src/llama-context.cpp
src/llama-model.cpp
tools/server/server.cpp

index 06fce236e2b85aa3fe74cbf3442e943ef5d5a5d8..01ff6763fff5e1b80cd4f859f6d7a60f59137140 100644 (file)
@@ -35,23 +35,14 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
 
 static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
     const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
-    const struct llama_model * model = llama_get_model(ctx);
 
     // clear previous kv_cache values (irrelevant for embeddings)
     llama_kv_self_clear(ctx);
 
     // run model
     LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
-    if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) {
-        // encoder-only model
-        if (llama_encode(ctx, batch) < 0) {
-            LOG_ERR("%s : failed to encode\n", __func__);
-        }
-    } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
-        // decoder-only model
-        if (llama_decode(ctx, batch) < 0) {
-            LOG_ERR("%s : failed to decode\n", __func__);
-        }
+    if (llama_encode(ctx, batch) < 0) {
+        LOG_ERR("%s : failed to encode\n", __func__);
     }
 
     for (int i = 0; i < batch.n_tokens; i++) {
index e18e9b8da337f042011a49a6018d844da0063136..a18f365bff6f2b17a468c9cb69701b033d2114e5 100644 (file)
@@ -922,14 +922,19 @@ extern "C" {
     // Frees a batch of tokens allocated with llama_batch_init()
     LLAMA_API void llama_batch_free(struct llama_batch batch);
 
-    // Processes a batch of tokens with the ecoder part of the encoder-decoder model.
-    // Stores the encoder output internally for later use by the decoder cross-attention layers.
+    // Process a batch of tokens.
+    // In contrast to llama_decode() - this call does not use KV cache.
+    // For encode-decoder contexts, processes the batch using the encoder.
+    // Can store the encoder output internally for later use by the decoder's cross-attention layers.
     //   0 - success
     // < 0 - error. the KV cache state is restored to the state before this call
     LLAMA_API int32_t llama_encode(
             struct llama_context * ctx,
               struct llama_batch   batch);
 
+    // Process a batch of tokens.
+    // Requires KV cache.
+    // For encode-decoder contexts, processes the batch using the decoder.
     // Positive return values does not mean a fatal error, but rather a warning.
     //   0 - success
     //   1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
index dadb87517f7ef17fb5b7a35bd8f0584a74328093..fd64622b8e02d2e2a789e109e331ecc35906f069 100644 (file)
@@ -251,7 +251,7 @@ llama_context::llama_context(
     }
 
     // reserve worst-case graph
-    if (!hparams.vocab_only) {
+    if (!hparams.vocab_only && memory) {
         const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
         const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
 
@@ -700,6 +700,8 @@ int llama_context::encode(llama_batch & inp_batch) {
         t_compute_start_us = ggml_time_us();
     }
 
+    embd_seq.clear();
+
     n_queued_tokens += n_tokens;
 
     const int64_t n_embd = hparams.n_embd;
@@ -761,12 +763,12 @@ int llama_context::encode(llama_batch & inp_batch) {
         ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
         GGML_ASSERT(backend_embd != nullptr);
 
-        GGML_ASSERT(embd != nullptr);
-
         switch (cparams.pooling_type) {
             case LLAMA_POOLING_TYPE_NONE:
                 {
                     // extract token embeddings
+                    GGML_ASSERT(embd != nullptr);
+
                     GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
                     ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
                 } break;
@@ -791,11 +793,18 @@ int llama_context::encode(llama_batch & inp_batch) {
                 } break;
             case LLAMA_POOLING_TYPE_RANK:
                 {
-                    // TODO: this likely should be the same logic as in llama_decoder_internal, but better to
-                    //       wait for an encoder model that requires this pooling type in order to test it
-                    //       https://github.com/ggerganov/llama.cpp/pull/9510
-                    GGML_ABORT("RANK pooling not implemented yet");
-                }
+                    // extract the rerank score - a single float per sequence
+                    auto & embd_seq_out = embd_seq;
+
+                    for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
+                        const llama_seq_id seq_id = ubatch.seq_id[s][0];
+                        if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
+                            continue;
+                        }
+                        embd_seq_out[seq_id].resize(1);
+                        ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
+                    }
+                } break;
             case LLAMA_POOLING_TYPE_UNSPECIFIED:
                 {
                     GGML_ABORT("unknown pooling type");
@@ -833,6 +842,11 @@ int llama_context::encode(llama_batch & inp_batch) {
 }
 
 int llama_context::decode(llama_batch & inp_batch) {
+    if (!memory) {
+        LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__);
+        return encode(inp_batch);
+    }
+
     if (inp_batch.n_tokens == 0) {
         LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
         return -1;
index 1603eae1292c9eea376761ebffdd8f17aff472eb..3ca265be8dca4e2e08eef0353b986fed1446362f 100644 (file)
@@ -12852,6 +12852,13 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
     llama_memory_i * res;
 
     switch (arch) {
+        case LLM_ARCH_BERT:
+        case LLM_ARCH_JINA_BERT_V2:
+        case LLM_ARCH_NOMIC_BERT:
+        case LLM_ARCH_NOMIC_BERT_MOE:
+            {
+                res = nullptr;
+            } break;
         case LLM_ARCH_MAMBA:
         case LLM_ARCH_RWKV6:
         case LLM_ARCH_RWKV6QWEN2:
index e0e99eafcdf5ce83c91f61bcfee6e5fef3a7dd54..06788bbdc854512716b418d2a459b80d0e211384 100644 (file)
@@ -3214,7 +3214,14 @@ struct server_context {
                 batch.logits   + i,
             };
 
-            const int ret = llama_decode(ctx, batch_view);
+            int ret = 0;
+
+            if (params_base.embedding || params_base.reranking) {
+                ret = llama_encode(ctx, batch_view);
+            } else {
+                ret = llama_decode(ctx, batch_view);
+            }
+
             metrics.on_decoded(slots);
 
             if (ret != 0) {
@@ -3943,7 +3950,7 @@ int main(int argc, char ** argv) {
     const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok](
             server_task_type type,
             json & data,
-            std::function<bool()> is_connection_closed,
+            const std::function<bool()> & is_connection_closed,
             httplib::Response & res,
             oaicompat_type oaicompat) {
         GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);