]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
talk-llama : fix session prompt load (#854)
authorLuis Herrera <redacted>
Tue, 2 May 2023 17:05:27 +0000 (12:05 -0500)
committerGitHub <redacted>
Tue, 2 May 2023 17:05:27 +0000 (20:05 +0300)
examples/talk-llama/talk-llama.cpp

index 36f9bf4c6f25ba4c961ae7cdc04169065581a3a9..7960ab7df19fa372b071bad4984bafdc17cf3eeb 100644 (file)
@@ -333,27 +333,10 @@ int main(int argc, char ** argv) {
 
     prompt_llama = ::replace(prompt_llama, "{4}", chat_symb);
 
-    // evaluate the initial prompt
-
-    auto embd_inp = ::llama_tokenize(ctx_llama, prompt_llama, true);
-
-    printf("\n");
-    printf("%s : initializing - please wait ...\n", __func__);
-
-    if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0, params.n_threads)) {
-        fprintf(stderr, "%s : failed to eval\n", __func__);
-        return 1;
-    }
-
-    if (params.verbose_prompt) {
-        fprintf(stdout, "\n");
-        fprintf(stdout, "%s", prompt_llama.c_str());
-        fflush(stdout);
-    }
-
     // init session
     std::string path_session = params.path_session;
     std::vector<llama_token> session_tokens;
+    auto embd_inp = ::llama_tokenize(ctx_llama, prompt_llama, true);
 
     if (!path_session.empty()) {
         fprintf(stderr, "%s: attempting to load saved session from %s\n", __func__, path_session.c_str());
@@ -370,6 +353,9 @@ int main(int argc, char ** argv) {
                 return 1;
             }
             session_tokens.resize(n_token_count_out);
+            for (size_t i = 0; i < session_tokens.size(); i++) {
+                embd_inp[i] = session_tokens[i];
+            }
 
             fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size());
         } else {
@@ -377,6 +363,22 @@ int main(int argc, char ** argv) {
         }
     }
 
+    // evaluate the initial prompt
+
+    printf("\n");
+    printf("%s : initializing - please wait ...\n", __func__);
+
+    if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0, params.n_threads)) {
+        fprintf(stderr, "%s : failed to eval\n", __func__);
+        return 1;
+    }
+
+    if (params.verbose_prompt) {
+        fprintf(stdout, "\n");
+        fprintf(stdout, "%s", prompt_llama.c_str());
+        fflush(stdout);
+    }
+
      // debug message about similarity of saved session, if applicable
     size_t n_matching_session_tokens = 0;
     if (session_tokens.size()) {
@@ -417,7 +419,7 @@ int main(int argc, char ** argv) {
 
     int n_past = n_keep;
     int n_prev = 64; // TODO arg
-    int n_session_consumed = 0;
+    int n_session_consumed = !path_session.empty() && session_tokens.size() > 0 ? session_tokens.size() : 0;
 
     std::vector<llama_token> embd;
 
@@ -494,6 +496,11 @@ int main(int argc, char ** argv) {
 
                 embd = ::llama_tokenize(ctx_llama, text_heard, false);
 
+                // Append the new input tokens to the session_tokens vector
+                if (!path_session.empty()) {
+                    session_tokens.insert(session_tokens.end(), tokens.begin(), tokens.end());
+                }
+
                 // text inference
                 bool done = false;
                 std::string text_to_speak;
@@ -539,20 +546,21 @@ int main(int argc, char ** argv) {
                             }
                         }
 
+                        if (embd.size() > 0 && !path_session.empty()) {
+                            session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
+                            n_session_consumed = session_tokens.size();
+                        }
+
                         if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past, params.n_threads)) {
                             fprintf(stderr, "%s : failed to eval\n", __func__);
                             return 1;
                         }
                     }
 
-                    //printf("n_iter = %d, n_past = %d, n_ctx = %d, n_keep = %d, n_prev = %d, embd.size() = %d\n", n_iter, n_past, n_ctx, n_keep, n_prev, (int) embd.size());
 
                     embd_inp.insert(embd_inp.end(), embd.begin(), embd.end());
                     n_past += embd.size();
-                    if (embd.size() > 0 && !path_session.empty()) {
-                        session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
-                        n_session_consumed = session_tokens.size();
-                    }
+                    
                     embd.clear();
 
                     if (done) break;