]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
server : add dtw (#2044)
authorEmmanuel Schmidbauer <redacted>
Mon, 15 Apr 2024 19:16:58 +0000 (15:16 -0400)
committerGitHub <redacted>
Mon, 15 Apr 2024 19:16:58 +0000 (22:16 +0300)
* server.cpp: add dtw

* Update examples/server/server.cpp

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/server/server.cpp

index a11d0e78d152af39d8ff11bd498cd059b78448d3..a98e156c26b411a25b5a01e39754f79300bc4af9 100644 (file)
@@ -87,6 +87,8 @@ struct whisper_params {
     std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
 
     std::string openvino_encode_device = "CPU";
+
+    std::string dtw = "";
 };
 
 void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params, const server_params& sparams) {
@@ -126,6 +128,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -m FNAME,  --model FNAME       [%-7s] model path\n",                                     params.model.c_str());
     fprintf(stderr, "  -oved D,   --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n",  params.openvino_encode_device.c_str());
     // server params
+    fprintf(stderr, "  -dtw MODEL --dtw MODEL         [%-7s] compute token-level timestamps\n", params.dtw.c_str());
     fprintf(stderr, "  --host HOST,                   [%-7s] Hostname/ip-adress for the server\n", sparams.hostname.c_str());
     fprintf(stderr, "  --port PORT,                   [%-7d] Port number for the server\n", sparams.port);
     fprintf(stderr, "  --public PATH,                 [%-7s] Path to the public folder\n", sparams.public_path.c_str());
@@ -173,6 +176,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
         else if (                  arg == "--prompt")          { params.prompt          = argv[++i]; }
         else if (arg == "-m"    || arg == "--model")           { params.model           = argv[++i]; }
         else if (arg == "-oved" || arg == "--ov-e-device")     { params.openvino_encode_device = argv[++i]; }
+        else if (arg == "-dtw"  || arg == "--dtw")             { params.dtw             = argv[++i]; }
         else if (arg == "-ng"   || arg == "--no-gpu")          { params.use_gpu         = false; }
         // server params
         else if (                  arg == "--port")            { sparams.port        = std::stoi(argv[++i]); }
@@ -499,6 +503,49 @@ int main(int argc, char ** argv) {
     // whisper init
     struct whisper_context_params cparams = whisper_context_default_params();
     cparams.use_gpu = params.use_gpu;
+    if (!params.dtw.empty()) {
+        cparams.dtw_token_timestamps = true;
+        cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE;
+
+        if (params.dtw == "tiny") {
+            cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY;
+        }
+        if (params.dtw == "tiny.en") {
+            cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY_EN;
+        }
+        if (params.dtw == "base") {
+            cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE;
+        }
+        if (params.dtw == "base.en") {
+            cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE_EN;
+        }
+        if (params.dtw == "small") {
+            cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL;
+        }
+        if (params.dtw == "small.en") {
+            cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL_EN;
+        }
+        if (params.dtw == "medium") {
+            cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM;
+        }
+        if (params.dtw == "medium.en") {
+            cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM_EN;
+        }
+        if (params.dtw == "large.v1") {
+            cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V1;
+        }
+        if (params.dtw == "large.v2") {
+            cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V2;
+        }
+        if (params.dtw == "large.v3") {
+            cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3;
+        }
+
+        if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) {
+            fprintf(stderr, "error: unknown DTW preset '%s'\n", params.dtw.c_str());
+            return 3;
+        }
+    }
 
     struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
 
@@ -865,6 +912,7 @@ int main(int argc, char ** argv) {
                     if (!params.no_timestamps) {
                         word["start"] = token.t0 * 0.01;
                         word["end"] = token.t1 * 0.01;
+                        word["t_dtw"] = token.t_dtw;
                     }
                     word["probability"] = token.p;
                     total_logprob += token.plog;