]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
talk-llama : use llama_decode instead of llama_eval
authorGeorgi Gerganov <redacted>
Fri, 8 Mar 2024 10:04:43 +0000 (12:04 +0200)
committerGeorgi Gerganov <redacted>
Fri, 8 Mar 2024 10:04:43 +0000 (12:04 +0200)
examples/talk-llama/talk-llama.cpp

index ddc9e765c2da30a3b99245c46adfe8b8fbb3f384..4e1c1755f1cd05668232da945568513f7e952fc9 100644 (file)
@@ -391,6 +391,8 @@ int main(int argc, char ** argv) {
 
     prompt_llama = ::replace(prompt_llama, "{4}", chat_symb);
 
+    llama_batch batch = llama_batch_init(llama_n_ctx(ctx_llama), 0, 1);
+
     // init session
     std::string path_session = params.path_session;
     std::vector<llama_token> session_tokens;
@@ -426,8 +428,21 @@ int main(int argc, char ** argv) {
     printf("\n");
     printf("%s : initializing - please wait ...\n", __func__);
 
-    if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0)) {
-        fprintf(stderr, "%s : failed to eval\n", __func__);
+    // prepare batch
+    {
+        batch.n_tokens = embd_inp.size();
+
+        for (int i = 0; i < batch.n_tokens; i++) {
+            batch.token[i]     = embd_inp[i];
+            batch.pos[i]       = i;
+            batch.n_seq_id[i]  = 1;
+            batch.seq_id[i][0] = 0;
+            batch.logits[i]    = i == batch.n_tokens - 1;
+        }
+    }
+
+    if (llama_decode(ctx_llama, batch)) {
+        fprintf(stderr, "%s : failed to decode\n", __func__);
         return 1;
     }
 
@@ -647,8 +662,21 @@ int main(int argc, char ** argv) {
                             n_session_consumed = session_tokens.size();
                         }
 
-                        if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past)) {
-                            fprintf(stderr, "%s : failed to eval\n", __func__);
+                        // prepare batch
+                        {
+                            batch.n_tokens = embd.size();
+
+                            for (int i = 0; i < batch.n_tokens; i++) {
+                                batch.token[i]     = embd[i];
+                                batch.pos[i]       = n_past + i;
+                                batch.n_seq_id[i]  = 1;
+                                batch.seq_id[i][0] = 0;
+                                batch.logits[i]    = i == batch.n_tokens - 1;
+                            }
+                        }
+
+                        if (llama_decode(ctx_llama, batch)) {
+                            fprintf(stderr, "%s : failed to decode\n", __func__);
                             return 1;
                         }
                     }