]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : fix pooled embedding output (#14645)
authorDouglas Hanley <redacted>
Sat, 12 Jul 2025 10:21:02 +0000 (06:21 -0400)
committerGitHub <redacted>
Sat, 12 Jul 2025 10:21:02 +0000 (13:21 +0300)
tools/server/server.cpp

index 57b917f2f97b382d92081ef8c9012d105d18feed..d4dffb39c8d16a333dcc75e3657639d08e49c893 100644 (file)
@@ -2581,12 +2581,14 @@ struct server_context {
                 continue;
             }
 
-            const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
-            if (embd == NULL) {
+            const float * embd = nullptr;
+            if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) {
                 embd = llama_get_embeddings_ith(ctx, i);
+            } else {
+                embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
             }
 
-            if (embd == NULL) {
+            if (embd == nullptr) {
                 SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
 
                 res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
@@ -2594,12 +2596,12 @@ struct server_context {
             }
 
             // normalize only when there is pooling
-            // TODO: configurable
             if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
                 common_embd_normalize(embd, embd_res.data(), n_embd, 2);
                 res->embedding.push_back(embd_res);
+                break;
             } else {
-                res->embedding.push_back({ embd, embd + n_embd });
+                res->embedding.emplace_back(embd, embd + n_embd);
             }
         }