]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
main : add --conversation / -cnv flag (#7108)
authorDawid Potocki <redacted>
Wed, 8 May 2024 14:32:32 +0000 (02:32 +1200)
committerGitHub <redacted>
Wed, 8 May 2024 14:32:32 +0000 (17:32 +0300)
common/common.cpp
common/common.h
examples/main/main.cpp

index 467fb014eedb0bd7e6810549c00a05283aaa8145..4a9da284e7ec9d7ecdcd2427d1c90ec17a0805c8 100644 (file)
@@ -911,6 +911,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         params.instruct = true;
         return true;
     }
+    if (arg == "-cnv" || arg == "--conversation") {
+        params.conversation = true;
+        return true;
+    }
     if (arg == "-cml" || arg == "--chatml") {
         params.chatml = true;
         return true;
@@ -1417,6 +1421,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  --version             show version and build info\n");
     printf("  -i, --interactive     run in interactive mode\n");
     printf("  --interactive-first   run in interactive mode and wait for input right away\n");
+    printf("  -cnv, --conversation  run in conversation mode (does not print special tokens and suffix/prefix)\n");
     printf("  -ins, --instruct      run in instruction mode (use with Alpaca models)\n");
     printf("  -cml, --chatml        run in chatml mode (use with ChatML-compatible models)\n");
     printf("  --multiline-input     allows you to write or paste multiple lines without ending each in '\\'\n");
index 9252a4b63889b395d6fa0fe14a358da6045af07e..6f00a2cca888374093224a5163339ed2d35ccd2c 100644 (file)
@@ -140,6 +140,7 @@ struct gpt_params {
     bool random_prompt     = false; // do not randomize prompt if none provided
     bool use_color         = false; // use color to distinguish generations and inputs
     bool interactive       = false; // interactive mode
+    bool conversation      = false; // conversation mode (does not print special tokens and suffix/prefix)
     bool chatml            = false; // chatml mode (used for models trained on chatml syntax)
     bool prompt_cache_all  = false; // save user input and generations to prompt cache
     bool prompt_cache_ro   = false; // open the prompt cache read-only and do not update it
index f676ea1ba8a1372041679d8fe635cbdba057b414..49acd6bab4074ac5bf23e88fa588f6c114dd17e8 100644 (file)
@@ -362,6 +362,9 @@ int main(int argc, char ** argv) {
         params.interactive_first = true;
         params.antiprompt.emplace_back("<|im_start|>user\n");
     }
+    else if (params.conversation) {
+        params.interactive_first = true;
+    }
 
     // enable interactive mode if interactive start is specified
     if (params.interactive_first) {
@@ -733,7 +736,7 @@ int main(int argc, char ** argv) {
         // display text
         if (input_echo && display) {
             for (auto id : embd) {
-                const std::string token_str = llama_token_to_piece(ctx, id);
+                const std::string token_str = llama_token_to_piece(ctx, id, !params.conversation);
                 printf("%s", token_str.c_str());
 
                 if (embd.size() > 1) {
@@ -816,7 +819,7 @@ int main(int argc, char ** argv) {
             if (n_past > 0 && is_interacting) {
                 LOG("waiting for user input\n");
 
-                if (params.instruct || params.chatml) {
+                if (params.conversation || params.instruct || params.chatml) {
                     printf("\n> ");
                 }
 
@@ -826,7 +829,7 @@ int main(int argc, char ** argv) {
                 }
 
                 std::string buffer;
-                if (!params.input_prefix.empty()) {
+                if (!params.input_prefix.empty() && !params.conversation) {
                     LOG("appending input prefix: '%s'\n", params.input_prefix.c_str());
                     printf("%s", params.input_prefix.c_str());
                 }
@@ -850,7 +853,7 @@ int main(int argc, char ** argv) {
                 // Entering a empty line lets the user pass control back
                 if (buffer.length() > 1) {
                     // append input suffix if any
-                    if (!params.input_suffix.empty()) {
+                    if (!params.input_suffix.empty() && !params.conversation) {
                         LOG("appending input suffix: '%s'\n", params.input_suffix.c_str());
                         printf("%s", params.input_suffix.c_str());
                     }