]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
main : support special tokens as reverse/anti prompt (#5847)
authorDAN™ <redacted>
Mon, 4 Mar 2024 07:57:20 +0000 (02:57 -0500)
committerGitHub <redacted>
Mon, 4 Mar 2024 07:57:20 +0000 (09:57 +0200)
* Support special tokens as reverse/anti prompt.

* Tokenize antiprompts only once.

* main : minor

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/main/main.cpp

index 34e84d0d42f87870838338bdad10e6d1cbbd1ced..47059e582a0d499493e7e8d4d73534d39ba88830 100644 (file)
@@ -511,6 +511,14 @@ int main(int argc, char ** argv) {
     std::vector<llama_token> embd;
     std::vector<llama_token> embd_guidance;
 
+    // tokenized antiprompts
+    std::vector<std::vector<llama_token>> antiprompt_ids;
+
+    antiprompt_ids.reserve(params.antiprompt.size());
+    for (const std::string & antiprompt : params.antiprompt) {
+        antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
+    }
+
     struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
 
     while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
@@ -769,6 +777,18 @@ int main(int argc, char ** argv) {
                     }
                 }
 
+                // check for reverse prompt using special tokens
+                llama_token last_token = llama_sampling_last(ctx_sampling);
+                for (std::vector<llama_token> ids : antiprompt_ids) {
+                    if (ids.size() == 1 && last_token == ids[0]) {
+                        if (params.interactive) {
+                            is_interacting = true;
+                        }
+                        is_antiprompt = true;
+                        break;
+                    }
+                }
+
                 if (is_antiprompt) {
                     LOG("found antiprompt: %s\n", last_output.c_str());
                 }