From: Alex Klinkhamer Date: Fri, 21 Apr 2023 18:18:09 +0000 (-0700) Subject: main : evaluate tokens in batches after swapping context (#1014) X-Git-Tag: gguf-v0.4.0~903 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=9411288271ab548216902a029f42a0a38ebcedb7;p=pkg%2Fggml%2Fsources%2Fllama.cpp main : evaluate tokens in batches after swapping context (#1014) * examples : evaluate tokens in batches after swapping context * Update examples/main/main.cpp --------- Co-authored-by: Georgi Gerganov --- diff --git a/examples/main/main.cpp b/examples/main/main.cpp index b7b3c419..65db7926 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -264,7 +264,7 @@ int main(int argc, char ** argv) { // infinite text generation via context swapping // if we run out of context: // - take the n_keep first tokens from the original prompt (via n_past) - // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch + // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches if (n_past + (int) embd.size() > n_ctx) { const int n_left = n_past - params.n_keep; @@ -282,13 +282,21 @@ int main(int argc, char ** argv) { //printf("\n---\n"); } - if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) { - fprintf(stderr, "%s : failed to eval\n", __func__); - return 1; + // evaluate tokens in batches + // embd is typically prepared beforehand to fit within a batch, but not always + for (int i = 0; i < (int) embd.size(); i += params.n_batch) { + int n_eval = (int) embd.size() - i; + if (n_eval > params.n_batch) { + n_eval = params.n_batch; + } + if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return 1; + } + n_past += n_eval; } } - n_past += embd.size(); embd.clear(); if ((int) embd_inp.size() <= n_consumed && !is_interacting) {