]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
main : add options for temperature control (#2088)
authorDaniel Ziegenberg <redacted>
Mon, 13 May 2024 11:59:44 +0000 (13:59 +0200)
committerGitHub <redacted>
Mon, 13 May 2024 11:59:44 +0000 (14:59 +0300)
Add two options:

```
-tp,       --temperature N     [0.00   ] The sampling temperature, between 0 and 1
-tpi,      --temperature-inc N [0.20   ] The increment of temperature, between 0 and 1
```

The sampling temperature, between 0 and 1. Higher values like 0.8 will
make the output more random, while lower values like 0.2 will make it
more focused and deterministic. If set to 0, the model will use log
probability to automatically increase the temperature until certain
thresholds are hit.

Signed-off-by: Daniel Ziegenberg <redacted>
examples/main/main.cpp

index 6a3db73d87aca7939ea6a583c44f422e4a64e78a..bb1931869d338fe541bfc9ddc4963938400b2db3 100644 (file)
@@ -44,6 +44,8 @@ struct whisper_params {
     float entropy_thold   =  2.40f;
     float logprob_thold   = -1.00f;
     float grammar_penalty = 100.0f;
+    float temperature     = 0.0f;
+    float temperature_inc = 0.2f;
 
     bool speed_up        = false;
     bool debug_mode      = false;
@@ -133,6 +135,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-wt"   || arg == "--word-thold")      { params.word_thold      = std::stof(argv[++i]); }
         else if (arg == "-et"   || arg == "--entropy-thold")   { params.entropy_thold   = std::stof(argv[++i]); }
         else if (arg == "-lpt"  || arg == "--logprob-thold")   { params.logprob_thold   = std::stof(argv[++i]); }
+        else if (arg == "-tp"   || arg == "--temperature")     { params.temperature     = std::stof(argv[++i]); }
+        else if (arg == "-tpi"  || arg == "--temperature-inc") { params.temperature_inc = std::stof(argv[++i]); }
         // else if (arg == "-su"   || arg == "--speed-up")        { params.speed_up        = true; }
         else if (arg == "-debug"|| arg == "--debug-mode")      { params.debug_mode      = true; }
         else if (arg == "-tr"   || arg == "--translate")       { params.translate       = true; }
@@ -198,6 +202,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -wt N,     --word-thold N      [%-7.2f] word timestamp probability threshold\n",         params.word_thold);
     fprintf(stderr, "  -et N,     --entropy-thold N   [%-7.2f] entropy threshold for decoder fail\n",           params.entropy_thold);
     fprintf(stderr, "  -lpt N,    --logprob-thold N   [%-7.2f] log probability threshold for decoder fail\n",   params.logprob_thold);
+    fprintf(stderr, "  -tp,       --temperature N     [%-7.2f] The sampling temperature, between 0 and 1\n",    params.temperature);
+    fprintf(stderr, "  -tpi,      --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n",params.temperature_inc);
     // fprintf(stderr, "  -su,       --speed-up          [%-7s] speed up audio by x2 (reduced accuracy)\n",        params.speed_up ? "true" : "false");
     fprintf(stderr, "  -debug,    --debug-mode        [%-7s] enable debug mode (eg. dump log_mel)\n",           params.debug_mode ? "true" : "false");
     fprintf(stderr, "  -tr,       --translate         [%-7s] translate from source language to english\n",      params.translate ? "true" : "false");
@@ -1107,7 +1113,9 @@ int main(int argc, char ** argv) {
             wparams.greedy.best_of        = params.best_of;
             wparams.beam_search.beam_size = params.beam_size;
 
-            wparams.temperature_inc  = params.no_fallback ? 0.0f : wparams.temperature_inc;
+            wparams.temperature_inc  = params.no_fallback ? 0.0f : params.temperature_inc;
+            wparams.temperature      = params.temperature;
+
             wparams.entropy_thold    = params.entropy_thold;
             wparams.logprob_thold    = params.logprob_thold;