]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : token-level timestamp refactoring (#49, #120)
authorGeorgi Gerganov <redacted>
Wed, 2 Nov 2022 19:18:20 +0000 (21:18 +0200)
committerGeorgi Gerganov <redacted>
Wed, 2 Nov 2022 19:45:54 +0000 (21:45 +0200)
This turned out pretty good overall. The algorithm has been moved from
main.cpp to whisper.cpp and can be reused for all subtitles types. This
means that now you can specify the maximum length of the generated
lines. Simply provide the "-ml" argument specifying the max length in
number of characters

README.md
examples/main/README.md
examples/main/main.cpp
whisper.cpp
whisper.h

index fdbc65e90ec1509842a72449426eb3a76bc7d814..a888880550db944f4e1a45d1c4b650c4b0283eb7 100644 (file)
--- a/README.md
+++ b/README.md
@@ -101,13 +101,14 @@ options:
   -ot N,    --offset-t N     time offset in milliseconds (default: 0)
   -on N,    --offset-n N     segment index offset (default: 0)
   -mc N,    --max-context N  maximum number of text context tokens to store (default: max)
+  -ml N,    --max-len N      maximum segment length in characters (default: 0)
   -wt N,    --word-thold N   word timestamp probability threshold (default: 0.010000)
   -v,       --verbose        verbose output
             --translate      translate from source language to english
   -otxt,    --output-txt     output result in a text file
   -ovtt,    --output-vtt     output result in a vtt file
   -osrt,    --output-srt     output result in a srt file
-  -owts,    --output-words   output word-level timestamps to a text file
+  -owts,    --output-words   output script for generating karaoke video
   -ps,      --print_special  print special tokens
   -pc,      --print_colors   print colors
   -nt,      --no_timestamps  do not print timestamps
index 27f47ff580cf290cff6762b88289a81961cd6145..f2bf2a8dca9c94ab71f4765045951abd4b0739eb 100644 (file)
@@ -8,7 +8,6 @@ It can be used as a reference for using the `whisper.cpp` library in other proje
 \r
 usage: ./bin/main [options] file0.wav file1.wav ...\r
 \r
-options:\r
   -h,       --help           show this help message and exit\r
   -s SEED,  --seed SEED      RNG seed (default: -1)\r
   -t N,     --threads N      number of threads to use during computation (default: 4)\r
@@ -16,18 +15,20 @@ options:
   -ot N,    --offset-t N     time offset in milliseconds (default: 0)\r
   -on N,    --offset-n N     segment index offset (default: 0)\r
   -mc N,    --max-context N  maximum number of text context tokens to store (default: max)\r
+  -ml N,    --max-len N      maximum segment length in characters (default: 0)\r
   -wt N,    --word-thold N   word timestamp probability threshold (default: 0.010000)\r
   -v,       --verbose        verbose output\r
             --translate      translate from source language to english\r
   -otxt,    --output-txt     output result in a text file\r
   -ovtt,    --output-vtt     output result in a vtt file\r
   -osrt,    --output-srt     output result in a srt file\r
-  -owts,    --output-words   output word-level timestamps to a text file\r
+  -owts,    --output-words   output script for generating karaoke video\r
   -ps,      --print_special  print special tokens\r
   -pc,      --print_colors   print colors\r
   -nt,      --no_timestamps  do not print timestamps\r
   -l LANG,  --language LANG  spoken language (default: en)\r
   -m FNAME, --model FNAME    model path (default: models/ggml-base.en.bin)\r
   -f FNAME, --file FNAME     input WAV file path\r
+  -h,       --help           show this help message and exit\r
 \r
 ```\r
index 83438921654f3e66e8ac69359008aed20e9dd822..b58945998e938fbaa1af3e194a204675c444456f 100644 (file)
@@ -36,6 +36,7 @@ std::string to_timestamp(int64_t t, bool comma = false) {
     return std::string(buf);
 }
 
+// helper function to replace substrings
 void replace_all(std::string & s, const std::string & search, const std::string & replace) {
     for (size_t pos = 0; ; pos += replace.length()) {
         pos = s.find(search, pos);
@@ -45,31 +46,6 @@ void replace_all(std::string & s, const std::string & search, const std::string
     }
 }
 
-// a cost-function that is high for text that takes longer to pronounce
-float voice_length(const std::string & text) {
-    float res = 0.0f;
-
-    for (size_t i = 0; i < text.size(); ++i) {
-        if (text[i] == ' ') {
-            res += 0.01f;
-        } else if (text[i] == ',') {
-            res += 2.00f;
-        } else if (text[i] == '.') {
-            res += 3.00f;
-        } else if (text[i] == '!') {
-            res += 3.00f;
-        } else if (text[i] == '?') {
-            res += 3.00f;
-        } else if (text[i] >= '0' && text[i] <= '9') {
-            res += 3.00f;
-        } else {
-            res += 1.00f;
-        }
-    }
-
-    return res;
-}
-
 // command-line parameters
 struct whisper_params {
     int32_t seed         = -1; // RNG seed, not used currently
@@ -78,6 +54,7 @@ struct whisper_params {
     int32_t offset_t_ms  = 0;
     int32_t offset_n     = 0;
     int32_t max_context  = -1;
+    int32_t max_len      = 0;
 
     float word_thold = 0.01f;
 
@@ -120,6 +97,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
             params.offset_n = std::stoi(argv[++i]);
         } else if (arg == "-mc" || arg == "--max-context") {
             params.max_context = std::stoi(argv[++i]);
+        } else if (arg == "-ml" || arg == "--max-len") {
+            params.max_len = std::stoi(argv[++i]);
         } else if (arg == "-wt" || arg == "--word-thold") {
             params.word_thold = std::stof(argv[++i]);
         } else if (arg == "-v" || arg == "--verbose") {
@@ -176,13 +155,14 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
     fprintf(stderr, "  -ot N,    --offset-t N     time offset in milliseconds (default: %d)\n", params.offset_t_ms);
     fprintf(stderr, "  -on N,    --offset-n N     segment index offset (default: %d)\n", params.offset_n);
     fprintf(stderr, "  -mc N,    --max-context N  maximum number of text context tokens to store (default: max)\n");
+    fprintf(stderr, "  -ml N,    --max-len N      maximum segment length in characters (default: %d)\n", params.max_len);
     fprintf(stderr, "  -wt N,    --word-thold N   word timestamp probability threshold (default: %f)\n", params.word_thold);
     fprintf(stderr, "  -v,       --verbose        verbose output\n");
     fprintf(stderr, "            --translate      translate from source language to english\n");
     fprintf(stderr, "  -otxt,    --output-txt     output result in a text file\n");
     fprintf(stderr, "  -ovtt,    --output-vtt     output result in a vtt file\n");
     fprintf(stderr, "  -osrt,    --output-srt     output result in a srt file\n");
-    fprintf(stderr, "  -owts,    --output-words   output word-level timestamps to a text file\n");
+    fprintf(stderr, "  -owts,    --output-words   output script for generating karaoke video\n");
     fprintf(stderr, "  -ps,      --print_special  print special tokens\n");
     fprintf(stderr, "  -pc,      --print_colors   print colors\n");
     fprintf(stderr, "  -nt,      --no_timestamps  do not print timestamps\n");
@@ -192,65 +172,67 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
     fprintf(stderr, "\n");
 }
 
-void whisper_print_segment_callback(struct whisper_context * ctx, void * user_data) {
+void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
     const whisper_params & params = *(whisper_params *) user_data;
 
     const int n_segments = whisper_full_n_segments(ctx);
 
-    // print the last segment
-    const int i = n_segments - 1;
-    if (i == 0) {
+    // print the last n_new segments
+    const int s0 = n_segments - n_new;
+    if (s0 == 0) {
         printf("\n");
     }
 
-    if (params.no_timestamps) {
-        if (params.print_colors) {
-            for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
-                if (params.print_special_tokens == false) {
-                    const whisper_token id = whisper_full_get_token_id(ctx, i, j);
-                    if (id >= whisper_token_eot(ctx)) {
-                        continue;
+    for (int i = s0; i < n_segments; i++) {
+        if (params.no_timestamps) {
+            if (params.print_colors) {
+                for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
+                    if (params.print_special_tokens == false) {
+                        const whisper_token id = whisper_full_get_token_id(ctx, i, j);
+                        if (id >= whisper_token_eot(ctx)) {
+                            continue;
+                        }
                     }
-                }
 
-                const char * text = whisper_full_get_token_text(ctx, i, j);
-                const float  p    = whisper_full_get_token_p   (ctx, i, j);
+                    const char * text = whisper_full_get_token_text(ctx, i, j);
+                    const float  p    = whisper_full_get_token_p   (ctx, i, j);
 
-                const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
+                    const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
 
-                printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
+                    printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
+                }
+            } else {
+                const char * text = whisper_full_get_segment_text(ctx, i);
+                printf("%s", text);
             }
+            fflush(stdout);
         } else {
-            const char * text = whisper_full_get_segment_text(ctx, i);
-            printf("%s", text);
-        }
-        fflush(stdout);
-    } else {
-        const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
-        const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
-
-        if (params.print_colors) {
-            printf("[%s --> %s]  ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
-            for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
-                if (params.print_special_tokens == false) {
-                    const whisper_token id = whisper_full_get_token_id(ctx, i, j);
-                    if (id >= whisper_token_eot(ctx)) {
-                        continue;
+            const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
+            const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+
+            if (params.print_colors) {
+                printf("[%s --> %s]  ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
+                for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
+                    if (params.print_special_tokens == false) {
+                        const whisper_token id = whisper_full_get_token_id(ctx, i, j);
+                        if (id >= whisper_token_eot(ctx)) {
+                            continue;
+                        }
                     }
-                }
 
-                const char * text = whisper_full_get_token_text(ctx, i, j);
-                const float  p    = whisper_full_get_token_p   (ctx, i, j);
+                    const char * text = whisper_full_get_token_text(ctx, i, j);
+                    const float  p    = whisper_full_get_token_p   (ctx, i, j);
 
-                const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
+                    const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
 
-                printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
-            }
-            printf("\n");
-        } else {
-            const char * text = whisper_full_get_segment_text(ctx, i);
+                    printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
+                }
+                printf("\n");
+            } else {
+                const char * text = whisper_full_get_segment_text(ctx, i);
 
-            printf("[%s --> %s]  %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
+                printf("[%s --> %s]  %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
+            }
         }
     }
 }
@@ -320,297 +302,41 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
     return true;
 }
 
-// word-level timestamps (experimental)
-// TODO: make ffmpeg output optional
-// TODO: extra pass to detect unused speech and assign to tokens
+// karaoke video generation
+// outputs a bash script that uses ffmpeg to generate a video with the subtitles
 // TODO: font parameter adjustments
-// TODO: move to whisper.h/whisper.cpp and add parameter to select max line-length of subtitles
-bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, const std::vector<float> & pcmf32) {
-    std::vector<float> pcm_avg(pcmf32.size(), 0);
-
-    // average the fabs of the signal
-    {
-        const int hw = 32;
-
-        for (int i = 0; i < pcmf32.size(); i++) {
-            float sum = 0;
-            for (int j = -hw; j <= hw; j++) {
-                if (i + j >= 0 && i + j < pcmf32.size()) {
-                    sum += fabs(pcmf32[i + j]);
-                }
-            }
-            pcm_avg[i] = sum/(2*hw + 1);
-        }
-    }
-
-    struct token_info {
-        int64_t t0 = -1;
-        int64_t t1 = -1;
-
-        int64_t tt0 = -1;
-        int64_t tt1 = -1;
-
-        whisper_token id;
-        whisper_token tid;
-
-        float p     = 0.0f;
-        float pt    = 0.0f;
-        float ptsum = 0.0f;
-
-        std::string text;
-        float vlen = 0.0f; // voice length of this token
-    };
-
-    int64_t t_beg  = 0;
-    int64_t t_last = 0;
-
-    whisper_token tid_last = 0;
-
+bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) {
     std::ofstream fout(fname);
 
     fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
 
+    // TODO: become parameter
+    static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
+
     fout << "!/bin/bash" << "\n";
     fout << "\n";
 
-    fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE << ":rate=25:color=black -vf \"";
-
-    bool is_first = true;
+    fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << t_sec << ":rate=25:color=black -vf \"";
 
     for (int i = 0; i < whisper_full_n_segments(ctx); i++) {
         const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
         const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
 
-        const char *text = whisper_full_get_segment_text(ctx, i);
-
-        const int s0 = std::max(0,                   (int) (t0*WHISPER_SAMPLE_RATE/100));
-        const int s1 = std::min((int) pcmf32.size(), (int) (t1*WHISPER_SAMPLE_RATE/100));
-
         const int n = whisper_full_n_tokens(ctx, i);
 
-        std::vector<token_info> tokens(n);
-
-        if (n <= 1) {
-            continue;
-        }
-
+        std::vector<whisper_token_data> tokens(n);
         for (int j = 0; j < n; ++j) {
-            struct whisper_token_data token = whisper_full_get_token_data(ctx, i, j);
-
-            if (j == 0) {
-                if (token.id == whisper_token_beg(ctx)) {
-                    tokens[j    ].t0 = t0;
-                    tokens[j    ].t1 = t0;
-                    tokens[j + 1].t0 = t0;
-
-                    t_beg  = t0;
-                    t_last = t0;
-                    tid_last = whisper_token_beg(ctx);
-                } else {
-                    tokens[j    ].t0 = t_last;
-                }
-            }
-
-            const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx));
-
-            tokens[j].id    = token.id;
-            tokens[j].tid   = token.tid;
-            tokens[j].p     = token.p;
-            tokens[j].pt    = token.pt;
-            tokens[j].ptsum = token.ptsum;
-
-            tokens[j].text = whisper_token_to_str(ctx, token.id);
-            tokens[j].vlen = voice_length(tokens[j].text);
-
-            if (token.pt > params.word_thold && token.ptsum > 0.01 && token.tid > tid_last && tt <= t1) {
-                if (j > 0) {
-                    tokens[j - 1].t1 = tt;
-                }
-                tokens[j].t0 = tt;
-                tid_last = token.tid;
-            }
+            tokens[j] = whisper_full_get_token_data(ctx, i, j);
         }
 
-        tokens[n - 2].t1 = t1;
-        tokens[n - 1].t0 = t1;
-        tokens[n - 1].t1 = t1;
-
-        t_last = t1;
-
-        // find intervals of tokens with unknown timestamps
-        // fill the timestamps by proportionally splitting the interval based on the token voice lengths
-        {
-            int p0 = 0;
-            int p1 = 0;
-            while (true) {
-                while (p1 < n && tokens[p1].t1 < 0) {
-                    p1++;
-                }
-
-                if (p1 >= n) {
-                    p1--;
-                }
-
-                if (p1 > p0) {
-                    double psum = 0.0;
-                    for (int j = p0; j <= p1; j++) {
-                        psum += tokens[j].vlen;
-                    }
-
-                    //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
-
-                    const double dt = tokens[p1].t1 - tokens[p0].t0;
-
-                    // split the time proportionally to the voice length
-                    for (int j = p0 + 1; j <= p1; j++) {
-                        const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum;
-
-                        tokens[j - 1].t1 = ct;
-                        tokens[j    ].t0 = ct;
-                    }
-                }
-
-                p1++;
-                p0 = p1;
-                if (p1 >= n) {
-                    break;
-                }
-            }
-        }
-
-        // fix up (just in case)
-        for (int j = 0; j < n - 1; j++) {
-            if (tokens[j].t1 < 0) {
-                tokens[j + 1].t0 = tokens[j].t1;
-            }
-
-            if (j > 0) {
-                if (tokens[j - 1].t1 > tokens[j].t0) {
-                    tokens[j].t0 = tokens[j - 1].t1;
-                    tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1);
-                }
-            }
-
-            tokens[j].tt0 = tokens[j].t0;
-            tokens[j].tt1 = tokens[j].t1;
-        }
-
-        // VAD
-        // expand or contract tokens based on voice activity
-        {
-            const int hw = WHISPER_SAMPLE_RATE/8;
-
-            for (int j = 0; j < n; j++) {
-                if (tokens[j].id >= whisper_token_eot(ctx)) {
-                    continue;
-                }
-
-                const int64_t t0 = tokens[j].t0;
-                const int64_t t1 = tokens[j].t1;
-
-                int s0 = std::max(0,                        (int) (t0*WHISPER_SAMPLE_RATE/100));
-                int s1 = std::min((int) pcmf32.size() - 1,  (int) (t1*WHISPER_SAMPLE_RATE/100));
-
-                const int ss0 = std::max(0,                       (int) (t0*WHISPER_SAMPLE_RATE/100) - hw);
-                const int ss1 = std::min((int) pcmf32.size() - 1, (int) (t1*WHISPER_SAMPLE_RATE/100) + hw);
-
-                const int n = ss1 - ss0;
-
-                float sum = 0.0f;
-
-                for (int k = ss0; k < ss1; k++) {
-                    sum += pcm_avg[k];
-                }
-
-                const float thold = 0.5*sum/n;
-
-                {
-                    int k = s0;
-                    if (pcm_avg[k] > thold && j > 0) {
-                        while (k > 0 && pcm_avg[k] > thold) {
-                            k--;
-                        }
-                        tokens[j].t0 = (int64_t) (100*k/WHISPER_SAMPLE_RATE);
-                        if (tokens[j].t0 < tokens[j - 1].t1) {
-                            tokens[j].t0 = tokens[j - 1].t1;
-                        } else {
-                            s0 = k;
-                        }
-                    } else {
-                        while (pcm_avg[k] < thold && k < s1) {
-                            k++;
-                        }
-                        s0 = k;
-                        tokens[j].t0 = 100*k/WHISPER_SAMPLE_RATE;
-                    }
-                }
-
-                {
-                    int k = s1;
-                    if (pcm_avg[k] > thold) {
-                        while (k < (int) pcmf32.size() - 1 && pcm_avg[k] > thold) {
-                            k++;
-                        }
-                        tokens[j].t1 = 100*k/WHISPER_SAMPLE_RATE;
-                        if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) {
-                            tokens[j].t1 = tokens[j + 1].t0;
-                        } else {
-                            s1 = k;
-                        }
-                    } else {
-                        while (pcm_avg[k] < thold && k > s0) {
-                            k--;
-                        }
-                        s1 = k;
-                        tokens[j].t1 = 100*k/WHISPER_SAMPLE_RATE;
-                    }
-                }
-            }
-        }
-
-        // fixed token expand (optional)
-        {
-            const int t_expand = 0;
-
-            for (int j = 0; j < n; j++) {
-                if (j > 0) {
-                    tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
-                }
-                if (j < n - 1) {
-                    tokens[j].t1 = tokens[j].t1 + t_expand;
-                }
-            }
-        }
-
-        // debug info
-        // TODO: toggle via parameter
-        for (int j = 0; j < n; ++j) {
-            const auto & token = tokens[j];
-            const auto tt = token.pt > params.word_thold && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]";
-            printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
-                    tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, token.text.c_str());
-
-            if (tokens[j].id >= whisper_token_eot(ctx)) {
-                continue;
-            }
-
-            //printf("[%s --> %s] %s\n", to_timestamp(token.t0).c_str(), to_timestamp(token.t1).c_str(), whisper_token_to_str(ctx, token.id));
-
-            //fout << "# " << to_timestamp(token.t0) << " --> " << to_timestamp(token.t1) << " " << whisper_token_to_str(ctx, token.id) << "\n";
-        }
-
-        // TODO: become parameters
-        static const int line_wrap = 60;
-        static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
-
-        if (!is_first) {
+        if (i > 0) {
             fout << ",";
         }
 
         // background text
         fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'";
 
-        is_first = false;
+        bool is_first = true;
 
         for (int j = 0; j < n; ++j) {
             const auto & token = tokens[j];
@@ -654,17 +380,6 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
                     }
 
                     ncnt += txt.size();
-
-                    if (ncnt > line_wrap) {
-                        if (k < j) {
-                            txt_bg = "> ";
-                            txt_fg = "> ";
-                            txt_ul = "\\ \\ ";
-                            ncnt = 0;
-                        } else {
-                            break;
-                        }
-                    }
                 }
 
                 ::replace_all(txt_bg, "'", "’");
@@ -673,8 +388,11 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
                 ::replace_all(txt_fg, "\"", "\\\"");
             }
 
-            // background text
-            fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='" << txt_bg << "':enable='between(t," << token.tt0/100.0 << "," << token.tt1/100.0 << ")'";
+            if (is_first) {
+                // background text
+                fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='" << txt_bg << "':enable='between(t," << t0/100.0 << "," << t1/100.0 << ")'";
+                is_first = false;
+            }
 
             // foreground text
             fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2:text='" << txt_fg << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'";
@@ -815,6 +533,10 @@ int main(int argc, char ** argv) {
             wparams.n_max_text_ctx       = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
             wparams.offset_ms            = params.offset_t_ms;
 
+            wparams.token_timestamps     = params.output_wts || params.max_len > 0;
+            wparams.thold_pt             = params.word_thold;
+            wparams.max_len              = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
+
             // this callback is called on each new segment
             if (!wparams.print_realtime) {
                 wparams.new_segment_callback           = whisper_print_segment_callback;
@@ -852,7 +574,7 @@ int main(int argc, char ** argv) {
             // output to WTS file
             if (params.output_wts) {
                 const auto fname_wts = fname_inp + ".wts";
-                output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, pcmf32);
+                output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
             }
         }
     }
index b230d0c0ce3b31e66795d849b9bf69ef43a4943e..02ab5cbc8a63e451527f2b1902dc0457a52d9059 100644 (file)
@@ -418,6 +418,12 @@ struct whisper_context {
     std::vector<whisper_segment> result_all;
 
     std::vector<whisper_token> prompt_past;
+
+    // [EXPERIMENTAL] token-level timestamps data
+    int64_t t_beg;
+    int64_t t_last;
+    whisper_token tid_last;
+    std::vector<float> energy; // PCM signal energy
 };
 
 // load the model from a ggml file
@@ -431,7 +437,7 @@ struct whisper_context {
 //
 // see the convert-pt-to-ggml.py script for details
 //
-bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
+static bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
     fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str());
 
     auto & model = wctx.model;
@@ -1062,7 +1068,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
 //   - n_threads:  number of threads to use
 //   - mel_offset: offset in the mel spectrogram (i.e. audio offset)
 //
-bool whisper_encode(
+static bool whisper_encode(
               whisper_context & wctx,
         const int n_threads,
         const int mel_offset) {
@@ -1448,7 +1454,7 @@ bool whisper_encode(
 //   - n_tokens:   number of tokens in the prompt
 //   - n_past:     number of past tokens to prefix the prompt with
 //
-bool whisper_decode(
+static bool whisper_decode(
               whisper_context & wctx,
         const int n_threads,
         const whisper_token * tokens,
@@ -1811,10 +1817,12 @@ bool whisper_decode(
 }
 
 // the most basic sampling scheme - select the top token
-whisper_token_data whisper_sample_best(
+static whisper_token_data whisper_sample_best(
         const whisper_vocab & vocab,
         const float * probs) {
-    whisper_token_data result;
+    whisper_token_data result = {
+        0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
+    };
 
     int n_logits = vocab.id_to_token.size();
 
@@ -1887,7 +1895,7 @@ whisper_token_data whisper_sample_best(
 }
 
 // samples only from the timestamps tokens
-whisper_vocab::id whisper_sample_timestamp(
+static whisper_vocab::id whisper_sample_timestamp(
         const whisper_vocab & vocab,
         const float * probs) {
     int n_logits = vocab.id_to_token.size();
@@ -1939,7 +1947,7 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
 // naive Discrete Fourier Transform
 // input is real-valued
 // output is complex-valued
-void dft(const std::vector<float> & in, std::vector<float> & out) {
+static void dft(const std::vector<float> & in, std::vector<float> & out) {
     int N = in.size();
 
     out.resize(N*2);
@@ -1963,7 +1971,7 @@ void dft(const std::vector<float> & in, std::vector<float> & out) {
 // poor man's implementation - use something better
 // input is real-valued
 // output is complex-valued
-void fft(const std::vector<float> & in, std::vector<float> & out) {
+static void fft(const std::vector<float> & in, std::vector<float> & out) {
     out.resize(in.size()*2);
 
     int N = in.size();
@@ -2014,7 +2022,7 @@ void fft(const std::vector<float> & in, std::vector<float> & out) {
 }
 
 // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
-bool log_mel_spectrogram(
+static bool log_mel_spectrogram(
     const float * samples,
     const int n_samples,
     const int sample_rate,
@@ -2339,6 +2347,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
                     /*.print_realtime       =*/ false,
                     /*.print_timestamps     =*/ true,
 
+                    /*.token_timestamps     =*/ false,
+                    /*.thold_pt             =*/ 0.01f,
+                    /*.thold_ptsum          =*/ 0.01f,
+                    /*.max_len              =*/ 0,
+
                     /*.language             =*/ "en",
 
                     /*.greedy               =*/ {
@@ -2371,6 +2384,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
                     /*.print_realtime       =*/ false,
                     /*.print_timestamps     =*/ true,
 
+                    /*.token_timestamps     =*/ false,
+                    /*.thold_pt             =*/ 0.01f,
+                    /*.thold_ptsum          =*/ 0.01f,
+                    /*.max_len              =*/ 0,
+
                     /*.language             =*/ "en",
 
                     /*.greedy               =*/ {
@@ -2392,6 +2410,68 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
     return result;
 }
 
+// forward declarations
+static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
+static void whisper_exp_compute_token_level_timestamps(
+        struct whisper_context * ctx,
+        int   i_segment,
+        float thold_pt,
+        float thold_ptsum);
+
+// wrap the last segment to max_len characters
+// returns the number of new segments
+static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) {
+    auto segment = ctx->result_all.back();
+
+    int res = 1;
+    int acc = 0;
+
+    std::string text;
+
+    for (int i = 0; i < (int) segment.tokens.size(); i++) {
+        const auto & token = segment.tokens[i];
+        if (token.id >= whisper_token_eot(ctx)) {
+            continue;
+        }
+
+        const auto txt = whisper_token_to_str(ctx, token.id);
+
+        const int cur = strlen(txt);
+
+        if (acc + cur > max_len && i > 0) {
+            // split here
+            ctx->result_all.back().text = std::move(text);
+            ctx->result_all.back().t1 = token.t0;
+            ctx->result_all.back().tokens.resize(i);
+
+            ctx->result_all.push_back({});
+            ctx->result_all.back().t0 = token.t0;
+            ctx->result_all.back().t1 = segment.t1;
+
+            // add tokens [i, end] to the new segment
+            ctx->result_all.back().tokens.insert(
+                    ctx->result_all.back().tokens.end(),
+                    segment.tokens.begin() + i,
+                    segment.tokens.end());
+
+            acc = 0;
+            text = "";
+
+            segment = ctx->result_all.back();
+            i = -1;
+
+            res++;
+        } else {
+            acc += cur;
+            text += txt;
+        }
+    }
+
+    ctx->result_all.back().text = std::move(text);
+
+    return res;
+}
+
 int whisper_full(
         struct whisper_context * ctx,
         struct whisper_full_params params,
@@ -2408,6 +2488,13 @@ int whisper_full(
         return -1;
     }
 
+    if (params.token_timestamps) {
+        ctx->t_beg = 0;
+        ctx->t_last = 0;
+        ctx->tid_last = 0;
+        ctx->energy = get_signal_energy(samples, n_samples, 32);
+    }
+
     const int seek_start = params.offset_ms/10;
 
     // if length of spectrogram is less than 1s (100 samples), then return
@@ -2557,6 +2644,7 @@ int whisper_full(
             }
         }
 
+        // shrink down to result_len
         tokens_cur.resize(result_len);
 
         for (const auto & r : tokens_cur) {
@@ -2595,8 +2683,19 @@ int whisper_full(
                         for (int j = i0; j <= i; j++) {
                             result_all.back().tokens.push_back(tokens_cur[j]);
                         }
+
+                        int n_new = 1;
+
+                        if (params.token_timestamps) {
+                            whisper_exp_compute_token_level_timestamps(
+                                    ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
+
+                            if (params.max_len > 0) {
+                                n_new = whisper_wrap_segment(ctx, params.max_len);
+                            }
+                        }
                         if (params.new_segment_callback) {
-                            params.new_segment_callback(ctx, params.new_segment_callback_user_data);
+                            params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
                         }
                     }
                     text = "";
@@ -2625,8 +2724,19 @@ int whisper_full(
                 for (int j = i0; j < (int) tokens_cur.size(); j++) {
                     result_all.back().tokens.push_back(tokens_cur[j]);
                 }
+
+                int n_new = 1;
+
+                if (params.token_timestamps) {
+                    whisper_exp_compute_token_level_timestamps(
+                            ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
+
+                    if (params.max_len > 0) {
+                        n_new = whisper_wrap_segment(ctx, params.max_len);
+                    }
+                }
                 if (params.new_segment_callback) {
-                    params.new_segment_callback(ctx, params.new_segment_callback_user_data);
+                    params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
                 }
             }
         }
@@ -2760,7 +2870,7 @@ int whisper_full_parallel(
 
             // call the new_segment_callback for each segment
             if (params.new_segment_callback) {
-                params.new_segment_callback(ctx, params.new_segment_callback_user_data);
+                params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data);
             }
         }
 
@@ -2836,3 +2946,304 @@ const char * whisper_print_system_info() {
 
     return s.c_str();
 }
+
+// =================================================================================================
+
+//
+// Experimental stuff below
+//
+// Not sure if these should be part of the library at all, because the quality of the results is not
+// guaranteed. Might get removed at some point unless a robust algorithm implementation is found
+//
+
+// =================================================================================================
+
+//
+// token-level timestamps
+//
+
+static int timestamp_to_sample(int64_t t, int n_samples) {
+    return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
+}
+
+static int64_t sample_to_timestamp(int i_sample) {
+    return (100*i_sample)/WHISPER_SAMPLE_RATE;
+}
+
+// a cost-function / heuristic that is high for text that takes longer to pronounce
+// obviously, can be improved
+static float voice_length(const std::string & text) {
+    float res = 0.0f;
+
+    for (size_t i = 0; i < text.size(); ++i) {
+        if (text[i] == ' ') {
+            res += 0.01f;
+        } else if (text[i] == ',') {
+            res += 2.00f;
+        } else if (text[i] == '.') {
+            res += 3.00f;
+        } else if (text[i] == '!') {
+            res += 3.00f;
+        } else if (text[i] == '?') {
+            res += 3.00f;
+        } else if (text[i] >= '0' && text[i] <= '9') {
+            res += 3.00f;
+        } else {
+            res += 1.00f;
+        }
+    }
+
+    return res;
+}
+
+// average the fabs of the signal
+static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window) {
+    const int hw = n_samples_per_half_window;
+
+    std::vector<float> result(n_samples);
+
+    for (int i = 0; i < n_samples; i++) {
+        float sum = 0;
+        for (int j = -hw; j <= hw; j++) {
+            if (i + j >= 0 && i + j < n_samples) {
+                sum += fabs(signal[i + j]);
+            }
+        }
+        result[i] = sum/(2*hw + 1);
+    }
+
+    return result;
+}
+
+static void whisper_exp_compute_token_level_timestamps(
+        struct whisper_context * ctx,
+        int   i_segment,
+        float thold_pt,
+        float thold_ptsum) {
+    auto & segment = ctx->result_all[i_segment];
+    auto & tokens  = segment.tokens;
+
+    const int n_samples = ctx->energy.size();
+
+    if (n_samples == 0) {
+        fprintf(stderr, "%s: no signal data available\n", __func__);
+        return;
+    }
+
+    const int64_t t0 = segment.t0;
+    const int64_t t1 = segment.t1;
+
+    const int s0 = timestamp_to_sample(t0, n_samples);
+    const int s1 = timestamp_to_sample(t1, n_samples);
+
+    const int n = tokens.size();
+
+    if (n == 0) {
+        return;
+    }
+
+    if (n == 1) {
+        tokens[0].t0 = t0;
+        tokens[0].t1 = t1;
+
+        return;
+    }
+
+    auto & t_beg    = ctx->t_beg;
+    auto & t_last   = ctx->t_last;
+    auto & tid_last = ctx->tid_last;
+
+    for (int j = 0; j < n; ++j) {
+        auto & token = tokens[j];
+
+        if (j == 0) {
+            if (token.id == whisper_token_beg(ctx)) {
+                tokens[j    ].t0 = t0;
+                tokens[j    ].t1 = t0;
+                tokens[j + 1].t0 = t0;
+
+                t_beg    = t0;
+                t_last   = t0;
+                tid_last = whisper_token_beg(ctx);
+            } else {
+                tokens[j    ].t0 = t_last;
+            }
+        }
+
+        const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx));
+
+        tokens[j].id    = token.id;
+        tokens[j].tid   = token.tid;
+        tokens[j].p     = token.p;
+        tokens[j].pt    = token.pt;
+        tokens[j].ptsum = token.ptsum;
+
+        tokens[j].vlen = voice_length(whisper_token_to_str(ctx, token.id));
+
+        if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
+            if (j > 0) {
+                tokens[j - 1].t1 = tt;
+            }
+            tokens[j].t0 = tt;
+            tid_last = token.tid;
+        }
+    }
+
+    tokens[n - 2].t1 = t1;
+    tokens[n - 1].t0 = t1;
+    tokens[n - 1].t1 = t1;
+
+    t_last = t1;
+
+    // find intervals of tokens with unknown timestamps
+    // fill the timestamps by proportionally splitting the interval based on the token voice lengths
+    {
+        int p0 = 0;
+        int p1 = 0;
+
+        while (true) {
+            while (p1 < n && tokens[p1].t1 < 0) {
+                p1++;
+            }
+
+            if (p1 >= n) {
+                p1--;
+            }
+
+            if (p1 > p0) {
+                double psum = 0.0;
+                for (int j = p0; j <= p1; j++) {
+                    psum += tokens[j].vlen;
+                }
+
+                //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
+
+                const double dt = tokens[p1].t1 - tokens[p0].t0;
+
+                // split the time proportionally to the voice length
+                for (int j = p0 + 1; j <= p1; j++) {
+                    const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum;
+
+                    tokens[j - 1].t1 = ct;
+                    tokens[j    ].t0 = ct;
+                }
+            }
+
+            p1++;
+            p0 = p1;
+            if (p1 >= n) {
+                break;
+            }
+        }
+    }
+
+    // fix up (just in case)
+    for (int j = 0; j < n - 1; j++) {
+        if (tokens[j].t1 < 0) {
+            tokens[j + 1].t0 = tokens[j].t1;
+        }
+
+        if (j > 0) {
+            if (tokens[j - 1].t1 > tokens[j].t0) {
+                tokens[j].t0 = tokens[j - 1].t1;
+                tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1);
+            }
+        }
+    }
+
+    // VAD
+    // expand or contract tokens based on voice activity
+    {
+        const int hw = WHISPER_SAMPLE_RATE/8;
+
+        for (int j = 0; j < n; j++) {
+            if (tokens[j].id >= whisper_token_eot(ctx)) {
+                continue;
+            }
+
+            int s0 = timestamp_to_sample(tokens[j].t0, n_samples);
+            int s1 = timestamp_to_sample(tokens[j].t1, n_samples);
+
+            const int ss0 = std::max(s0 - hw, 0);
+            const int ss1 = std::min(s1 + hw, n_samples);
+
+            const int ns = ss1 - ss0;
+
+            float sum = 0.0f;
+
+            for (int k = ss0; k < ss1; k++) {
+                sum += ctx->energy[k];
+            }
+
+            const float thold = 0.5*sum/ns;
+
+            {
+                int k = s0;
+                if (ctx->energy[k] > thold && j > 0) {
+                    while (k > 0 && ctx->energy[k] > thold) {
+                        k--;
+                    }
+                    tokens[j].t0 = sample_to_timestamp(k);
+                    if (tokens[j].t0 < tokens[j - 1].t1) {
+                        tokens[j].t0 = tokens[j - 1].t1;
+                    } else {
+                        s0 = k;
+                    }
+                } else {
+                    while (ctx->energy[k] < thold && k < s1) {
+                        k++;
+                    }
+                    s0 = k;
+                    tokens[j].t0 = sample_to_timestamp(k);
+                }
+            }
+
+            {
+                int k = s1;
+                if (ctx->energy[k] > thold) {
+                    while (k < n_samples - 1 && ctx->energy[k] > thold) {
+                        k++;
+                    }
+                    tokens[j].t1 = sample_to_timestamp(k);
+                    if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) {
+                        tokens[j].t1 = tokens[j + 1].t0;
+                    } else {
+                        s1 = k;
+                    }
+                } else {
+                    while (ctx->energy[k] < thold && k > s0) {
+                        k--;
+                    }
+                    s1 = k;
+                    tokens[j].t1 = sample_to_timestamp(k);
+                }
+            }
+        }
+    }
+
+    // fixed token expand (optional)
+    //{
+    //    const int t_expand = 0;
+
+    //    for (int j = 0; j < n; j++) {
+    //        if (j > 0) {
+    //            tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
+    //        }
+    //        if (j < n - 1) {
+    //            tokens[j].t1 = tokens[j].t1 + t_expand;
+    //        }
+    //    }
+    //}
+
+    // debug info
+    //for (int j = 0; j < n; ++j) {
+    //    const auto & token = tokens[j];
+    //    const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]";
+    //    printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
+    //            tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(ctx, token.id));
+
+    //    if (tokens[j].id >= whisper_token_eot(ctx)) {
+    //        continue;
+    //    }
+    //}
+}
index 5d7c40d00ee452a6b6633cd3875f9d1d9699e41d..57ea5db8bf39578a43798fcbc13159bb93d7df3b 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -68,14 +68,21 @@ extern "C" {
 
     typedef int whisper_token;
 
-    struct whisper_token_data {
+    typedef struct whisper_token_data {
         whisper_token id;  // token id
         whisper_token tid; // forced timestamp token id
 
         float p;     // probability of the token
         float pt;    // probability of the timestamp token
         float ptsum; // sum of probabilities of all timestamp tokens
-    };
+
+        // token-level timestamp data
+        // do not use if you haven't computed token-level timestamps
+        int64_t t0; // start time of the token
+        int64_t t1; //   end time of the token
+
+        float vlen; // voice length of the token
+    } whisper_token_data;
 
     // Allocates all memory needed for the model and loads the model from the given file.
     // Returns NULL on failure.
@@ -129,7 +136,7 @@ extern "C" {
     // You can also implement your own sampling method using the whisper_get_probs() function.
     // whisper_sample_best() returns the token with the highest probability
     // whisper_sample_timestamp() returns the most probable timestamp token
-    WHISPER_API struct whisper_token_data whisper_sample_best(struct whisper_context * ctx);
+    WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
     WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
 
     // Return the id of the specified language, returns -1 if not found
@@ -172,7 +179,7 @@ extern "C" {
     // Text segment callback
     // Called on every newly generated text segment
     // Use the whisper_full_...() functions to obtain the text segments
-    typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, void * user_data);
+    typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data);
 
     struct whisper_full_params {
         enum whisper_sampling_strategy strategy;
@@ -188,6 +195,12 @@ extern "C" {
         bool print_realtime;
         bool print_timestamps;
 
+        // [EXPERIMENTAL] token-level timestamps
+        bool  token_timestamps; // enable token-level timestamps
+        float thold_pt;         // timestamp token probability threshold (~0.01)
+        float thold_ptsum;      // timestamp token sum probability threshold (~0.01)
+        int   max_len;          // max segment length in characters
+
         const char * language;
 
         struct {
@@ -244,7 +257,7 @@ extern "C" {
 
     // Get token data for the specified token in the specified segment.
     // This contains probabilities, timestamps, etc.
-    WHISPER_API struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token);
+    WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token);
 
     // Get the probability of the specified token in the specified segment.
     WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);