]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : normalize embeddings (#5956)
authorSeungWon Jeong <redacted>
Sat, 9 Mar 2024 12:27:58 +0000 (21:27 +0900)
committerGitHub <redacted>
Sat, 9 Mar 2024 12:27:58 +0000 (14:27 +0200)
* output normalize embedding in '/v1/embeddings'

* common : reuse llama_embd_normalize

* common : better normalize impl

---------

Co-authored-by: Georgi Gerganov <redacted>
common/common.cpp
common/common.h
examples/embedding/embedding.cpp
examples/server/server.cpp

index d7f650ef486e93b8f141f20e242a331d554f4433..16ef4d7f74dd99979a525350226bef0a14abebdc 100644 (file)
@@ -1852,3 +1852,18 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
 
     printf("\n=== Done dumping\n");
 }
+
+void llama_embd_normalize(const float * inp, float * out, int n) {
+    double sum = 0.0;
+    for (int i = 0; i < n; i++) {
+        sum += inp[i] * inp[i];
+    }
+    sum = sqrt(sum);
+
+    const float norm = sum > 0.0 ? 1.0f / sum : 0.0f;
+
+    for (int i = 0; i < n; i++) {
+        out[i] = inp[i] * norm;
+    }
+}
+
index 977ce419ff93bc6d9096fdb29e18f8c8ea97ab5e..f8d82b8713c8715329e90333c61040f4adb327bb 100644 (file)
@@ -260,3 +260,10 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);
 
 // Dump the KV cache view showing individual sequences in each cell (long output).
 void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
+
+//
+// Embedding utils
+//
+
+void llama_embd_normalize(const float * inp, float * out, int n);
+
index ff5883da6ba27a7b5bbd3326ee7b53f7b50665c1..a553ae1c3f35d39314784084aeb25be67abfa838 100644 (file)
@@ -23,17 +23,6 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
     }
 }
 
-static void normalize(const float * vec, float * out, int n) {
-    float norm = 0;
-    for (int i = 0; i < n; i++) {
-        norm += vec[i] * vec[i];
-    }
-    norm = sqrt(norm);
-    for (int i = 0; i < n; i++) {
-        out[i] = vec[i] / norm;
-    }
-}
-
 static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
     // clear previous kv_cache values (irrelevant for embeddings)
     llama_kv_cache_clear(ctx);
@@ -44,7 +33,6 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
         fprintf(stderr, "%s : failed to decode\n", __func__);
     }
 
-    // normalize on copy
     for (int i = 0; i < batch.n_tokens; i++) {
         if (!batch.logits[i]) {
             continue;
@@ -61,7 +49,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
         }
 
         float * out = output + batch.seq_id[i][0] * n_embd;
-        normalize(embd, out, n_embd);
+        llama_embd_normalize(embd, out, n_embd);
     }
 }
 
index 8cff514f2a98aa3734146ddf2ad8bb09a0adfb5c..796f3499c9877ac98225843fae791f1ce8e2edc3 100644 (file)
@@ -1327,6 +1327,8 @@ struct server_context {
 
         const int n_embd = llama_n_embd(model);
 
+        std::vector<float> embd_res(n_embd, 0.0f);
+
         for (int i = 0; i < batch.n_tokens; ++i) {
             if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
                 continue;
@@ -1350,8 +1352,10 @@ struct server_context {
                 continue;
             }
 
+            llama_embd_normalize(embd, embd_res.data(), n_embd);
+
             res.data = json {
-                {"embedding", std::vector<float>(embd, embd + n_embd)},
+                {"embedding", embd_res},
             };
         }
 
@@ -3354,6 +3358,8 @@ int main(int argc, char ** argv) {
             // get the result
             server_task_result result = ctx_server.queue_results.recv(id_task);
             ctx_server.queue_results.remove_waiting_task_id(id_task);
+
+            // append to the responses
             responses.push_back(result.data);
         }