]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : add "split_on_word" flag when using using "max_len" option (#455)
authorMatija Pevec <redacted>
Sun, 5 Feb 2023 12:44:23 +0000 (13:44 +0100)
committerGitHub <redacted>
Sun, 5 Feb 2023 12:44:23 +0000 (14:44 +0200)
* Update whisper.cpp

* fix: trim function

* feat: added flag to split on word

* fix: arguments for main

examples/main/main.cpp
whisper.cpp
whisper.h

index 6a697ac8ef7bdc7b50885f98d104a212893b9d11..fbc9faf8f4afa54e54fcb5e0e19c0ac3fa6b0791 100644 (file)
@@ -69,6 +69,7 @@ struct whisper_params {
     bool speed_up       = false;
     bool translate      = false;
     bool diarize        = false;
+    bool split_on_word  = false;
     bool no_fallback    = false;
     bool output_txt     = false;
     bool output_vtt     = false;
@@ -118,6 +119,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-su"   || arg == "--speed-up")       { params.speed_up       = true; }
         else if (arg == "-tr"   || arg == "--translate")      { params.translate      = true; }
         else if (arg == "-di"   || arg == "--diarize")        { params.diarize        = true; }
+        else if (arg == "-sow"  || arg == "--split-on-word")  { params.split_on_word  = true; }
         else if (arg == "-nf"   || arg == "--no-fallback")    { params.no_fallback    = true; }
         else if (arg == "-otxt" || arg == "--output-txt")     { params.output_txt     = true; }
         else if (arg == "-ovtt" || arg == "--output-vtt")     { params.output_vtt     = true; }
@@ -156,6 +158,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -d  N,     --duration N        [%-7d] duration of audio to process in milliseconds\n",   params.duration_ms);
     fprintf(stderr, "  -mc N,     --max-context N     [%-7d] maximum number of text context tokens to store\n", params.max_context);
     fprintf(stderr, "  -ml N,     --max-len N         [%-7d] maximum segment length in characters\n",           params.max_len);
+    fprintf(stderr, "  -sow,      --split-on-word     [%-7s] split on word rather than on token\n",             params.split_on_word ? "true" : "false");
     fprintf(stderr, "  -bo N,     --best-of N         [%-7d] number of best candidates to keep\n",              params.best_of);
     fprintf(stderr, "  -bs N,     --beam-size N       [%-7d] beam size for beam search\n",                      params.beam_size);
     fprintf(stderr, "  -wt N,     --word-thold N      [%-7.2f] word timestamp probability threshold\n",         params.word_thold);
@@ -651,6 +654,7 @@ int main(int argc, char ** argv) {
             wparams.token_timestamps = params.output_wts || params.max_len > 0;
             wparams.thold_pt         = params.word_thold;
             wparams.max_len          = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
+            wparams.split_on_word    = params.split_on_word;
 
             wparams.speed_up         = params.speed_up;
 
index f123ed84ded45f2993b85ba5673c30761c39b375..1a4a207157b5a9c5b119437582f2cff7956fe877 100644 (file)
@@ -2922,6 +2922,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
         /*.thold_pt         =*/ 0.01f,
         /*.thold_ptsum      =*/ 0.01f,
         /*.max_len          =*/ 0,
+        /*.split_on_word    =*/ false,
         /*.max_tokens       =*/ 0,
 
         /*.speed_up         =*/ false,
@@ -2988,9 +2989,36 @@ static void whisper_exp_compute_token_level_timestamps(
                          float   thold_pt,
                          float   thold_ptsum);
 
+// trim from start (in place)
+static inline void ltrim(std::string &s) {
+    s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) {
+        return !std::isspace(ch);
+    }));
+}
+
+// trim from end (in place)
+static inline void rtrim(std::string &s) {
+    s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) {
+        return !std::isspace(ch);
+    }).base(), s.end());
+}
+
+// trim from both ends (in place)
+static inline void trim(std::string &s) {
+    rtrim(s);
+    ltrim(s);
+}
+
+static inline bool should_split_on_word(const char * txt, bool split_on_word) {
+    if (!split_on_word) return true;
+
+    std::string s = txt;
+    return s.substr(0, 1) == " ";
+}
+
 // wrap the last segment to max_len characters
 // returns the number of new segments
-static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
+static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool split_on_word) {
     auto segment = ctx.result_all.back();
 
     int res = 1;
@@ -3005,11 +3033,11 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
         }
 
         const auto txt = whisper_token_to_str(&ctx, token.id);
-
         const int cur = strlen(txt);
 
-        if (acc + cur > max_len && i > 0) {
+        if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) {
             // split here
+            trim(text);
             ctx.result_all.back().text = std::move(text);
             ctx.result_all.back().t1 = token.t0;
             ctx.result_all.back().tokens.resize(i);
@@ -3037,6 +3065,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
         }
     }
 
+    trim(text);
     ctx.result_all.back().text = std::move(text);
 
     return res;
@@ -4069,7 +4098,7 @@ int whisper_full(
                                         *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
 
                                 if (params.max_len > 0) {
-                                    n_new = whisper_wrap_segment(*ctx, params.max_len);
+                                    n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
                                 }
                             }
                             if (params.new_segment_callback) {
@@ -4113,7 +4142,7 @@ int whisper_full(
                                 *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
 
                         if (params.max_len > 0) {
-                            n_new = whisper_wrap_segment(*ctx, params.max_len);
+                            n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
                         }
                     }
                     if (params.new_segment_callback) {
index 51a18889d8d09fdcce87295ba82646f32f87f3ca..3a426680d2c7b494c6d9cadfe897d30d9dbe258e 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -257,6 +257,7 @@ extern "C" {
         float thold_pt;         // timestamp token probability threshold (~0.01)
         float thold_ptsum;      // timestamp token sum probability threshold (~0.01)
         int   max_len;          // max segment length in characters
+        bool  split_on_word;    // split on word rather than on token (when used with max_len)
         int   max_tokens;       // max tokens per segment (0 = no limit)
 
         // [EXPERIMENTAL] speed-up techniques