]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
main : evaluate tokens in batches after swapping context (#1014)
authorAlex Klinkhamer <redacted>
Fri, 21 Apr 2023 18:18:09 +0000 (11:18 -0700)
committerGitHub <redacted>
Fri, 21 Apr 2023 18:18:09 +0000 (21:18 +0300)
* examples : evaluate tokens in batches after swapping context

* Update examples/main/main.cpp

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/main/main.cpp

index b7b3c419655f663bba703d27b9c0f5295dd2d6aa..65db792631f36ab0b9fc28545d7a437543a1a9aa 100644 (file)
@@ -264,7 +264,7 @@ int main(int argc, char ** argv) {
             // infinite text generation via context swapping
             // if we run out of context:
             // - take the n_keep first tokens from the original prompt (via n_past)
-            // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch
+            // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
             if (n_past + (int) embd.size() > n_ctx) {
                 const int n_left = n_past - params.n_keep;
 
@@ -282,13 +282,21 @@ int main(int argc, char ** argv) {
                 //printf("\n---\n");
             }
 
-            if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
-                fprintf(stderr, "%s : failed to eval\n", __func__);
-                return 1;
+            // evaluate tokens in batches
+            // embd is typically prepared beforehand to fit within a batch, but not always
+            for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
+                int n_eval = (int) embd.size() - i;
+                if (n_eval > params.n_batch) {
+                    n_eval = params.n_batch;
+                }
+                if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) {
+                    fprintf(stderr, "%s : failed to eval\n", __func__);
+                    return 1;
+                }
+                n_past += n_eval;
             }
         }
 
-        n_past += embd.size();
         embd.clear();
 
         if ((int) embd_inp.size() <= n_consumed && !is_interacting) {