]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : add support for --carry-initial-prompt (#3395)
authorAndreas Lubbe <redacted>
Fri, 10 Oct 2025 16:51:15 +0000 (18:51 +0200)
committerGitHub <redacted>
Fri, 10 Oct 2025 16:51:15 +0000 (19:51 +0300)
* Add support for --carry-initial-prompt

* PR fixes for ruby and go

* Refactoring for readability

* WIP 1

* WIP 2

* PR fixes

* More PR fixes

* PR fix

* Further simplification

* d'oh

* One more logic fix

* Update src/whisper.cpp

Co-authored-by: Georgi Gerganov <redacted>
* Truncate prompt_past0 upon initialization

* Slight simplification

---------

Co-authored-by: Georgi Gerganov <redacted>
bindings/go/params.go
bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java
bindings/ruby/ext/ruby_whisper_params.c
bindings/ruby/sig/whisper.rbs
bindings/ruby/test/test_params.rb
examples/cli/cli.cpp
include/whisper.h
src/whisper.cpp

index 95c5bfaf9345d0177667d8754e090a659b335ee0..d8dee57e331101520256f32ea319dc5be52778b0 100644 (file)
@@ -47,6 +47,7 @@ func (p *Params) SetPrintTimestamps(v bool) {
        p.print_timestamps = toBool(v)
 }
 
+
 // Set language id
 func (p *Params) SetLanguage(lang int) error {
        if lang == -1 {
@@ -146,6 +147,10 @@ func (p *Params) SetInitialPrompt(prompt string) {
        p.initial_prompt = C.CString(prompt)
 }
 
+func (p *Params) SetCarryInitialPrompt(v bool) {
+       p.carry_initial_prompt = toBool(v)
+}
+
 ///////////////////////////////////////////////////////////////////////////////
 // PRIVATE METHODS
 
@@ -199,6 +204,9 @@ func (p *Params) String() string {
        if p.token_timestamps {
                str += " token_timestamps"
        }
+       if p.carry_initial_prompt {
+               str += " carry_initial_prompt"
+       }
 
        return str + ">"
 }
index 498ff126037946d4c43cad33a71b997bbffa3a87..76ce80fb4ccf93861371149d9d9c0e43caab0115 100644 (file)
@@ -157,6 +157,8 @@ public class WhisperFullParams extends Structure {
     /** Tokens to provide to the whisper decoder as an initial prompt.\r
      * These are prepended to any existing text context from a previous call. */\r
     public String initial_prompt;\r
+    /** Always prepend initial_prompt for every decode chunk. */\r
+    public CBool carry_initial_prompt;\r
 \r
     /** Prompt tokens. (int*) */\r
     public Pointer prompt_tokens;\r
@@ -336,8 +338,8 @@ public class WhisperFullParams extends Structure {
                 "no_timestamps", "single_segment", "print_special",\r
                 "print_progress", "print_realtime", "print_timestamps",\r
                 "token_timestamps", "thold_pt", "thold_ptsum", "max_len",\r
-                "split_on_word", "max_tokens", "debug_mode", "audio_ctx", \r
-                "tdrz_enable", "suppress_regex", "initial_prompt",\r
+                "split_on_word", "max_tokens", "debug_mode", "audio_ctx",\r
+                "tdrz_enable", "suppress_regex", "initial_prompt", "carry_initial_prompt",\r
                 "prompt_tokens", "prompt_n_tokens", "language", "detect_language",\r
                 "suppress_blank", "suppress_nst", "temperature",\r
                 "max_initial_ts", "length_penalty", "temperature_inc",\r
index 882c68d042f72be76b4aebe6fea9598c694aa322..70417cb166449e9e844a6a629e206e7be9da452c 100644 (file)
@@ -26,7 +26,7 @@
   rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \
   rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1);
 
-#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 36
+#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 37
 
 extern VALUE cParams;
 extern VALUE cVADParams;
@@ -46,6 +46,7 @@ static ID id_print_special;
 static ID id_print_progress;
 static ID id_print_realtime;
 static ID id_print_timestamps;
+static ID id_carry_initial_prompt;
 static ID id_suppress_blank;
 static ID id_suppress_nst;
 static ID id_token_timestamps;
@@ -455,6 +456,26 @@ ruby_whisper_params_get_print_timestamps(VALUE self)
 {
   BOOL_PARAMS_GETTER(self, print_timestamps)
 }
+
+/*
+ *  call-seq:
+ *    carry_initial_prompt -> true or false
+ */
+static VALUE
+ruby_whisper_params_get_carry_initial_prompt(VALUE self)
+{
+  BOOL_PARAMS_GETTER(self, carry_initial_prompt)
+}
+
+/*
+ *  call-seq:
+ *    carry_initial_prompt = bool -> bool
+ */
+static VALUE
+ruby_whisper_params_set_carry_initial_prompt(VALUE self, VALUE value)
+{
+  BOOL_PARAMS_SETTER(self, carry_initial_prompt, value)
+}
 /*
  * call-seq:
  *   suppress_blank = force_suppress -> force_suppress
@@ -1168,6 +1189,7 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
       SET_PARAM_IF_SAME(max_len)
       SET_PARAM_IF_SAME(split_on_word)
       SET_PARAM_IF_SAME(initial_prompt)
+      SET_PARAM_IF_SAME(carry_initial_prompt)
       SET_PARAM_IF_SAME(offset)
       SET_PARAM_IF_SAME(duration)
       SET_PARAM_IF_SAME(max_text_tokens)
@@ -1303,28 +1325,29 @@ init_ruby_whisper_params(VALUE *mWhisper)
   DEFINE_PARAM(max_len, 11)
   DEFINE_PARAM(split_on_word, 12)
   DEFINE_PARAM(initial_prompt, 13)
-  DEFINE_PARAM(diarize, 14)
-  DEFINE_PARAM(offset, 15)
-  DEFINE_PARAM(duration, 16)
-  DEFINE_PARAM(max_text_tokens, 17)
-  DEFINE_PARAM(temperature, 18)
-  DEFINE_PARAM(max_initial_ts, 19)
-  DEFINE_PARAM(length_penalty, 20)
-  DEFINE_PARAM(temperature_inc, 21)
-  DEFINE_PARAM(entropy_thold, 22)
-  DEFINE_PARAM(logprob_thold, 23)
-  DEFINE_PARAM(no_speech_thold, 24)
-  DEFINE_PARAM(new_segment_callback, 25)
-  DEFINE_PARAM(new_segment_callback_user_data, 26)
-  DEFINE_PARAM(progress_callback, 27)
-  DEFINE_PARAM(progress_callback_user_data, 28)
-  DEFINE_PARAM(encoder_begin_callback, 29)
-  DEFINE_PARAM(encoder_begin_callback_user_data, 30)
-  DEFINE_PARAM(abort_callback, 31)
-  DEFINE_PARAM(abort_callback_user_data, 32)
-  DEFINE_PARAM(vad, 33)
-  DEFINE_PARAM(vad_model_path, 34)
-  DEFINE_PARAM(vad_params, 35)
+  DEFINE_PARAM(carry_initial_prompt, 14)
+  DEFINE_PARAM(diarize, 15)
+  DEFINE_PARAM(offset, 16)
+  DEFINE_PARAM(duration, 17)
+  DEFINE_PARAM(max_text_tokens, 18)
+  DEFINE_PARAM(temperature, 19)
+  DEFINE_PARAM(max_initial_ts, 20)
+  DEFINE_PARAM(length_penalty, 21)
+  DEFINE_PARAM(temperature_inc, 22)
+  DEFINE_PARAM(entropy_thold, 23)
+  DEFINE_PARAM(logprob_thold, 24)
+  DEFINE_PARAM(no_speech_thold, 25)
+  DEFINE_PARAM(new_segment_callback, 26)
+  DEFINE_PARAM(new_segment_callback_user_data, 27)
+  DEFINE_PARAM(progress_callback, 28)
+  DEFINE_PARAM(progress_callback_user_data, 29)
+  DEFINE_PARAM(encoder_begin_callback, 30)
+  DEFINE_PARAM(encoder_begin_callback_user_data, 31)
+  DEFINE_PARAM(abort_callback, 32)
+  DEFINE_PARAM(abort_callback_user_data, 33)
+  DEFINE_PARAM(vad, 34)
+  DEFINE_PARAM(vad_model_path, 35)
+  DEFINE_PARAM(vad_params, 36)
 
   rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
   rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
index 0489432a24903d8cbd16b407adee1d9ac258948d..d5905dd7037addb9ed4583964e3613bab51a2ceb 100644 (file)
@@ -138,6 +138,7 @@ module Whisper
       ?max_len: Integer,
       ?split_on_word: boolish,
       ?initial_prompt: string | nil,
+      ?carry_initial_prompt: boolish,
       ?diarize: boolish,
       ?offset: Integer,
       ?duration: Integer,
@@ -236,6 +237,7 @@ module Whisper
     def split_on_word: () -> (true | false)
 
     def initial_prompt=: (_ToS) -> _ToS
+    def carry_initial_prompt=: (boolish) -> boolish
 
     # Tokens to provide to the whisper decoder as initial prompt
     # these are prepended to any existing text context from a previous call
@@ -243,6 +245,7 @@ module Whisper
     # Maximum of whisper_n_text_ctx()/2 tokens are used (typically 224).
     #
     def initial_prompt: () -> (String | nil)
+    def carry_initial_prompt: () -> (true | false)
 
     def diarize=: (boolish) -> boolish
 
index d5c5d140e8c5c5566b980c15b0ccf9bbc68d136c..4dd9780de7d28918a3af0483a6f92cad951228dc 100644 (file)
@@ -16,6 +16,7 @@ class TestParams < TestBase
     :max_len,
     :split_on_word,
     :initial_prompt,
+    :carry_initial_prompt,
     :diarize,
     :offset,
     :duration,
@@ -119,6 +120,13 @@ class TestParams < TestBase
     assert !@params.print_timestamps
   end
 
+  def test_carry_initial_prompt
+    @params.carry_initial_prompt = true
+    assert @params.carry_initial_prompt
+    @params.carry_initial_prompt = false
+    assert !@params.carry_initial_prompt
+  end
+
   def test_suppress_blank
     @params.suppress_blank = true
     assert @params.suppress_blank
index 0739cacfd1b004b771a02b53d5f46a3631c704a7..9a54742fe1d7db5e1371bd4a98cefae1dd517e27 100644 (file)
@@ -5,6 +5,7 @@
 #include "grammar-parser.h"
 
 #include <cmath>
+#include <algorithm>
 #include <fstream>
 #include <cstdio>
 #include <string>
@@ -77,6 +78,7 @@ struct whisper_params {
     bool use_gpu         = true;
     bool flash_attn      = true;
     bool suppress_nst    = false;
+    bool carry_initial_prompt = false;
 
     std::string language  = "en";
     std::string prompt;
@@ -145,60 +147,61 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
             exit(0);
         }
         #define ARGV_NEXT (((i + 1) < argc) ? argv[++i] : requires_value_error(arg))
-        else if (arg == "-t"    || arg == "--threads")         { params.n_threads       = std::stoi(ARGV_NEXT); }
-        else if (arg == "-p"    || arg == "--processors")      { params.n_processors    = std::stoi(ARGV_NEXT); }
-        else if (arg == "-ot"   || arg == "--offset-t")        { params.offset_t_ms     = std::stoi(ARGV_NEXT); }
-        else if (arg == "-on"   || arg == "--offset-n")        { params.offset_n        = std::stoi(ARGV_NEXT); }
-        else if (arg == "-d"    || arg == "--duration")        { params.duration_ms     = std::stoi(ARGV_NEXT); }
-        else if (arg == "-mc"   || arg == "--max-context")     { params.max_context     = std::stoi(ARGV_NEXT); }
-        else if (arg == "-ml"   || arg == "--max-len")         { params.max_len         = std::stoi(ARGV_NEXT); }
-        else if (arg == "-bo"   || arg == "--best-of")         { params.best_of         = std::stoi(ARGV_NEXT); }
-        else if (arg == "-bs"   || arg == "--beam-size")       { params.beam_size       = std::stoi(ARGV_NEXT); }
-        else if (arg == "-ac"   || arg == "--audio-ctx")       { params.audio_ctx       = std::stoi(ARGV_NEXT); }
-        else if (arg == "-wt"   || arg == "--word-thold")      { params.word_thold      = std::stof(ARGV_NEXT); }
-        else if (arg == "-et"   || arg == "--entropy-thold")   { params.entropy_thold   = std::stof(ARGV_NEXT); }
-        else if (arg == "-lpt"  || arg == "--logprob-thold")   { params.logprob_thold   = std::stof(ARGV_NEXT); }
-        else if (arg == "-nth"  || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(ARGV_NEXT); }
-        else if (arg == "-tp"   || arg == "--temperature")     { params.temperature     = std::stof(ARGV_NEXT); }
-        else if (arg == "-tpi"  || arg == "--temperature-inc") { params.temperature_inc = std::stof(ARGV_NEXT); }
-        else if (arg == "-debug"|| arg == "--debug-mode")      { params.debug_mode      = true; }
-        else if (arg == "-tr"   || arg == "--translate")       { params.translate       = true; }
-        else if (arg == "-di"   || arg == "--diarize")         { params.diarize         = true; }
-        else if (arg == "-tdrz" || arg == "--tinydiarize")     { params.tinydiarize     = true; }
-        else if (arg == "-sow"  || arg == "--split-on-word")   { params.split_on_word   = true; }
-        else if (arg == "-nf"   || arg == "--no-fallback")     { params.no_fallback     = true; }
-        else if (arg == "-otxt" || arg == "--output-txt")      { params.output_txt      = true; }
-        else if (arg == "-ovtt" || arg == "--output-vtt")      { params.output_vtt      = true; }
-        else if (arg == "-osrt" || arg == "--output-srt")      { params.output_srt      = true; }
-        else if (arg == "-owts" || arg == "--output-words")    { params.output_wts      = true; }
-        else if (arg == "-olrc" || arg == "--output-lrc")      { params.output_lrc      = true; }
-        else if (arg == "-fp"   || arg == "--font-path")       { params.font_path       = ARGV_NEXT; }
-        else if (arg == "-ocsv" || arg == "--output-csv")      { params.output_csv      = true; }
-        else if (arg == "-oj"   || arg == "--output-json")     { params.output_jsn      = true; }
-        else if (arg == "-ojf"  || arg == "--output-json-full"){ params.output_jsn_full = params.output_jsn = true; }
-        else if (arg == "-of"   || arg == "--output-file")     { params.fname_out.emplace_back(ARGV_NEXT); }
-        else if (arg == "-np"   || arg == "--no-prints")       { params.no_prints       = true; }
-        else if (arg == "-ps"   || arg == "--print-special")   { params.print_special   = true; }
-        else if (arg == "-pc"   || arg == "--print-colors")    { params.print_colors    = true; }
-        else if (                  arg == "--print-confidence"){ params.print_confidence= true; }
-        else if (arg == "-pp"   || arg == "--print-progress")  { params.print_progress  = true; }
-        else if (arg == "-nt"   || arg == "--no-timestamps")   { params.no_timestamps   = true; }
-        else if (arg == "-l"    || arg == "--language")        { params.language        = whisper_param_turn_lowercase(ARGV_NEXT); }
-        else if (arg == "-dl"   || arg == "--detect-language") { params.detect_language = true; }
-        else if (                  arg == "--prompt")          { params.prompt          = ARGV_NEXT; }
-        else if (arg == "-m"    || arg == "--model")           { params.model           = ARGV_NEXT; }
-        else if (arg == "-f"    || arg == "--file")            { params.fname_inp.emplace_back(ARGV_NEXT); }
-        else if (arg == "-oved" || arg == "--ov-e-device")     { params.openvino_encode_device = ARGV_NEXT; }
-        else if (arg == "-dtw"  || arg == "--dtw")             { params.dtw             = ARGV_NEXT; }
-        else if (arg == "-ls"   || arg == "--log-score")       { params.log_score       = true; }
-        else if (arg == "-ng"   || arg == "--no-gpu")          { params.use_gpu         = false; }
-        else if (arg == "-fa"   || arg == "--flash-attn")      { params.flash_attn      = true; }
-        else if (arg == "-nfa"  || arg == "--no-flash-attn")   { params.flash_attn      = false; }
-        else if (arg == "-sns"  || arg == "--suppress-nst")    { params.suppress_nst    = true; }
-        else if (                  arg == "--suppress-regex")  { params.suppress_regex  = ARGV_NEXT; }
-        else if (                  arg == "--grammar")         { params.grammar         = ARGV_NEXT; }
-        else if (                  arg == "--grammar-rule")    { params.grammar_rule    = ARGV_NEXT; }
-        else if (                  arg == "--grammar-penalty") { params.grammar_penalty = std::stof(ARGV_NEXT); }
+        else if (arg == "-t"    || arg == "--threads")              { params.n_threads       = std::stoi(ARGV_NEXT); }
+        else if (arg == "-p"    || arg == "--processors")           { params.n_processors    = std::stoi(ARGV_NEXT); }
+        else if (arg == "-ot"   || arg == "--offset-t")             { params.offset_t_ms     = std::stoi(ARGV_NEXT); }
+        else if (arg == "-on"   || arg == "--offset-n")             { params.offset_n        = std::stoi(ARGV_NEXT); }
+        else if (arg == "-d"    || arg == "--duration")             { params.duration_ms     = std::stoi(ARGV_NEXT); }
+        else if (arg == "-mc"   || arg == "--max-context")          { params.max_context     = std::stoi(ARGV_NEXT); }
+        else if (arg == "-ml"   || arg == "--max-len")              { params.max_len         = std::stoi(ARGV_NEXT); }
+        else if (arg == "-bo"   || arg == "--best-of")              { params.best_of         = std::stoi(ARGV_NEXT); }
+        else if (arg == "-bs"   || arg == "--beam-size")            { params.beam_size       = std::stoi(ARGV_NEXT); }
+        else if (arg == "-ac"   || arg == "--audio-ctx")            { params.audio_ctx       = std::stoi(ARGV_NEXT); }
+        else if (arg == "-wt"   || arg == "--word-thold")           { params.word_thold      = std::stof(ARGV_NEXT); }
+        else if (arg == "-et"   || arg == "--entropy-thold")        { params.entropy_thold   = std::stof(ARGV_NEXT); }
+        else if (arg == "-lpt"  || arg == "--logprob-thold")        { params.logprob_thold   = std::stof(ARGV_NEXT); }
+        else if (arg == "-nth"  || arg == "--no-speech-thold")      { params.no_speech_thold = std::stof(ARGV_NEXT); }
+        else if (arg == "-tp"   || arg == "--temperature")          { params.temperature     = std::stof(ARGV_NEXT); }
+        else if (arg == "-tpi"  || arg == "--temperature-inc")      { params.temperature_inc = std::stof(ARGV_NEXT); }
+        else if (arg == "-debug"|| arg == "--debug-mode")           { params.debug_mode      = true; }
+        else if (arg == "-tr"   || arg == "--translate")            { params.translate       = true; }
+        else if (arg == "-di"   || arg == "--diarize")              { params.diarize         = true; }
+        else if (arg == "-tdrz" || arg == "--tinydiarize")          { params.tinydiarize     = true; }
+        else if (arg == "-sow"  || arg == "--split-on-word")        { params.split_on_word   = true; }
+        else if (arg == "-nf"   || arg == "--no-fallback")          { params.no_fallback     = true; }
+        else if (arg == "-otxt" || arg == "--output-txt")           { params.output_txt      = true; }
+        else if (arg == "-ovtt" || arg == "--output-vtt")           { params.output_vtt      = true; }
+        else if (arg == "-osrt" || arg == "--output-srt")           { params.output_srt      = true; }
+        else if (arg == "-owts" || arg == "--output-words")         { params.output_wts      = true; }
+        else if (arg == "-olrc" || arg == "--output-lrc")           { params.output_lrc      = true; }
+        else if (arg == "-fp"   || arg == "--font-path")            { params.font_path       = ARGV_NEXT; }
+        else if (arg == "-ocsv" || arg == "--output-csv")           { params.output_csv      = true; }
+        else if (arg == "-oj"   || arg == "--output-json")          { params.output_jsn      = true; }
+        else if (arg == "-ojf"  || arg == "--output-json-full")     { params.output_jsn_full = params.output_jsn = true; }
+        else if (arg == "-of"   || arg == "--output-file")          { params.fname_out.emplace_back(ARGV_NEXT); }
+        else if (arg == "-np"   || arg == "--no-prints")            { params.no_prints       = true; }
+        else if (arg == "-ps"   || arg == "--print-special")        { params.print_special   = true; }
+        else if (arg == "-pc"   || arg == "--print-colors")         { params.print_colors    = true; }
+        else if (                  arg == "--print-confidence")     { params.print_confidence= true; }
+        else if (arg == "-pp"   || arg == "--print-progress")       { params.print_progress  = true; }
+        else if (arg == "-nt"   || arg == "--no-timestamps")        { params.no_timestamps   = true; }
+        else if (arg == "-l"    || arg == "--language")             { params.language        = whisper_param_turn_lowercase(ARGV_NEXT); }
+        else if (arg == "-dl"   || arg == "--detect-language")      { params.detect_language = true; }
+        else if (                  arg == "--prompt")               { params.prompt          = ARGV_NEXT; }
+        else if (                  arg == "--carry-initial-prompt") { params.carry_initial_prompt = true; }
+        else if (arg == "-m"    || arg == "--model")                { params.model           = ARGV_NEXT; }
+        else if (arg == "-f"    || arg == "--file")                 { params.fname_inp.emplace_back(ARGV_NEXT); }
+        else if (arg == "-oved" || arg == "--ov-e-device")          { params.openvino_encode_device = ARGV_NEXT; }
+        else if (arg == "-dtw"  || arg == "--dtw")                  { params.dtw             = ARGV_NEXT; }
+        else if (arg == "-ls"   || arg == "--log-score")            { params.log_score       = true; }
+        else if (arg == "-ng"   || arg == "--no-gpu")               { params.use_gpu         = false; }
+        else if (arg == "-fa"   || arg == "--flash-attn")           { params.flash_attn      = true; }
+        else if (arg == "-nfa"  || arg == "--no-flash-attn")        { params.flash_attn      = false; }
+        else if (arg == "-sns"  || arg == "--suppress-nst")         { params.suppress_nst    = true; }
+        else if (                  arg == "--suppress-regex")       { params.suppress_regex  = ARGV_NEXT; }
+        else if (                  arg == "--grammar")              { params.grammar         = ARGV_NEXT; }
+        else if (                  arg == "--grammar-rule")         { params.grammar_rule    = ARGV_NEXT; }
+        else if (                  arg == "--grammar-penalty")      { params.grammar_penalty = std::stof(ARGV_NEXT); }
         // Voice Activity Detection (VAD)
         else if (                  arg == "--vad")                         { params.vad                         = true; }
         else if (arg == "-vm"   || arg == "--vad-model")                   { params.vad_model                   = ARGV_NEXT; }
@@ -224,61 +227,62 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params
     fprintf(stderr, "supported audio formats: flac, mp3, ogg, wav\n");
     fprintf(stderr, "\n");
     fprintf(stderr, "options:\n");
-    fprintf(stderr, "  -h,        --help              [default] show this help message and exit\n");
-    fprintf(stderr, "  -t N,      --threads N         [%-7d] number of threads to use during computation\n",    params.n_threads);
-    fprintf(stderr, "  -p N,      --processors N      [%-7d] number of processors to use during computation\n", params.n_processors);
-    fprintf(stderr, "  -ot N,     --offset-t N        [%-7d] time offset in milliseconds\n",                    params.offset_t_ms);
-    fprintf(stderr, "  -on N,     --offset-n N        [%-7d] segment index offset\n",                           params.offset_n);
-    fprintf(stderr, "  -d  N,     --duration N        [%-7d] duration of audio to process in milliseconds\n",   params.duration_ms);
-    fprintf(stderr, "  -mc N,     --max-context N     [%-7d] maximum number of text context tokens to store\n", params.max_context);
-    fprintf(stderr, "  -ml N,     --max-len N         [%-7d] maximum segment length in characters\n",           params.max_len);
-    fprintf(stderr, "  -sow,      --split-on-word     [%-7s] split on word rather than on token\n",             params.split_on_word ? "true" : "false");
-    fprintf(stderr, "  -bo N,     --best-of N         [%-7d] number of best candidates to keep\n",              params.best_of);
-    fprintf(stderr, "  -bs N,     --beam-size N       [%-7d] beam size for beam search\n",                      params.beam_size);
-    fprintf(stderr, "  -ac N,     --audio-ctx N       [%-7d] audio context size (0 - all)\n",                   params.audio_ctx);
-    fprintf(stderr, "  -wt N,     --word-thold N      [%-7.2f] word timestamp probability threshold\n",         params.word_thold);
-    fprintf(stderr, "  -et N,     --entropy-thold N   [%-7.2f] entropy threshold for decoder fail\n",           params.entropy_thold);
-    fprintf(stderr, "  -lpt N,    --logprob-thold N   [%-7.2f] log probability threshold for decoder fail\n",   params.logprob_thold);
-    fprintf(stderr, "  -nth N,    --no-speech-thold N [%-7.2f] no speech threshold\n",                          params.no_speech_thold);
-    fprintf(stderr, "  -tp,       --temperature N     [%-7.2f] The sampling temperature, between 0 and 1\n",    params.temperature);
-    fprintf(stderr, "  -tpi,      --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n",params.temperature_inc);
-    fprintf(stderr, "  -debug,    --debug-mode        [%-7s] enable debug mode (eg. dump log_mel)\n",           params.debug_mode ? "true" : "false");
-    fprintf(stderr, "  -tr,       --translate         [%-7s] translate from source language to english\n",      params.translate ? "true" : "false");
-    fprintf(stderr, "  -di,       --diarize           [%-7s] stereo audio diarization\n",                       params.diarize ? "true" : "false");
-    fprintf(stderr, "  -tdrz,     --tinydiarize       [%-7s] enable tinydiarize (requires a tdrz model)\n",     params.tinydiarize ? "true" : "false");
-    fprintf(stderr, "  -nf,       --no-fallback       [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
-    fprintf(stderr, "  -otxt,     --output-txt        [%-7s] output result in a text file\n",                   params.output_txt ? "true" : "false");
-    fprintf(stderr, "  -ovtt,     --output-vtt        [%-7s] output result in a vtt file\n",                    params.output_vtt ? "true" : "false");
-    fprintf(stderr, "  -osrt,     --output-srt        [%-7s] output result in a srt file\n",                    params.output_srt ? "true" : "false");
-    fprintf(stderr, "  -olrc,     --output-lrc        [%-7s] output result in a lrc file\n",                    params.output_lrc ? "true" : "false");
-    fprintf(stderr, "  -owts,     --output-words      [%-7s] output script for generating karaoke video\n",     params.output_wts ? "true" : "false");
-    fprintf(stderr, "  -fp,       --font-path         [%-7s] path to a monospace font for karaoke video\n",     params.font_path.c_str());
-    fprintf(stderr, "  -ocsv,     --output-csv        [%-7s] output result in a CSV file\n",                    params.output_csv ? "true" : "false");
-    fprintf(stderr, "  -oj,       --output-json       [%-7s] output result in a JSON file\n",                   params.output_jsn ? "true" : "false");
-    fprintf(stderr, "  -ojf,      --output-json-full  [%-7s] include more information in the JSON file\n",      params.output_jsn_full ? "true" : "false");
-    fprintf(stderr, "  -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n",      "");
-    fprintf(stderr, "  -np,       --no-prints         [%-7s] do not print anything other than the results\n",   params.no_prints ? "true" : "false");
-    fprintf(stderr, "  -ps,       --print-special     [%-7s] print special tokens\n",                           params.print_special ? "true" : "false");
-    fprintf(stderr, "  -pc,       --print-colors      [%-7s] print colors\n",                                   params.print_colors ? "true" : "false");
-    fprintf(stderr, "             --print-confidence  [%-7s] print confidence\n",                               params.print_confidence ? "true" : "false");
-    fprintf(stderr, "  -pp,       --print-progress    [%-7s] print progress\n",                                 params.print_progress ? "true" : "false");
-    fprintf(stderr, "  -nt,       --no-timestamps     [%-7s] do not print timestamps\n",                        params.no_timestamps ? "true" : "false");
-    fprintf(stderr, "  -l LANG,   --language LANG     [%-7s] spoken language ('auto' for auto-detect)\n",       params.language.c_str());
-    fprintf(stderr, "  -dl,       --detect-language   [%-7s] exit after automatically detecting language\n",    params.detect_language ? "true" : "false");
-    fprintf(stderr, "             --prompt PROMPT     [%-7s] initial prompt (max n_text_ctx/2 tokens)\n",       params.prompt.c_str());
-    fprintf(stderr, "  -m FNAME,  --model FNAME       [%-7s] model path\n",                                     params.model.c_str());
-    fprintf(stderr, "  -f FNAME,  --file FNAME        [%-7s] input audio file path\n",                            "");
-    fprintf(stderr, "  -oved D,   --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n",  params.openvino_encode_device.c_str());
-    fprintf(stderr, "  -dtw MODEL --dtw MODEL         [%-7s] compute token-level timestamps\n",                 params.dtw.c_str());
-    fprintf(stderr, "  -ls,       --log-score         [%-7s] log best decoder scores of tokens\n",              params.log_score?"true":"false");
-    fprintf(stderr, "  -ng,       --no-gpu            [%-7s] disable GPU\n",                                    params.use_gpu ? "false" : "true");
-    fprintf(stderr, "  -fa,       --flash-attn        [%-7s] enable flash attention\n",                         params.flash_attn ? "true" : "false");
-    fprintf(stderr, "  -nfa,      --no-flash-attn     [%-7s] disable flash attention\n",                        params.flash_attn ? "false" : "true");
-    fprintf(stderr, "  -sns,      --suppress-nst      [%-7s] suppress non-speech tokens\n",                     params.suppress_nst ? "true" : "false");
-    fprintf(stderr, "  --suppress-regex REGEX         [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
-    fprintf(stderr, "  --grammar GRAMMAR              [%-7s] GBNF grammar to guide decoding\n",                 params.grammar.c_str());
-    fprintf(stderr, "  --grammar-rule RULE            [%-7s] top-level GBNF grammar rule name\n",               params.grammar_rule.c_str());
-    fprintf(stderr, "  --grammar-penalty N            [%-7.1f] scales down logits of nongrammar tokens\n",      params.grammar_penalty);
+    fprintf(stderr, "  -h,        --help                 [default] show this help message and exit\n");
+    fprintf(stderr, "  -t N,      --threads N            [%-7d] number of threads to use during computation\n",    params.n_threads);
+    fprintf(stderr, "  -p N,      --processors N         [%-7d] number of processors to use during computation\n", params.n_processors);
+    fprintf(stderr, "  -ot N,     --offset-t N           [%-7d] time offset in milliseconds\n",                    params.offset_t_ms);
+    fprintf(stderr, "  -on N,     --offset-n N           [%-7d] segment index offset\n",                           params.offset_n);
+    fprintf(stderr, "  -d  N,     --duration N           [%-7d] duration of audio to process in milliseconds\n",   params.duration_ms);
+    fprintf(stderr, "  -mc N,     --max-context N        [%-7d] maximum number of text context tokens to store\n", params.max_context);
+    fprintf(stderr, "  -ml N,     --max-len N            [%-7d] maximum segment length in characters\n",           params.max_len);
+    fprintf(stderr, "  -sow,      --split-on-word        [%-7s] split on word rather than on token\n",             params.split_on_word ? "true" : "false");
+    fprintf(stderr, "  -bo N,     --best-of N            [%-7d] number of best candidates to keep\n",              params.best_of);
+    fprintf(stderr, "  -bs N,     --beam-size N          [%-7d] beam size for beam search\n",                      params.beam_size);
+    fprintf(stderr, "  -ac N,     --audio-ctx N          [%-7d] audio context size (0 - all)\n",                   params.audio_ctx);
+    fprintf(stderr, "  -wt N,     --word-thold N         [%-7.2f] word timestamp probability threshold\n",         params.word_thold);
+    fprintf(stderr, "  -et N,     --entropy-thold N      [%-7.2f] entropy threshold for decoder fail\n",           params.entropy_thold);
+    fprintf(stderr, "  -lpt N,    --logprob-thold N      [%-7.2f] log probability threshold for decoder fail\n",   params.logprob_thold);
+    fprintf(stderr, "  -nth N,    --no-speech-thold N    [%-7.2f] no speech threshold\n",                          params.no_speech_thold);
+    fprintf(stderr, "  -tp,       --temperature N        [%-7.2f] The sampling temperature, between 0 and 1\n",    params.temperature);
+    fprintf(stderr, "  -tpi,      --temperature-inc N    [%-7.2f] The increment of temperature, between 0 and 1\n",params.temperature_inc);
+    fprintf(stderr, "  -debug,    --debug-mode           [%-7s] enable debug mode (eg. dump log_mel)\n",           params.debug_mode ? "true" : "false");
+    fprintf(stderr, "  -tr,       --translate            [%-7s] translate from source language to english\n",      params.translate ? "true" : "false");
+    fprintf(stderr, "  -di,       --diarize              [%-7s] stereo audio diarization\n",                       params.diarize ? "true" : "false");
+    fprintf(stderr, "  -tdrz,     --tinydiarize          [%-7s] enable tinydiarize (requires a tdrz model)\n",     params.tinydiarize ? "true" : "false");
+    fprintf(stderr, "  -nf,       --no-fallback          [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
+    fprintf(stderr, "  -otxt,     --output-txt           [%-7s] output result in a text file\n",                   params.output_txt ? "true" : "false");
+    fprintf(stderr, "  -ovtt,     --output-vtt           [%-7s] output result in a vtt file\n",                    params.output_vtt ? "true" : "false");
+    fprintf(stderr, "  -osrt,     --output-srt           [%-7s] output result in a srt file\n",                    params.output_srt ? "true" : "false");
+    fprintf(stderr, "  -olrc,     --output-lrc           [%-7s] output result in a lrc file\n",                    params.output_lrc ? "true" : "false");
+    fprintf(stderr, "  -owts,     --output-words         [%-7s] output script for generating karaoke video\n",     params.output_wts ? "true" : "false");
+    fprintf(stderr, "  -fp,       --font-path            [%-7s] path to a monospace font for karaoke video\n",     params.font_path.c_str());
+    fprintf(stderr, "  -ocsv,     --output-csv           [%-7s] output result in a CSV file\n",                    params.output_csv ? "true" : "false");
+    fprintf(stderr, "  -oj,       --output-json          [%-7s] output result in a JSON file\n",                   params.output_jsn ? "true" : "false");
+    fprintf(stderr, "  -ojf,      --output-json-full     [%-7s] include more information in the JSON file\n",      params.output_jsn_full ? "true" : "false");
+    fprintf(stderr, "  -of FNAME, --output-file FNAME    [%-7s] output file path (without file extension)\n",      "");
+    fprintf(stderr, "  -np,       --no-prints            [%-7s] do not print anything other than the results\n",   params.no_prints ? "true" : "false");
+    fprintf(stderr, "  -ps,       --print-special        [%-7s] print special tokens\n",                           params.print_special ? "true" : "false");
+    fprintf(stderr, "  -pc,       --print-colors         [%-7s] print colors\n",                                   params.print_colors ? "true" : "false");
+    fprintf(stderr, "             --print-confidence     [%-7s] print confidence\n",                               params.print_confidence ? "true" : "false");
+    fprintf(stderr, "  -pp,       --print-progress       [%-7s] print progress\n",                                 params.print_progress ? "true" : "false");
+    fprintf(stderr, "  -nt,       --no-timestamps        [%-7s] do not print timestamps\n",                        params.no_timestamps ? "true" : "false");
+    fprintf(stderr, "  -l LANG,   --language LANG        [%-7s] spoken language ('auto' for auto-detect)\n",       params.language.c_str());
+    fprintf(stderr, "  -dl,       --detect-language      [%-7s] exit after automatically detecting language\n",    params.detect_language ? "true" : "false");
+    fprintf(stderr, "             --prompt PROMPT        [%-7s] initial prompt (max n_text_ctx/2 tokens)\n",       params.prompt.c_str());
+    fprintf(stderr, "             --carry-initial-prompt [%-7s] always prepend initial prompt\n",                  params.carry_initial_prompt ? "true" : "false");
+    fprintf(stderr, "  -m FNAME,  --model FNAME          [%-7s] model path\n",                                     params.model.c_str());
+    fprintf(stderr, "  -f FNAME,  --file FNAME           [%-7s] input audio file path\n",                          "");
+    fprintf(stderr, "  -oved D,   --ov-e-device DNAME    [%-7s] the OpenVINO device used for encode inference\n",  params.openvino_encode_device.c_str());
+    fprintf(stderr, "  -dtw MODEL --dtw MODEL            [%-7s] compute token-level timestamps\n",                 params.dtw.c_str());
+    fprintf(stderr, "  -ls,       --log-score            [%-7s] log best decoder scores of tokens\n",              params.log_score?"true":"false");
+    fprintf(stderr, "  -ng,       --no-gpu               [%-7s] disable GPU\n",                                    params.use_gpu ? "false" : "true");
+    fprintf(stderr, "  -fa,       --flash-attn           [%-7s] enable flash attention\n",                         params.flash_attn ? "true" : "false");
+    fprintf(stderr, "  -nfa,      --no-flash-attn        [%-7s] disable flash attention\n",                        params.flash_attn ? "false" : "true");
+    fprintf(stderr, "  -sns,      --suppress-nst         [%-7s] suppress non-speech tokens\n",                     params.suppress_nst ? "true" : "false");
+    fprintf(stderr, "  --suppress-regex REGEX            [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
+    fprintf(stderr, "  --grammar GRAMMAR                 [%-7s] GBNF grammar to guide decoding\n",                 params.grammar.c_str());
+    fprintf(stderr, "  --grammar-rule RULE               [%-7s] top-level GBNF grammar rule name\n",               params.grammar_rule.c_str());
+    fprintf(stderr, "  --grammar-penalty N               [%-7.1f] scales down logits of nongrammar tokens\n",      params.grammar_penalty);
     // Voice Activity Detection (VAD) parameters
     fprintf(stderr, "\nVoice Activity Detection (VAD) options:\n");
     fprintf(stderr, "             --vad                           [%-7s] enable Voice Activity Detection (VAD)\n",            params.vad ? "true" : "false");
@@ -387,7 +391,11 @@ static void whisper_print_segment_callback(struct whisper_context * ctx, struct
                 const char * text = whisper_full_get_token_text(ctx, i, j);
                 const float  p    = whisper_full_get_token_p   (ctx, i, j);
 
-                const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size()))));
+                const int n_colors = (int) k_colors.size();
+                int raw_col = (int) (std::pow(p, 3)*float(n_colors));
+                if (raw_col < 0) raw_col = 0;
+                if (raw_col > n_colors - 1) raw_col = n_colors - 1;
+                const int col = raw_col;
 
                 printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
             }
@@ -1178,7 +1186,8 @@ int main(int argc, char ** argv) {
 
             wparams.suppress_regex   = params.suppress_regex.empty() ? nullptr : params.suppress_regex.c_str();
 
-            wparams.initial_prompt   = params.prompt.c_str();
+            wparams.initial_prompt       = params.prompt.c_str();
+            wparams.carry_initial_prompt = params.carry_initial_prompt;
 
             wparams.greedy.best_of        = params.best_of;
             wparams.beam_search.beam_size = params.beam_size;
index fcd756a9fe2533c7444b0a0635dfc847e0fdeb07..f4cc6bf7abdd96c0851ed506edd071daf6c0801f 100644 (file)
@@ -525,6 +525,7 @@ extern "C" {
         // use whisper_tokenize() to convert text to tokens
         // maximum of whisper_n_text_ctx()/2 tokens are used (typically 224)
         const char * initial_prompt;
+        bool carry_initial_prompt; // if true, always prepend initial_prompt to every decode window (may reduce conditioning on previous text)
         const whisper_token * prompt_tokens;
         int prompt_n_tokens;
 
index a212b7c9272191b59d7e3207ed3cb5721530da02..18874309b7dab7891e9fcb65d7a04f4c545e72f5 100644 (file)
@@ -140,6 +140,10 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
     } while (0)
 
 #define WHISPER_MAX_DECODERS 8
+
+// temperature below which we condition on past text history
+static constexpr float WHISPER_HISTORY_CONDITIONING_TEMP_CUTOFF = 0.5f;
+
 #define WHISPER_MAX_NODES 4096
 
 static std::string format(const char * fmt, ...) {
@@ -882,7 +886,10 @@ struct whisper_state {
     std::vector<float> logits;
 
     std::vector<whisper_segment> result_all;
-    std::vector<whisper_token>   prompt_past;
+
+    // prompt history split into static prefix (prompt_past0) and dynamic rolling context (prompt_past1)
+    std::vector<whisper_token>   prompt_past0; // static carried initial prompt (if enabled)
+    std::vector<whisper_token>   prompt_past1; // dynamic context from decoded output
 
     int lang_id = 0; // english by default
 
@@ -5922,9 +5929,10 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
 
         /* suppress_regex    =*/ nullptr,
 
-        /*.initial_prompt    =*/ nullptr,
-        /*.prompt_tokens     =*/ nullptr,
-        /*.prompt_n_tokens   =*/ 0,
+        /*.initial_prompt       =*/ nullptr,
+        /*.carry_initial_prompt =*/ false,
+        /*.prompt_tokens        =*/ nullptr,
+        /*.prompt_n_tokens      =*/ 0,
 
         /*.language          =*/ "en",
         /*.detect_language   =*/ false,
@@ -6880,17 +6888,22 @@ int whisper_full_with_state(
         decoder.rng = std::mt19937(j);
     }
 
-    // the accumulated text context so far
-    auto & prompt_past = state->prompt_past;
+    // the accumulated text context split into static (prompt_past0) and dynamic (prompt_past1)
+    auto & prompt_past0 = state->prompt_past0;
+    auto & prompt_past1 = state->prompt_past1;
     if (params.no_context) {
-        prompt_past.clear();
+        prompt_past0.clear();
+        prompt_past1.clear();
     }
 
+    // calculate the maximum context budget for prompt history
+    const int max_prompt_ctx = std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2);
+
     // prepare prompt
     {
         std::vector<whisper_token> prompt_tokens;
 
-        // initial prompt
+        // tokenize the initial prompt
         if (!params.prompt_tokens && params.initial_prompt) {
             prompt_tokens.resize(1024);
             int n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
@@ -6902,14 +6915,25 @@ int whisper_full_with_state(
             params.prompt_tokens   = prompt_tokens.data();
             params.prompt_n_tokens = prompt_tokens.size();
         }
-
-        // prepend the prompt tokens to the prompt_past
         if (params.prompt_tokens && params.prompt_n_tokens > 0) {
-            // parse tokens from the pointer
-            for (int i = 0; i < params.prompt_n_tokens; i++) {
-                prompt_past.push_back(params.prompt_tokens[i]);
+            if (params.carry_initial_prompt) {
+                if (prompt_past0.empty()) {
+                    const int max_tokens = std::max(1, max_prompt_ctx - 1);
+
+                    if (params.prompt_n_tokens > max_tokens) {
+                        WHISPER_LOG_WARN("%s: initial prompt is too long (%d tokens), will use only the last %d tokens\n",
+                                        __func__, params.prompt_n_tokens, max_tokens);
+                    }
+
+                    const int n_tokens = std::min(params.prompt_n_tokens, max_tokens);
+                    prompt_past0.assign(params.prompt_tokens + (params.prompt_n_tokens - n_tokens), params.prompt_tokens + params.prompt_n_tokens);
+                }
+            } else {
+                for (int i = 0; i < params.prompt_n_tokens; ++i) {
+                    prompt_past1.push_back(params.prompt_tokens[i]);
+                }
+                std::rotate(prompt_past1.begin(), prompt_past1.end() - params.prompt_n_tokens, prompt_past1.end());
             }
-            std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
         }
     }
 
@@ -6995,7 +7019,8 @@ int whisper_full_with_state(
         // if there is a very short audio segment left to process, we remove any past prompt since it tends
         // to confuse the decoder and often make it repeat or hallucinate stuff
         if (seek > seek_start && seek + 500 >= seek_end) {
-            prompt_past.clear();
+            prompt_past0.clear();
+            prompt_past1.clear();
         }
 
         int best_decoder_id = 0;
@@ -7056,12 +7081,25 @@ int whisper_full_with_state(
             {
                 prompt.clear();
 
-                // if we have already generated some text, use it as a prompt to condition the next generation
-                if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) {
-                    int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
+                if (params.n_max_text_ctx > 0 && t_cur < WHISPER_HISTORY_CONDITIONING_TEMP_CUTOFF) {
+                    const bool can_take0 = params.carry_initial_prompt && !prompt_past0.empty();
+                    const bool can_take1 = !prompt_past1.empty();
+
+                    if (max_prompt_ctx > 0 && (can_take0 || can_take1)) {
+                        // Always start with previous token marker to connect continuity
+                        prompt.push_back(whisper_token_prev(ctx));
 
-                    prompt = { whisper_token_prev(ctx) };
-                    prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
+                        // Take static tokens (initial prompt) first
+                        int n_take0 = 0;
+                        if (can_take0) {
+                            n_take0 = prompt_past0.size();
+                            prompt.insert(prompt.end(), prompt_past0.end() - n_take0, prompt_past0.end());
+                        }
+
+                        // Fill remaining budget with dynamic tokens (rolling context)
+                        const int n_take1 = std::min<int>(max_prompt_ctx - n_take0 - 1, prompt_past1.size());
+                        prompt.insert(prompt.end(), prompt_past1.end() - n_take1, prompt_past1.end());
+                    }
                 }
 
                 // init new transcription with sot, language (opt) and task tokens
@@ -7543,14 +7581,17 @@ int whisper_full_with_state(
 
             //WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
 
-            // update prompt_past
-            prompt_past.clear();
-            if (prompt.front() == whisper_token_prev(ctx)) {
-                prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
+            // update prompt_past1
+            prompt_past1.clear();
+            if (!params.carry_initial_prompt && !prompt.empty() && prompt.front() == whisper_token_prev(ctx)) {
+                prompt_past1.insert(prompt_past1.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
             }
 
-            for (int i = 0; i < result_len && !is_no_speech; ++i) {
-                prompt_past.push_back(tokens_cur[i].id);
+            // Add newly decoded tokens to the rolling context
+            if (!is_no_speech) {
+                for (int i = 0; i < result_len; ++i) {
+                    prompt_past1.push_back(tokens_cur[i].id);
+                }
             }
 
             if (!tokens_cur.empty() && ctx->model.n_loaded > 0 && !is_no_speech) {