]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
embedding : evaluate prompt in batches (#2713)
authorslaren <redacted>
Tue, 22 Aug 2023 14:03:12 +0000 (16:03 +0200)
committerGitHub <redacted>
Tue, 22 Aug 2023 14:03:12 +0000 (16:03 +0200)
examples/embedding/embedding.cpp

index 8788571cbf9d42603eea40f28de7f0ab6cc82f41..38395c75b0b5bc3ecbde8fd9413082476a7d4859 100644 (file)
@@ -72,22 +72,29 @@ int main(int argc, char ** argv) {
         fprintf(stderr, "\n");
     }
 
-    if (params.embedding){
-        if (embd_inp.size() > 0) {
-            if (llama_eval(ctx, embd_inp.data(), embd_inp.size(), n_past, params.n_threads)) {
-                fprintf(stderr, "%s : failed to eval\n", __func__);
-                return 1;
-            }
+    if (embd_inp.size() > (size_t)params.n_ctx) {
+        fprintf(stderr, "%s: error: prompt is longer than the context window (%zu tokens, n_ctx = %d)\n",
+                __func__, embd_inp.size(), params.n_ctx);
+        return 1;
+    }
+
+    while (!embd_inp.empty()) {
+        int n_tokens = std::min(params.n_batch, (int) embd_inp.size());
+        if (llama_eval(ctx, embd_inp.data(), n_tokens, n_past, params.n_threads)) {
+            fprintf(stderr, "%s : failed to eval\n", __func__);
+            return 1;
         }
+        n_past += n_tokens;
+        embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_tokens);
+    }
 
-        const int n_embd = llama_n_embd(ctx);
-        const auto embeddings = llama_get_embeddings(ctx);
+    const int n_embd = llama_n_embd(ctx);
+    const auto embeddings = llama_get_embeddings(ctx);
 
-        for (int i = 0; i < n_embd; i++) {
-            printf("%f ", embeddings[i]);
-        }
-        printf("\n");
+    for (int i = 0; i < n_embd; i++) {
+        printf("%f ", embeddings[i]);
     }
+    printf("\n");
 
     llama_print_timings(ctx);
     llama_free(ctx);