]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
talk : improve prompting
authorGeorgi Gerganov <redacted>
Mon, 12 Dec 2022 21:44:36 +0000 (23:44 +0200)
committerGeorgi Gerganov <redacted>
Mon, 12 Dec 2022 21:44:36 +0000 (23:44 +0200)
examples/talk/README.md
examples/talk/gpt-2.cpp
examples/talk/talk.cpp
models/download-ggml-model.cmd

index 316a2ae34b04d50d4c7569a6a81182cb25cc4156..160f0ac68360da02f0056d8aab773c2638a1d36b 100644 (file)
@@ -31,7 +31,7 @@ To run this, you will need a ggml GPT-2 model: [instructions](https://github.com
 Alternatively, you can simply download the smallest ggml GPT-2 117M model (240 MB) like this:\r
 \r
 ```\r
-wget --quiet --show-progress -O models/ggml-gpt-2-117M.bin https://ggml.ggerganov.com/ggml-model-gpt-2-117M.bin\r
+wget --quiet --show-progress -O models/ggml-gpt-2-117M.bin https://huggingface.co/datasets/ggerganov/ggml/raw/main/ggml-model-gpt-2-117M.bin\r
 ```\r
 \r
 ## TTS\r
index 3c04a9cba5b7c04017ef13a1e5a26e560ff24e29..c67551daef9b46803e0fe4efa02b060545a494c6 100644 (file)
@@ -139,7 +139,7 @@ gpt_vocab::id gpt_sample_top_k_top_p(
     }
 
     //printf("\n");
-    //for (int i = 0; i < (int)logits_id.size(); i++) {
+    //for (int i = 0; i < (int) logits_id.size(); i++) {
     //    printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), logits_id[i].first);
     //}
     //exit(0);
@@ -825,8 +825,8 @@ Me too.
     int32_t n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
 
     // sampling parameters
-    int32_t top_k = 20;
-    float   top_p = 0.98f;
+    int32_t top_k = 5;
+    float   top_p = 0.9f;
     float   temp  = 1.0f;
 };
 
@@ -840,7 +840,7 @@ struct gpt2_context * gpt2_init(const char * path_model) {
         const int64_t t_start_us = ggml_time_us();
 
         if (!gpt2_model_load(path_model, ctx->model, ctx->vocab)) {
-            fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, "gpt-2.bin");
+            fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
             return nullptr;
         }
 
@@ -913,10 +913,7 @@ std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens)
         result += ctx->vocab.id_to_token[embd[0]];
 
         // end of text token
-        if (embd.back() == 50256 ||
-            ctx->vocab.id_to_token[embd.back()] == "." ||
-            ctx->vocab.id_to_token[embd.back()] == "!" ||
-            ctx->vocab.id_to_token[embd.back()] == "?") {
+        if (embd.back() == 50256) {
             break;
         }
     }
index 2b0b2e99fc3a54dbc2fde4278a54e4f6a04bf854..3cd730ce1b6376fe558ec606a17bbd1a809de97b 100644 (file)
@@ -473,56 +473,15 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
     return result;
 }
 
