]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : rename suppress_non_speech_tokens to suppress_nst (#2653)
authorGeorgi Gerganov <redacted>
Sat, 21 Dec 2024 10:54:35 +0000 (12:54 +0200)
committerGitHub <redacted>
Sat, 21 Dec 2024 10:54:35 +0000 (12:54 +0200)
bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java
bindings/ruby/ext/ruby_whisper.cpp
bindings/ruby/tests/test_params.rb
examples/lsp/lsp.cpp
examples/server/server.cpp
include/whisper.h
src/whisper.cpp

index 90d8c15767c9763bc243f87c12a7ae21b14a8043..18c209fc83cf8d37c9831ccf2a314908ad99040e 100644 (file)
@@ -181,11 +181,11 @@ public class WhisperFullParams extends Structure {
     }\r
 \r
     /** Flag to suppress non-speech tokens. */\r
-    public CBool suppress_non_speech_tokens;\r
+    public CBool suppress_nst;\r
 \r
     /** Flag to suppress non-speech tokens. */\r
     public void suppressNonSpeechTokens(boolean enable) {\r
-        suppress_non_speech_tokens = enable ? CBool.TRUE : CBool.FALSE;\r
+        suppress_nst = enable ? CBool.TRUE : CBool.FALSE;\r
     }\r
 \r
     /** Initial decoding temperature. */\r
@@ -315,7 +315,7 @@ public class WhisperFullParams extends Structure {
                 "print_special", "print_progress", "print_realtime", "print_timestamps",  "token_timestamps",\r
                 "thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "audio_ctx",\r
                 "tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",\r
-                "suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",\r
+                "suppress_blank", "suppress_nst", "temperature", "max_initial_ts", "length_penalty",\r
                 "temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",\r
                 "new_segment_callback", "new_segment_callback_user_data",\r
                 "progress_callback", "progress_callback_user_data",\r
index 26e9def4b03da9d2c76a9f630253f77bbe0cf977..aa526577fbed32add0398ab8dc982e161984b263 100644 (file)
@@ -979,19 +979,19 @@ static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) {
 }
 /*
  * call-seq:
- *   suppress_non_speech_tokens = force_suppress -> force_suppress
+ *   suppress_nst = force_suppress -> force_suppress
  */
-static VALUE ruby_whisper_params_set_suppress_non_speech_tokens(VALUE self, VALUE value) {
-  BOOL_PARAMS_SETTER(self, suppress_non_speech_tokens, value)
+static VALUE ruby_whisper_params_set_suppress_nst(VALUE self, VALUE value) {
+  BOOL_PARAMS_SETTER(self, suppress_nst, value)
 }
 /*
  * If true, suppresses non-speech-tokens.
  *
  * call-seq:
- *   suppress_non_speech_tokens -> bool
+ *   suppress_nst -> bool
  */
-static VALUE ruby_whisper_params_get_suppress_non_speech_tokens(VALUE self) {
-  BOOL_PARAMS_GETTER(self, suppress_non_speech_tokens)
+static VALUE ruby_whisper_params_get_suppress_nst(VALUE self) {
+  BOOL_PARAMS_GETTER(self, suppress_nst)
 }
 /*
  * If true, enables token-level timestamps.
@@ -1832,8 +1832,8 @@ void Init_whisper() {
   rb_define_method(cParams, "print_timestamps=", ruby_whisper_params_set_print_timestamps, 1);
   rb_define_method(cParams, "suppress_blank", ruby_whisper_params_get_suppress_blank, 0);
   rb_define_method(cParams, "suppress_blank=", ruby_whisper_params_set_suppress_blank, 1);
-  rb_define_method(cParams, "suppress_non_speech_tokens", ruby_whisper_params_get_suppress_non_speech_tokens, 0);
-  rb_define_method(cParams, "suppress_non_speech_tokens=", ruby_whisper_params_set_suppress_non_speech_tokens, 1);
+  rb_define_method(cParams, "suppress_nst", ruby_whisper_params_get_suppress_nst, 0);
+  rb_define_method(cParams, "suppress_nst=", ruby_whisper_params_set_suppress_nst, 1);
   rb_define_method(cParams, "token_timestamps", ruby_whisper_params_get_token_timestamps, 0);
   rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1);
   rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0);
index d2667ef0d3e16d10e50f74286dc906e715695060..7981bfaab5076db929337697463ace77475a0471 100644 (file)
@@ -89,11 +89,11 @@ class TestParams < TestBase
     assert !@params.suppress_blank
   end
 
-  def test_suppress_non_speech_tokens
-    @params.suppress_non_speech_tokens = true
-    assert @params.suppress_non_speech_tokens
-    @params.suppress_non_speech_tokens = false
-    assert !@params.suppress_non_speech_tokens
+  def test_suppress_nst
+    @params.suppress_nst = true
+    assert @params.suppress_nst
+    @params.suppress_nst = false
+    assert !@params.suppress_nst
   end
 
   def test_token_timestamps
index 1afc159f60395cf6ed44d9dcfbb7c275e3745a30..803cd6d55da06cdaae91df739874e013c51a1f69 100644 (file)
@@ -181,7 +181,7 @@ static json unguided_transcription(struct whisper_context * ctx, audio_async &au
     wparams.n_threads        = params.n_threads;
 
     wparams.audio_ctx        = params.audio_ctx;
-    wparams.suppress_non_speech_tokens = true;
+    wparams.suppress_nst     = true;
     // run the transformer and a single decoding pass
     if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
         fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
@@ -225,7 +225,7 @@ static json guided_transcription(struct whisper_context * ctx, audio_async &audi
     wparams.prompt_tokens    = cs.prompt_tokens.data();
     wparams.prompt_n_tokens  = cs.prompt_tokens.size();
     // TODO: properly expose as option
-    wparams.suppress_non_speech_tokens = true;
+    wparams.suppress_nst     = true;
 
     // run the transformer and a single decoding pass
     if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
index f0484b3492063c6f4c35fdfc6a877a237593a2ac..0608bb6bd1a50fb275941e9b990749574085edc4 100644 (file)
@@ -76,7 +76,7 @@ struct whisper_params {
     bool no_timestamps   = false;
     bool use_gpu         = true;
     bool flash_attn      = false;
-    bool suppress_non_speech_tokens = false;
+    bool suppress_nst    = false;
 
     std::string language        = "en";
     std::string prompt          = "";
@@ -136,7 +136,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  --request-path PATH,           [%-7s] Request path for all requests\n", sparams.request_path.c_str());
     fprintf(stderr, "  --inference-path PATH,         [%-7s] Inference path for all requests\n", sparams.inference_path.c_str());
     fprintf(stderr, "  --convert,                     [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false");
-    fprintf(stderr, "  -sns,      --suppress-non-speech [%-7s] suppress non-speech tokens\n", params.suppress_non_speech_tokens ? "true" : "false");
+    fprintf(stderr, "  -sns,      --suppress-nst      [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false");
     fprintf(stderr, "\n");
 }
 
@@ -181,7 +181,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
         else if (arg == "-dtw"  || arg == "--dtw")             { params.dtw             = argv[++i]; }
         else if (arg == "-ng"   || arg == "--no-gpu")          { params.use_gpu         = false; }
         else if (arg == "-fa"   || arg == "--flash-attn")      { params.flash_attn      = true; }
-        else if (arg == "-sns"  || arg == "--suppress-non-speech") { params.suppress_non_speech_tokens = true; }
+        else if (arg == "-sns"  || arg == "--suppress-nst")    { params.suppress_nst    = true; }
         // server params
         else if (                  arg == "--port")            { sparams.port        = std::stoi(argv[++i]); }
         else if (                  arg == "--host")            { sparams.hostname    = argv[++i]; }
@@ -477,7 +477,11 @@ void get_req_parameters(const Request & req, whisper_params & params)
     }
     if (req.has_file("suppress_non_speech"))
     {
-        params.suppress_non_speech_tokens = parse_str_to_bool(req.get_file_value("suppress_non_speech").content);
+        params.suppress_nst = parse_str_to_bool(req.get_file_value("suppress_non_speech").content);
+    }
+    if (req.has_file("suppress_nst"))
+    {
+        params.suppress_nst = parse_str_to_bool(req.get_file_value("suppress_nst").content);
     }
 }
 
@@ -793,7 +797,7 @@ int main(int argc, char ** argv) {
             wparams.no_timestamps    = params.no_timestamps;
             wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format;
 
-            wparams.suppress_non_speech_tokens = params.suppress_non_speech_tokens;
+            wparams.suppress_nst     = params.suppress_nst;
 
             whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
 
index 71949bdd39727fb1fe92391bfcfbd4309e1e9660..6e0db505ee18f11b05a4b095f0c41b28f963e6e4 100644 (file)
@@ -522,8 +522,8 @@ extern "C" {
         bool detect_language;
 
         // common decoding parameters:
-        bool suppress_blank;    // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
-        bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
+        bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
+        bool suppress_nst;   // non-speech tokens, ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
 
         float temperature;      // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
         float max_initial_ts;   // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
index bcc530ae8918f6806316b4b0e3da055ac8ae1545..ff200ea0f8d6b15f4c42a2c99eedc824f38390a6 100644 (file)
@@ -4676,7 +4676,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
         /*.detect_language   =*/ false,
 
         /*.suppress_blank    =*/ true,
-        /*.suppress_non_speech_tokens =*/ false,
+        /*.suppress_nst      =*/ false,
 
         /*.temperature       =*/  0.0f,
         /*.max_initial_ts    =*/  1.0f,
@@ -4960,7 +4960,7 @@ static void whisper_process_logits(
 
         // suppress non-speech tokens
         // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
-        if (params.suppress_non_speech_tokens) {
+        if (params.suppress_nst) {
             for (const std::string & token : non_speech_tokens) {
                 const std::string suppress_tokens[] = {token, " " + token};
                 for (const std::string & suppress_token : suppress_tokens) {