]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
examples : allow extracting embeddings from decoder contexts (#13797)
authorGeorgi Gerganov <redacted>
Mon, 26 May 2025 11:03:54 +0000 (14:03 +0300)
committerGitHub <redacted>
Mon, 26 May 2025 11:03:54 +0000 (14:03 +0300)
ggml-ci

examples/embedding/embedding.cpp
examples/retrieval/retrieval.cpp
src/llama-context.cpp
tools/server/server.cpp

index 01ff6763fff5e1b80cd4f859f6d7a60f59137140..71f700877a3b9faca3247231153b631e3f135174 100644 (file)
@@ -41,8 +41,8 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
 
     // run model
     LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
-    if (llama_encode(ctx, batch) < 0) {
-        LOG_ERR("%s : failed to encode\n", __func__);
+    if (llama_decode(ctx, batch) < 0) {
+        LOG_ERR("%s : failed to process\n", __func__);
     }
 
     for (int i = 0; i < batch.n_tokens; i++) {
index e3d0c9542ed8d1d83a5a9e0afc8d75632c941ae0..754da1411bcc14fd8e65438eefd715189450581a 100644 (file)
@@ -81,14 +81,14 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
     }
 }
 
-static void batch_encode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
+static void batch_process(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
     // clear previous kv_cache values (irrelevant for embeddings)
     llama_kv_self_clear(ctx);
 
     // run model
     LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
-    if (llama_encode(ctx, batch) < 0) {
-        LOG_ERR("%s : failed to encode\n", __func__);
+    if (llama_decode(ctx, batch) < 0) {
+        LOG_ERR("%s : failed to process\n", __func__);
     }
 
     for (int i = 0; i < batch.n_tokens; i++) {
@@ -233,7 +233,7 @@ int main(int argc, char ** argv) {
         // encode if at capacity
         if (batch.n_tokens + n_toks > n_batch) {
             float * out = emb + p * n_embd;
-            batch_encode(ctx, batch, out, s, n_embd);
+            batch_process(ctx, batch, out, s, n_embd);
             common_batch_clear(batch);
             p += s;
             s = 0;
@@ -246,7 +246,7 @@ int main(int argc, char ** argv) {
 
     // final batch
     float * out = emb + p * n_embd;
-    batch_encode(ctx, batch, out, s, n_embd);
+    batch_process(ctx, batch, out, s, n_embd);
 
     // save embeddings to chunks
     for (int i = 0; i < n_chunks; i++) {
@@ -267,7 +267,7 @@ int main(int argc, char ** argv) {
         batch_add_seq(query_batch, query_tokens, 0);
 
         std::vector<float> query_emb(n_embd, 0);
-        batch_encode(ctx, query_batch, query_emb.data(), 1, n_embd);
+        batch_process(ctx, query_batch, query_emb.data(), 1, n_embd);
 
         common_batch_clear(query_batch);
 
index 98ecb7c8249ceaa484eb04ea4e12da120912867b..ad77cae20eb504fc9bea8d21a1abee597587e27e 100644 (file)
@@ -852,7 +852,7 @@ int llama_context::encode(llama_batch & inp_batch) {
 
 int llama_context::decode(llama_batch & inp_batch) {
     if (!memory) {
-        LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__);
+        LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
         return encode(inp_batch);
     }
 
index 07b613122e321b44fc641263a0aefa206fc93e0c..fcab1dfaafe6e6ab4f6a66c9016adf66a187afec 100644 (file)
@@ -3394,13 +3394,7 @@ struct server_context {
                 batch.logits   + i,
             };
 
-            int ret = 0;
-
-            if (do_encode) {
-                ret = llama_encode(ctx, batch_view);
-            } else {
-                ret = llama_decode(ctx, batch_view);
-            }
+            const int ret = llama_decode(ctx, batch_view);
 
             metrics.on_decoded(slots);