]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
server : add option to suppress non-speech tokens (#2649)
authorSacha Arbonel <redacted>
Sat, 21 Dec 2024 10:05:05 +0000 (11:05 +0100)
committerGitHub <redacted>
Sat, 21 Dec 2024 10:05:05 +0000 (12:05 +0200)
* The parameter will suppress non-speech tokens like [LAUGH], [SIGH], etc. from the output when enabled.

* add to whisper_params_parse

* add missing param

examples/server/server.cpp

index af46751393599c07c467a0517521bf60067afcf3..f0484b3492063c6f4c35fdfc6a877a237593a2ac 100644 (file)
@@ -76,6 +76,7 @@ struct whisper_params {
     bool no_timestamps   = false;
     bool use_gpu         = true;
     bool flash_attn      = false;
+    bool suppress_non_speech_tokens = false;
 
     std::string language        = "en";
     std::string prompt          = "";
@@ -135,6 +136,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  --request-path PATH,           [%-7s] Request path for all requests\n", sparams.request_path.c_str());
     fprintf(stderr, "  --inference-path PATH,         [%-7s] Inference path for all requests\n", sparams.inference_path.c_str());
     fprintf(stderr, "  --convert,                     [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false");
+    fprintf(stderr, "  -sns,      --suppress-non-speech [%-7s] suppress non-speech tokens\n", params.suppress_non_speech_tokens ? "true" : "false");
     fprintf(stderr, "\n");
 }
 
@@ -179,6 +181,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
         else if (arg == "-dtw"  || arg == "--dtw")             { params.dtw             = argv[++i]; }
         else if (arg == "-ng"   || arg == "--no-gpu")          { params.use_gpu         = false; }
         else if (arg == "-fa"   || arg == "--flash-attn")      { params.flash_attn      = true; }
+        else if (arg == "-sns"  || arg == "--suppress-non-speech") { params.suppress_non_speech_tokens = true; }
         // server params
         else if (                  arg == "--port")            { sparams.port        = std::stoi(argv[++i]); }
         else if (                  arg == "--host")            { sparams.hostname    = argv[++i]; }
@@ -472,6 +475,10 @@ void get_req_parameters(const Request & req, whisper_params & params)
     {
         params.temperature_inc = std::stof(req.get_file_value("temperature_inc").content);
     }
+    if (req.has_file("suppress_non_speech"))
+    {
+        params.suppress_non_speech_tokens = parse_str_to_bool(req.get_file_value("suppress_non_speech").content);
+    }
 }
 
 }  // namespace
@@ -786,6 +793,8 @@ int main(int argc, char ** argv) {
             wparams.no_timestamps    = params.no_timestamps;
             wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format;
 
+            wparams.suppress_non_speech_tokens = params.suppress_non_speech_tokens;
+
             whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
 
             // this callback is called on each new segment