]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
embedding : show full embedding for single prompt (#6342)
authorhowlger <redacted>
Wed, 27 Mar 2024 11:15:44 +0000 (12:15 +0100)
committerGitHub <redacted>
Wed, 27 Mar 2024 11:15:44 +0000 (13:15 +0200)
* embedding : show full embedding for single prompt

To support the use case of creating an embedding for a given prompt, the entire embedding and not just the first part needed to be printed.

Also, show cosine similarity matrix only if there is more than one prompt, as the cosine similarity matrix for a single prompt is always `1.00`.

* Update examples/embedding/embedding.cpp

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/embedding/embedding.cpp

index 9aede7fadfe31f607516334ab814eb9e9c71cf76..536657526685ca34217b0118988df6f625a5ecdf 100644 (file)
@@ -178,25 +178,27 @@ int main(int argc, char ** argv) {
     float * out = emb + p * n_embd;
     batch_decode(ctx, batch, out, s, n_embd);
 
-    // print the first part of the embeddings
+    // print the first part of the embeddings or for a single prompt, the full embedding
     fprintf(stdout, "\n");
     for (int j = 0; j < n_prompts; j++) {
         fprintf(stdout, "embedding %d: ", j);
-        for (int i = 0; i < std::min(16, n_embd); i++) {
+        for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
             fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
         }
         fprintf(stdout, "\n");
     }
 
     // print cosine similarity matrix
-    fprintf(stdout, "\n");
-    printf("cosine similarity matrix:\n\n");
-    for (int i = 0; i < n_prompts; i++) {
-        for (int j = 0; j < n_prompts; j++) {
-            float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
-            fprintf(stdout, "%6.2f ", sim);
-        }
+    if (n_prompts > 1) {
         fprintf(stdout, "\n");
+        printf("cosine similarity matrix:\n\n");
+        for (int i = 0; i < n_prompts; i++) {
+            for (int j = 0; j < n_prompts; j++) {
+                float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
+                fprintf(stdout, "%6.2f ", sim);
+            }
+            fprintf(stdout, "\n");
+        }
     }
 
     // clean up