]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
Prompt previous tokens for streaming (#163)
authorM. Eren Akbiyik <redacted>
Tue, 22 Nov 2022 16:10:35 +0000 (17:10 +0100)
committerGitHub <redacted>
Tue, 22 Nov 2022 16:10:35 +0000 (18:10 +0200)
* feat: prompt previous tokens for streaming

I used a vector pointer instead of vector itself because it gave weird errors, and why not

* convert vector to use with C api

* feat: remove old refs, check for prompt size

* feat: use better way of getting the pointer

examples/stream/stream.cpp
whisper.cpp
whisper.h

index 6f3634b79244da783153617d737007583adbae49..32f93d6fcbd4b9b7b31eb0c21b1e9d662d54378d 100644 (file)
@@ -234,6 +234,7 @@ int main(int argc, char ** argv) {
     std::vector<float> pcmf32(n_samples_30s, 0.0f);
     std::vector<float> pcmf32_old;
 
+    std::vector<whisper_token> prompt_tokens;
     const int n_new_line = params.length_ms / params.step_ms - 1;
 
     // print some info about the processing
@@ -344,6 +345,9 @@ int main(int argc, char ** argv) {
             wparams.audio_ctx            = params.audio_ctx;
             wparams.speed_up             = params.speed_up;
 
+            wparams.prompt_tokens        = prompt_tokens.data();
+            wparams.prompt_n_tokens      = prompt_tokens.size();
+            
             if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
                 fprintf(stderr, "%s: failed to process audio\n", argv[0]);
                 return 6;
@@ -393,6 +397,16 @@ int main(int argc, char ** argv) {
 
                 // keep part of the audio for next iteration to try to mitigate word boundary issues
                 pcmf32_old = std::vector<float>(pcmf32.end() - n_samples_keep, pcmf32.end());
+
+                // Add tokens of the last full length segment as the prompt
+                prompt_tokens.clear();
+                const int n_segments = whisper_full_n_segments(ctx);
+                for (int i = 0; i < n_segments; ++i) {
+                    const int token_count = whisper_full_n_tokens(ctx, i);
+                    for (int j = 0; j < token_count; ++j) {
+                        prompt_tokens.push_back(whisper_full_get_token_id(ctx, i, j));
+                    }
+                }
             }
         }
     }
index 7052355bc8149c44cd7c117ac0fec0ded4984e0d..28c5d26a02ce98df839a964bb75e883b785ab849 100644 (file)
@@ -2412,6 +2412,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
                     /*.speed_up             =*/ false,
                     /*.audio_ctx            =*/ 0,
 
+                    /*.prompt_tokens        =*/ nullptr,
+                    /*.prompt_n_tokens      =*/ 0,
+
                     /*.language             =*/ "en",
 
                     /*.greedy               =*/ {
@@ -2455,6 +2458,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
                     /*.speed_up             =*/ false,
                     /*.audio_ctx            =*/ 0,
 
+                    /*.prompt_tokens        =*/ nullptr,
+                    /*.prompt_n_tokens      =*/ 0,
+
                     /*.language             =*/ "en",
 
                     /*.greedy               =*/ {
@@ -2584,6 +2590,15 @@ int whisper_full(
         prompt_past.clear();
     }
 
+    // Prepend the prompt tokens to the prompt_past
+    if (params.prompt_tokens && params.prompt_n_tokens > 0) {
+        // Parse tokens from the pointer (it points to an std::vector)
+        for (int i = 0; i < params.prompt_n_tokens; i++) {
+            prompt_past.push_back(params.prompt_tokens[i]);
+        }
+        std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
+    }
+
     // overwrite audio_ctx
     ctx->exp_n_audio_ctx = params.audio_ctx;
 
index 88cc71131442de8000d221b8a8ec4c1f07147001..1b2a042bbe5e3834f0bf39003403311a10c70559 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -208,6 +208,10 @@ extern "C" {
         bool speed_up;  // speed-up the audio by 2x using Phase Vocoder
         int  audio_ctx; // overwrite the audio context size (0 = use default)
 
+        // std::vector<whisper_token>: tokens to provide the whisper model as initial prompt
+        const whisper_token * prompt_tokens;
+        int prompt_n_tokens;
+
         const char * language;
 
         struct {