]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
stream : "-kc" now enables context keeping from previous segment (#90)
authorGeorgi Gerganov <redacted>
Tue, 22 Nov 2022 16:20:05 +0000 (18:20 +0200)
committerGeorgi Gerganov <redacted>
Tue, 22 Nov 2022 16:21:15 +0000 (18:21 +0200)
By default, the context keeping is disabled

examples/stream/stream.cpp
whisper.cpp
whisper.h

index 32f93d6fcbd4b9b7b31eb0c21b1e9d662d54378d..9efc83cdb6edea816b1b55cab465351e9fba95bc 100644 (file)
@@ -336,7 +336,7 @@ int main(int argc, char ** argv) {
             wparams.print_realtime       = false;
             wparams.print_timestamps     = !params.no_timestamps;
             wparams.translate            = params.translate;
-            wparams.no_context           = params.no_context;
+            wparams.no_context           = true;
             wparams.single_segment       = true;
             wparams.max_tokens           = params.max_tokens;
             wparams.language             = params.language.c_str();
@@ -345,9 +345,9 @@ int main(int argc, char ** argv) {
             wparams.audio_ctx            = params.audio_ctx;
             wparams.speed_up             = params.speed_up;
 
-            wparams.prompt_tokens        = prompt_tokens.data();
-            wparams.prompt_n_tokens      = prompt_tokens.size();
-            
+            wparams.prompt_tokens        = params.no_context ? nullptr : prompt_tokens.data();
+            wparams.prompt_n_tokens      = params.no_context ? 0       : prompt_tokens.size();
+
             if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
                 fprintf(stderr, "%s: failed to process audio\n", argv[0]);
                 return 6;
@@ -399,12 +399,15 @@ int main(int argc, char ** argv) {
                 pcmf32_old = std::vector<float>(pcmf32.end() - n_samples_keep, pcmf32.end());
 
                 // Add tokens of the last full length segment as the prompt
-                prompt_tokens.clear();
-                const int n_segments = whisper_full_n_segments(ctx);
-                for (int i = 0; i < n_segments; ++i) {
-                    const int token_count = whisper_full_n_tokens(ctx, i);
-                    for (int j = 0; j < token_count; ++j) {
-                        prompt_tokens.push_back(whisper_full_get_token_id(ctx, i, j));
+                if (!params.no_context) {
+                    prompt_tokens.clear();
+
+                    const int n_segments = whisper_full_n_segments(ctx);
+                    for (int i = 0; i < n_segments; ++i) {
+                        const int token_count = whisper_full_n_tokens(ctx, i);
+                        for (int j = 0; j < token_count; ++j) {
+                            prompt_tokens.push_back(whisper_full_get_token_id(ctx, i, j));
+                        }
                     }
                 }
             }
index 28c5d26a02ce98df839a964bb75e883b785ab849..6c2e0e0e65e91aae4ea842ef5958f8ad3101acee 100644 (file)
@@ -2590,9 +2590,9 @@ int whisper_full(
         prompt_past.clear();
     }
 
-    // Prepend the prompt tokens to the prompt_past
+    // prepend the prompt tokens to the prompt_past
     if (params.prompt_tokens && params.prompt_n_tokens > 0) {
-        // Parse tokens from the pointer (it points to an std::vector)
+        // parse tokens from the pointer
         for (int i = 0; i < params.prompt_n_tokens; i++) {
             prompt_past.push_back(params.prompt_tokens[i]);
         }
index 1b2a042bbe5e3834f0bf39003403311a10c70559..58a88726d0c07ddc4ffdc649b1ea335bd02b6902 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -208,7 +208,8 @@ extern "C" {
         bool speed_up;  // speed-up the audio by 2x using Phase Vocoder
         int  audio_ctx; // overwrite the audio context size (0 = use default)
 
-        // std::vector<whisper_token>: tokens to provide the whisper model as initial prompt
+        // tokens to provide the whisper model as initial prompt
+        // these are prepended to any existing text context from a previous call
         const whisper_token * prompt_tokens;
         int prompt_n_tokens;