-// compute similarity between two strings using Levenshtein distance
-float similarity(const std::string & s0, const std::string & s1) {
-    const size_t len0 = s0.size() + 1;
-    const size_t len1 = s1.size() + 1;
+const std::string k_prompt =
+R"(This is a dialogue between {0} (A) and a person (B). The dialogue so far is:
 
-    std::vector<int> col(len1, 0);
-    std::vector<int> prevCol(len1, 0);
+B: Hello {0}, how are you?
+A: I'm fine, thank you.
+{1}
+Here is how {0} (A) continues the dialogue:
 
-    for (size_t i = 0; i < len1; i++) {
-        prevCol[i] = i;
-    }
-
-    for (size_t i = 0; i < len0; i++) {
-        col[0] = i;
-        for (size_t j = 1; j < len1; j++) {
-            col[j] = std::min(std::min(1 + col[j - 1], 1 + prevCol[j]), prevCol[j - 1] + (s0[i - 1] == s1[j - 1] ? 0 : 1));
-        }
-        col.swap(prevCol);
-    }
-
-    const float dist = prevCol[len1 - 1];
-
-    return 1.0f - (dist / std::max(s0.size(), s1.size()));
-}
-
-// generated with ChatGPT
-std::map<std::string, std::string> k_prompts = {
-    { "Santa",
-R"(Kid: Hi Santa! Are you real?
-Santa: Of course I am, my dear! Ho ho ho!
-Kid: Can you please bring me a new toy for Christmas?
-Santa: I'll see what I can do, but you have to make sure to be a good boy or girl and listen to your parents.
-Kid: I will, Santa! Thank you!
-Santa: You're welcome, little one. Merry Christmas! Ho ho ho!
-Kid: Can you tell me how you deliver all the presents to all the kids in the world in one night?
-Santa: It's a secret, but I have a lot of help from my elves and my magical sleigh. And I have a special route that I follow to make sure I visit every child.
-Kid: Wow, that's amazing! Can I please have a ride in your sleigh sometime?
-Santa: I'm sorry, but only good boys and girls get to ride in my sleigh.
-)" },
-    { "Kid",
-R"(Kid: Hi Santa! Are you real?
-Santa: Of course I am, my dear! Ho ho ho!
-Kid: Can you please bring me a new toy for Christmas?
-Santa: I'll see what I can do, but you have to make sure to be a good boy or girl and listen to your parents.
-Kid: I will, Santa! Thank you!
-Kid: Can you tell me how you deliver all the presents to all the kids in the world in one night?
-Santa: It's a secret, but I have a lot of help from my elves and my magical sleigh. And I have a special route that I follow to make sure I visit every child.
-Kid: Wow, that's amazing! Can I please have a ride in your sleigh sometime?
-)" },
-};
+A:)";
 
 int main(int argc, char ** argv) {
     whisper_params params;
@@ -579,7 +538,7 @@ int main(int argc, char ** argv) {
     int n_iter = 0;
 
     bool is_running  = true;
-    bool force_speak = params.person == "Kid";
+    bool force_speak = false;
 
     float prob0 = 0.0f;
     float prob  = 0.0f;
@@ -587,19 +546,13 @@ int main(int argc, char ** argv) {
     std::vector<float> pcmf32_cur;
     std::vector<float> pcmf32_prompt;
 
-    if (k_prompts.find(params.person) == k_prompts.end()) {
-        fprintf(stderr, "%s: unknown person '%s'\n", __func__, params.person.c_str());
-        return 1;
-    }
-
-    gpt2_set_prompt(ctx_gpt, k_prompts.at(params.person).c_str());
+    gpt2_set_prompt(ctx_gpt, "");
 
-    const std::string person_other = params.person == "Santa" ? "Kid" : "Santa";
-    const int voice_id = params.person == "Santa" ? 5 : 2;
+    const int voice_id = rand()%6;
 
-    fprintf(stderr, "gpt-2: prompt_base:\n");
+    fprintf(stderr, "gpt-2: prompt:\n");
     fprintf(stderr, "========================\n\n");
-    fprintf(stderr, "%s\n", gpt2_get_prompt(ctx_gpt));
+    fprintf(stderr, "%s\n", ::replace(k_prompt, "{0}", params.person).c_str());
     fprintf(stderr, "========================\n\n");
 
     // main loop
@@ -636,13 +589,12 @@ int main(int argc, char ** argv) {
 
                 audio.get(params.voice_ms, pcmf32_cur);
 
-                std::string text_heard = "Hey little one, what do you want for Christmas?";
+                std::string text_heard = "";
+
                 if (!force_speak) {
                     text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prob0, t_ms));
                 }
 
-                force_speak = false;
-
                 // remove text between brackets using regex
                 {
                     std::regex re("\\[.*?\\]");
@@ -667,13 +619,15 @@ int main(int argc, char ** argv) {
 
                 const std::vector<gpt_vocab::id> tokens = gpt2_tokenize(ctx_gpt, text_heard.c_str());
 
-                if (text_heard.empty() || tokens.empty()) {
+                if (text_heard.empty() || tokens.empty() || force_speak) {
                     fprintf(stdout, "%s: Heard nothing, skipping ...\n", __func__);
                     audio.clear();
 
                     continue;
                 }
 
+                force_speak = false;
+
                 fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", text_heard.c_str(), "\033[0m", (int) t_ms);
 
                 std::string prompt_base = gpt2_get_prompt(ctx_gpt);
@@ -681,9 +635,11 @@ int main(int argc, char ** argv) {
                 std::string text_to_speak;
 
                 {
-                    text_heard = person_other + ": " + text_heard;
+                    prompt_base += "B: " + text_heard + "\n";
 
-                    text_to_speak = gpt2_gen_text(ctx_gpt, (prompt_base + text_heard + "\n").c_str(), params.max_tokens);
+                    std::string prompt = ::replace(::replace(k_prompt, "{0}", params.person), "{1}", prompt_base);
+
+                    text_to_speak = gpt2_gen_text(ctx_gpt, prompt.c_str(), params.max_tokens);
                     text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
                     text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of("\n"));
 
@@ -703,13 +659,20 @@ int main(int argc, char ** argv) {
                         }
                     }
 
-                    prompt_base += text_heard + "\n" + text_to_speak + "\n";
-                }
+                    prompt_base += "A:" + text_to_speak + "\n";
+
+                    {
+                        prompt = ::replace(::replace(k_prompt, "{0}", params.person), "{1}", prompt_base);
 
-                printf("%s\n", text_to_speak.c_str());
+                        printf("===============\n");
+                        printf("prompt:\n");
+                        printf("%s\n", prompt.c_str());
+                        printf("===============\n");
+                    }
+                }
 
                 //printf("========================\n");
-                //printf("gpt-2: prompt_base:\n'%s'\n", prompt_base.c_str());
+                //printf("gpt-2: prompt_base:\n%s\n", prompt_base.c_str());
                 //printf("========================\n");
 
                 gpt2_set_prompt(ctx_gpt, prompt_base.c_str());
index 4d91187dd172aa86f82ee55dbc3992a23317a0aa..9fe56ccdbeddf5ae549972321adb2a68cccbfd9b 100644 (file)
@@ -40,7 +40,7 @@ if exist "ggml-%model%.bin" (
   goto :eof
 )
 
-PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://ggml.ggerganov.com/ggml-model-whisper-%model%.bin -OutFile ggml-%model%.bin"
+PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://huggingface.co/datasets/ggerganov/whisper.cpp/raw/main/ggml-%model%.bin -OutFile ggml-%model%.bin"
 
 if %ERRORLEVEL% neq 0 (
   echo Failed to download ggml model %model%