]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ref #10 : option to keep context in "stream" example
authorGeorgi Gerganov <redacted>
Fri, 7 Oct 2022 19:30:44 +0000 (22:30 +0300)
committerGeorgi Gerganov <redacted>
Fri, 7 Oct 2022 19:30:44 +0000 (22:30 +0300)
Seems the results become worse when we keep the context, so by default
this is not enabled

stream.cpp
whisper.cpp
whisper.h

index 18a3f6f81b6e4087fcffd53eba45276612999a69..f927819f8d499a42f14334a52db7597c0cb238a4 100644 (file)
@@ -40,6 +40,7 @@ struct whisper_params {
 
     bool verbose              = false;
     bool translate            = false;
+    bool no_context           = true;
     bool print_special_tokens = false;
     bool no_timestamps        = true;
 
@@ -64,6 +65,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
             params.verbose = true;
         } else if (arg == "--translate") {
             params.translate = true;
+        } else if (arg == "-kc" || arg == "--keep-context") {
+            params.no_context = false;
         } else if (arg == "-l" || arg == "--language") {
             params.language = argv[++i];
             if (whisper_lang_id(params.language.c_str()) == -1) {
@@ -103,6 +106,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
     fprintf(stderr, "            --step N         audio step size in milliseconds (default: %d)\n", params.step_ms);
     fprintf(stderr, "  -v,       --verbose        verbose output\n");
     fprintf(stderr, "            --translate      translate from source language to english\n");
+    fprintf(stderr, "  -nc,      --no-context     disable context from earlier audio (default: false)\n");
     fprintf(stderr, "  -ps,      --print_special  print special tokens\n");
     fprintf(stderr, "  -nt,      --no_timestamps  do not print timestamps\n");
     fprintf(stderr, "  -l LANG,  --language LANG  spoken language (default: %s)\n", params.language.c_str());
@@ -273,6 +277,7 @@ int main(int argc, char ** argv) {
             wparams.print_realtime       = false;
             wparams.print_timestamps     = !params.no_timestamps;
             wparams.translate            = params.translate;
+            wparams.no_context           = params.no_context;
             wparams.language             = params.language.c_str();
             wparams.n_threads            = params.n_threads;
 
index ca0c6a44a189b7b75e34999c8447ce569088cc84..9913ab6abc2e87118d371cc5ba38c0c567f6650f 100644 (file)
@@ -405,6 +405,8 @@ struct whisper_context {
 
     std::vector<whisper_result>  result_cur;
     std::vector<whisper_segment> result_all;
+
+    std::vector<whisper_token> prompt_past;
 };
 
 // load the model from a ggml file
@@ -1020,8 +1022,6 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
 //   - model:      the model
 //   - n_threads:  number of threads to use
 //   - mel_offset: offset in the mel spectrogram (i.e. audio offset)
-//   - mel_inp:    input mel spectrogram
-//   - features:   output encoded features
 //
 bool whisper_encode(
               whisper_context & wctx,
@@ -1405,10 +1405,9 @@ bool whisper_encode(
 //
 //   - model:      the model
 //   - n_threads:  number of threads to use
-//   - n_past:     prompt length
-//   - prompt:     text prompt
-//   - logits_out: output logits
-//   - probs_out:  output probabilities
+//   - tokens:     text prompt
+//   - n_tokens:   number of tokens in the prompt
+//   - n_past:     number of past tokens to prefix the prompt with
 //
 bool whisper_decode(
               whisper_context & wctx,
@@ -2259,6 +2258,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_decode_strat
                     .offset_ms = 0,
 
                     .translate            = false,
+                    .no_context           = false,
                     .print_special_tokens = false,
                     .print_progress       = true,
                     .print_realtime       = false,
@@ -2279,6 +2279,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_decode_strat
                     .offset_ms = 0,
 
                     .translate            = false,
+                    .no_context           = false,
                     .print_special_tokens = false,
                     .print_progress       = true,
                     .print_realtime       = false,
@@ -2297,6 +2298,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_decode_strat
 
     return result;
 }
+
 int whisper_full(
         struct whisper_context * ctx,
         struct whisper_full_params params,
@@ -2309,7 +2311,10 @@ int whisper_full(
     }
 
     // the accumulated text context so far
-    std::vector<whisper_token> prompt_past = { };
+    auto & prompt_past = ctx->prompt_past;
+    if (params.no_context) {
+        prompt_past.clear();
+    }
 
     // these tokens determine the task that will be performed
     std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
index 78e08b76ae9e4279d64b3e3413ffae8855586491..79df0e04a26376864ca246bf9ad7b03de5b888db 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -105,6 +105,7 @@ extern "C" {
         int offset_ms;
 
         bool translate;
+        bool no_context;
         bool print_special_tokens;
         bool print_progress;
         bool print_realtime;