]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
stream : add beam size parameter(#2836)
authormasahji <redacted>
Tue, 25 Feb 2025 09:39:33 +0000 (01:39 -0800)
committerGitHub <redacted>
Tue, 25 Feb 2025 09:39:33 +0000 (11:39 +0200)
* feat: Add beam size parameter to stream.cpp for beam search configuration

* feat: Add beam size parameter to whisper full params in stream example

* fix: Remove duplicate beam search size assignment in server.cpp

examples/stream/stream.cpp

index 190f68a2c3be8fddaded2a0c900eb47d8c032221..833c240edd74b54754d57c0dc968191bbdc7c5d5 100644 (file)
@@ -23,6 +23,7 @@ struct whisper_params {
     int32_t capture_id = -1;
     int32_t max_tokens = 32;
     int32_t audio_ctx  = 0;
+    int32_t beam_size  = -1;
 
     float vad_thold    = 0.6f;
     float freq_thold   = 100.0f;
@@ -59,6 +60,7 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
         else if (arg == "-c"    || arg == "--capture")       { params.capture_id    = std::stoi(argv[++i]); }
         else if (arg == "-mt"   || arg == "--max-tokens")    { params.max_tokens    = std::stoi(argv[++i]); }
         else if (arg == "-ac"   || arg == "--audio-ctx")     { params.audio_ctx     = std::stoi(argv[++i]); }
+        else if (arg == "-bs"   || arg == "--beam-size")     { params.beam_size     = std::stoi(argv[++i]); }
         else if (arg == "-vth"  || arg == "--vad-thold")     { params.vad_thold     = std::stof(argv[++i]); }
         else if (arg == "-fth"  || arg == "--freq-thold")    { params.freq_thold    = std::stof(argv[++i]); }
         else if (arg == "-tr"   || arg == "--translate")     { params.translate     = true; }
@@ -96,6 +98,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -c ID,    --capture ID    [%-7d] capture device ID\n",                              params.capture_id);
     fprintf(stderr, "  -mt N,    --max-tokens N  [%-7d] maximum number of tokens per audio chunk\n",       params.max_tokens);
     fprintf(stderr, "  -ac N,    --audio-ctx N   [%-7d] audio context size (0 - all)\n",                   params.audio_ctx);
+    fprintf(stderr, "  -bs N,    --beam-size N   [%-7d] beam size for beam search\n",                      params.beam_size);
     fprintf(stderr, "  -vth N,   --vad-thold N   [%-7.2f] voice activity detection threshold\n",           params.vad_thold);
     fprintf(stderr, "  -fth N,   --freq-thold N  [%-7.2f] high-pass frequency cutoff\n",                   params.freq_thold);
     fprintf(stderr, "  -tr,      --translate     [%-7s] translate from source language to english\n",      params.translate ? "true" : "false");
@@ -298,7 +301,7 @@ int main(int argc, char ** argv) {
 
         // run the inference
         {
-            whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
+            whisper_full_params wparams = whisper_full_default_params(params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY);
 
             wparams.print_progress   = false;
             wparams.print_special    = params.print_special;
@@ -309,6 +312,7 @@ int main(int argc, char ** argv) {
             wparams.max_tokens       = params.max_tokens;
             wparams.language         = params.language.c_str();
             wparams.n_threads        = params.n_threads;
+            wparams.beam_search.beam_size = params.beam_size;
 
             wparams.audio_ctx        = params.audio_ctx;