]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Check for reverse prompt by characters instead of tokens (#292) (#330)
authortjohnman <redacted>
Tue, 21 Mar 2023 16:04:43 +0000 (17:04 +0100)
committerGitHub <redacted>
Tue, 21 Mar 2023 16:04:43 +0000 (18:04 +0200)
* Check for reverse prompt by characters instead of tokens (#292)

* Update main.cpp

Wording.

* Cleanup.

* Remove unnecessary use of std::stringstream.

---------

Co-authored-by: Johnman <redacted>
Co-authored-by: Georgi Gerganov <redacted>
main.cpp

index 6bae80cdf5876759ac0f24a533a711ba369f74ab..bda824ff1d1fac3b1cbc4081cb60cb4ef5cf0def 100644 (file)
--- a/main.cpp
+++ b/main.cpp
@@ -885,15 +885,8 @@ int main(int argc, char ** argv) {
         params.antiprompt.push_back("### Instruction:\n\n");
     }
 
-    // tokenize the reverse prompt
-    std::vector<std::vector<llama_vocab::id>> antipromptv_inp;
-
-    for (auto antiprompt : params.antiprompt) {
-        antipromptv_inp.push_back(::llama_tokenize(vocab, antiprompt, false));
-    }
-
     // enable interactive mode if reverse prompt is specified
-    if (antipromptv_inp.size() != 0) {
+    if (params.antiprompt.size() != 0) {
         params.interactive = true;
     }
 
@@ -917,15 +910,9 @@ int main(int argc, char ** argv) {
 
         fprintf(stderr, "%s: interactive mode on.\n", __func__);
 
-        if(antipromptv_inp.size()) {
-            for (size_t apindex = 0; apindex < antipromptv_inp.size(); ++apindex) {
-                auto antiprompt_inp = antipromptv_inp.at(apindex);
-                fprintf(stderr, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.at(apindex).c_str());
-                fprintf(stderr, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size());
-                for (int i = 0; i < (int) antiprompt_inp.size(); i++) {
-                    fprintf(stderr, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str());
-                }
-                fprintf(stderr, "\n");
+        if(params.antiprompt.size()) {
+            for (auto antiprompt : params.antiprompt) {
+                fprintf(stderr, "Reverse prompt: '%s'\n", antiprompt.c_str());
             }
         }
     }
@@ -1042,9 +1029,14 @@ int main(int argc, char ** argv) {
         // check if we should prompt the user for more
         if (params.interactive && (int) embd_inp.size() <= input_consumed) {
             // check for reverse prompt
-            for (auto antiprompt_inp : antipromptv_inp) {
-                if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) {
-                    // reverse prompt found
+            std::string last_output;
+            for (auto id : last_n_tokens) {
+                last_output += vocab.id_to_token[id];
+            }
+
+            // Check if each of the reverse prompts appears at the end of the output.
+            for (std::string antiprompt : params.antiprompt) {
+                if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
                     is_interacting = true;
                     break;
                 }