]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : add initial_prompt param (#645)
authorJhen-Jie Hong <redacted>
Wed, 29 Mar 2023 20:23:23 +0000 (04:23 +0800)
committerGitHub <redacted>
Wed, 29 Mar 2023 20:23:23 +0000 (23:23 +0300)
examples/addon.node/addon.cpp
examples/main/main.cpp
whisper.cpp
whisper.h

index 0fa4a8ca3bc3ffcc239fd7a75afb809fef13c775..52e80ad8528c703ef6da2d979d3cb3b86bafd588 100644 (file)
@@ -160,22 +160,6 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
         return 3;
     }
 
-    // initial prompt
-    std::vector<whisper_token> prompt_tokens;
-
-    if (!params.prompt.empty()) {
-        prompt_tokens.resize(1024);
-        prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
-
-        fprintf(stderr, "\n");
-        fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
-        fprintf(stderr, "initial tokens: [ ");
-        for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
-            fprintf(stderr, "%d ", prompt_tokens[i]);
-        }
-        fprintf(stderr, "]\n");
-    }
-
     for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
         const auto fname_inp = params.fname_inp[f];
         const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
@@ -243,8 +227,7 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
             wparams.greedy.best_of        = params.best_of;
             wparams.beam_search.beam_size = params.beam_size;
 
-            wparams.prompt_tokens     = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
-            wparams.prompt_n_tokens   = prompt_tokens.empty() ? 0       : prompt_tokens.size();
+            wparams.initial_prompt   = params.prompt.c_str();
 
             whisper_print_user_data user_data = { &params, &pcmf32s };
 
index dd30ba4c473766169e10594283c1e33a66483891..7131a937b757c899c46fa5fc2b5a2a545897965e 100644 (file)
@@ -639,22 +639,6 @@ int main(int argc, char ** argv) {
         return 3;
     }
 
-    // initial prompt
-    std::vector<whisper_token> prompt_tokens;
-
-    if (!params.prompt.empty()) {
-        prompt_tokens.resize(1024);
-        prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
-
-        fprintf(stderr, "\n");
-        fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
-        fprintf(stderr, "initial tokens: [ ");
-        for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
-            fprintf(stderr, "%d ", prompt_tokens[i]);
-        }
-        fprintf(stderr, "]\n");
-    }
-
     for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
         const auto fname_inp = params.fname_inp[f];
                const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
@@ -718,8 +702,7 @@ int main(int argc, char ** argv) {
 
             wparams.speed_up         = params.speed_up;
 
-            wparams.prompt_tokens     = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
-            wparams.prompt_n_tokens   = prompt_tokens.empty() ? 0       : prompt_tokens.size();
+            wparams.initial_prompt   = params.prompt.c_str();
 
             wparams.greedy.best_of        = params.best_of;
             wparams.beam_search.beam_size = params.beam_size;
index f44e5034bb804c711bd91b24cee9a46468fa1e65..13e11141cc95e2cc8cc30296ce5dfb88e3607e34 100644 (file)
@@ -3121,6 +3121,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
         /*.speed_up         =*/ false,
         /*.audio_ctx        =*/ 0,
 
+        /*.initial_prompt   =*/ nullptr,
         /*.prompt_tokens    =*/ nullptr,
         /*.prompt_n_tokens  =*/ 0,
 
@@ -3793,6 +3794,15 @@ int whisper_full_with_state(
         prompt_past.clear();
     }
 
+    // initial prompt
+    if (!params.prompt_tokens && params.initial_prompt) {
+        std::vector<whisper_token> prompt_tokens;
+        prompt_tokens.resize(1024);
+        prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
+        params.prompt_tokens = prompt_tokens.data();
+        params.prompt_n_tokens = prompt_tokens.size();
+    }
+
     // prepend the prompt tokens to the prompt_past
     if (params.prompt_tokens && params.prompt_n_tokens > 0) {
         // parse tokens from the pointer
index fc107108ad810da73e2fc4d8a0339e73bccf749e..fa6bff4fc8da098eed2ea831b865a29ec7433d48 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -356,6 +356,7 @@ extern "C" {
 
         // tokens to provide to the whisper decoder as initial prompt
         // these are prepended to any existing text context from a previous call
+        const char * initial_prompt;
         const whisper_token * prompt_tokens;
         int prompt_n_tokens;