]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
main : add option for word-leve timestamps (very experimental)
authorGeorgi Gerganov <redacted>
Sun, 30 Oct 2022 08:05:58 +0000 (10:05 +0200)
committerGeorgi Gerganov <redacted>
Sun, 30 Oct 2022 15:06:57 +0000 (17:06 +0200)
examples/main/main.cpp
whisper.cpp
whisper.h

index f8d05ba67598379b75c6212fa16a8e1e861759a4..d413828d523f481bd2d5cca99697cad21875f5cb 100644 (file)
@@ -36,6 +36,40 @@ std::string to_timestamp(int64_t t, bool comma = false) {
     return std::string(buf);
 }
 
+void replace_all(std::string & s, const std::string & search, const std::string & replace) {
+    for (size_t pos = 0; ; pos += replace.length()) {
+        pos = s.find(search, pos);
+        if (pos == std::string::npos) break;
+        s.erase(pos, search.length());
+        s.insert(pos, replace);
+    }
+}
+
+// a cost-function that is high for text that takes longer to pronounce
+float voice_length(const std::string & text) {
+    float res = 0.0f;
+
+    for (size_t i = 0; i < text.size(); ++i) {
+        if (text[i] == ' ') {
+            res += 0.01f;
+        } else if (text[i] == ',') {
+            res += 2.00f;
+        } else if (text[i] == '.') {
+            res += 3.00f;
+        } else if (text[i] == '!') {
+            res += 3.00f;
+        } else if (text[i] == '?') {
+            res += 3.00f;
+        } else if (text[i] >= '0' && text[i] <= '9') {
+            res += 3.00f;
+        } else {
+            res += 1.00f;
+        }
+    }
+
+    return res;
+}
+
 // command-line parameters
 struct whisper_params {
     int32_t seed         = -1; // RNG seed, not used currently
@@ -45,11 +79,14 @@ struct whisper_params {
     int32_t offset_n     = 0;
     int32_t max_context  = -1;
 
+    float word_thold = 0.01f;
+
     bool verbose              = false;
     bool translate            = false;
     bool output_txt           = false;
     bool output_vtt           = false;
     bool output_srt           = false;
+    bool output_wts           = false;
     bool print_special_tokens = false;
     bool print_colors         = false;
     bool no_timestamps        = false;
@@ -83,6 +120,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
             params.offset_n = std::stoi(argv[++i]);
         } else if (arg == "-mc" || arg == "--max-context") {
             params.max_context = std::stoi(argv[++i]);
+        } else if (arg == "-wt" || arg == "--word-thold") {
+            params.word_thold = std::stof(argv[++i]);
         } else if (arg == "-v" || arg == "--verbose") {
             params.verbose = true;
         } else if (arg == "--translate") {
@@ -100,6 +139,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
             params.output_vtt = true;
         } else if (arg == "-osrt" || arg == "--output-srt") {
             params.output_srt = true;
+        } else if (arg == "-owts" || arg == "--output-words") {
+            params.output_wts = true;
         } else if (arg == "-ps" || arg == "--print_special") {
             params.print_special_tokens = true;
         } else if (arg == "-pc" || arg == "--print_colors") {
@@ -135,11 +176,13 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
     fprintf(stderr, "  -ot N,    --offset-t N     time offset in milliseconds (default: %d)\n", params.offset_t_ms);
     fprintf(stderr, "  -on N,    --offset-n N     segment index offset (default: %d)\n", params.offset_n);
     fprintf(stderr, "  -mc N,    --max-context N  maximum number of text context tokens to store (default: max)\n");
+    fprintf(stderr, "  -wt N,    --word-thold N   word timestamp probability threshold (default: %f)\n", params.word_thold);
     fprintf(stderr, "  -v,       --verbose        verbose output\n");
     fprintf(stderr, "            --translate      translate from source language to english\n");
     fprintf(stderr, "  -otxt,    --output-txt     output result in a text file\n");
     fprintf(stderr, "  -ovtt,    --output-vtt     output result in a vtt file\n");
     fprintf(stderr, "  -osrt,    --output-srt     output result in a srt file\n");
+    fprintf(stderr, "  -owts,    --output-words   output word-level timestamps to a text file\n");
     fprintf(stderr, "  -ps,      --print_special  print special tokens\n");
     fprintf(stderr, "  -pc,      --print_colors   print colors\n");
     fprintf(stderr, "  -nt,      --no_timestamps  do not print timestamps\n");
@@ -277,6 +320,367 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
     return true;
 }
 
+// word-level timestamps (experimental)
+// TODO: probably still has bugs, needs refactoring, etc..
+// TODO: auto threshold
+// TODO: extra pass to detect unused speech and assign to tokens
+// TODO: font parameter adjustments
+bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, const std::vector<float> & pcmf32) {
+    if (params.output_wts) {
+        std::vector<float> pcm_avg(pcmf32.size(), 0);
+
+        // average the fabs of the signal
+        {
+            const int hw = 32;
+
+            for (int i = 0; i < pcmf32.size(); i++) {
+                float sum = 0;
+                for (int j = -hw; j <= hw; j++) {
+                    if (i + j >= 0 && i + j < pcmf32.size()) {
+                        sum += fabs(pcmf32[i + j]);
+                    }
+                }
+                pcm_avg[i] = sum/(2*hw + 1);
+            }
+        }
+
+        struct token_info {
+            int64_t t0 = -1;
+            int64_t t1 = -1;
+
+            int64_t tt0 = -1;
+            int64_t tt1 = -1;
+
+            whisper_token id;
+            whisper_token tid;
+
+            float p     = 0.0f;
+            float pt    = 0.0f;
+            float ptsum = 0.0f;
+
+            std::string text;
+            float vlen = 0.0f; // voice length of this token
+        };
+
+        int64_t t_beg  = 0;
+        int64_t t_last = 0;
+
+        whisper_token tid_last = 0;
+
+        std::ofstream fout(fname);
+
+        fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
+
+        fout << "!/bin/bash" << "\n";
+        fout << "\n";
+
+        fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE << ":rate=25:color=black -vf \"";
+
+        bool is_first = true;
+
+        for (int i = 0; i < whisper_full_n_segments(ctx); i++) {
+            const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
+            const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+
+            const char *text = whisper_full_get_segment_text(ctx, i);
+
+            const int s0 = std::max(0,                   (int) (t0*WHISPER_SAMPLE_RATE/100));
+            const int s1 = std::min((int) pcmf32.size(), (int) (t1*WHISPER_SAMPLE_RATE/100));
+
+            const int n = whisper_full_n_tokens(ctx, i);
+
+            std::vector<token_info> tokens(n);
+
+            if (n <= 1) {
+                continue;
+            }
+
+            for (int j = 0; j < n; ++j) {
+                struct whisper_token_data token = whisper_full_get_token_data(ctx, i, j);
+
+                if (j == 0) {
+                    if (token.id == whisper_token_beg(ctx)) {
+                        tokens[j    ].t0 = t0;
+                        tokens[j    ].t1 = t0;
+                        tokens[j + 1].t0 = t0;
+
+                        t_beg  = t0;
+                        t_last = t0;
+                        tid_last = whisper_token_beg(ctx);
+                    } else {
+                        tokens[j    ].t0 = t_last;
+                    }
+                }
+
+                const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx));
+
+                tokens[j].id    = token.id;
+                tokens[j].tid   = token.tid;
+                tokens[j].p     = token.p;
+                tokens[j].pt    = token.pt;
+                tokens[j].ptsum = token.ptsum;
+
+                tokens[j].text = whisper_token_to_str(ctx, token.id);
+                //tokens[j].vlen = tokens[j].pt;
+                tokens[j].vlen = voice_length(tokens[j].text);
+
+                if (token.pt > params.word_thold && token.ptsum > 0.01 && token.tid > tid_last) {
+                    if (j > 0) {
+                        tokens[j - 1].t1 = tt;
+                    }
+                    tokens[j].t0 = tt;
+                    tid_last = token.tid;
+                }
+            }
+
+            tokens[n - 2].t1 = t1;
+            tokens[n - 1].t0 = t1;
+            tokens[n - 1].t1 = t1;
+
+            t_last = t1;
+
+            int p0 = 0;
+            int p1 = 0;
+            while (true) {
+                while (p1 < n && tokens[p1].t1 < 0) {
+                    p1++;
+                }
+
+                if (p1 >= n) {
+                    p1--;
+                }
+
+                if (p1 > p0) {
+                    double psum = 0.0;
+                    for (int j = p0; j <= p1; j++) {
+                        psum += tokens[j].vlen;
+                    }
+
+                    //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
+
+                    const double dt = tokens[p1].t1 - tokens[p0].t0;
+
+                    for (int j = p0 + 1; j <= p1; j++) {
+                        const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum;
+                        //const double ct = tokens[j - 1].t0 + (dt*(j - p0))/(p1 - p0 + 1);
+                        //const double ct = tokens[p0].t0 + (dt*(j - p0))/(p1 - p0 + 1);
+
+                        tokens[j - 1].t1 = ct;
+                        tokens[j    ].t0 = ct;
+                    }
+                }
+
+                p1++;
+                p0 = p1;
+                if (p1 >= n) {
+                    break;
+                }
+            }
+
+            for (int j = 0; j < n - 1; j++) {
+                if (tokens[j].t1 < 0) {
+                    tokens[j + 1].t0 = tokens[j].t1;
+                }
+
+                tokens[j].tt0 = tokens[j].t0;
+                tokens[j].tt1 = tokens[j].t1;
+            }
+
+            // VAD
+            {
+                const int hw = WHISPER_SAMPLE_RATE; // take one second of audio around the token
+
+                for (int j = 0; j < n; j++) {
+                    const int64_t t0 = tokens[j].t0;
+                    const int64_t t1 = tokens[j].t1;
+
+                    int s0 = std::max(0,                        (int) (t0*WHISPER_SAMPLE_RATE/100));
+                    int s1 = std::min((int) pcmf32.size() - 1,  (int) (t1*WHISPER_SAMPLE_RATE/100));
+
+                    const int ss0 = std::max(0,                       (int) (t0*WHISPER_SAMPLE_RATE/100) - hw);
+                    const int ss1 = std::min((int) pcmf32.size() - 1, (int) (t1*WHISPER_SAMPLE_RATE/100) + hw);
+
+                    const int n = ss1 - ss0;
+
+                    float sum = 0.0f;
+                    for (int k = ss0; k < ss1; k++) {
+                        sum += pcm_avg[k];
+                    }
+
+                    const float avg = sum/n;
+
+                    const float thold = 0.5*avg;
+
+                    {
+                        int k = s0;
+                        if (pcm_avg[k] > thold && j > 0) {
+                            while (k > 0 && pcm_avg[k] > thold) {
+                                k--;
+                            }
+                            tokens[j].t0 = (int64_t) (100*k/WHISPER_SAMPLE_RATE);
+                            if (tokens[j].t0 < tokens[j - 1].t1) {
+                                tokens[j].t0 = tokens[j - 1].t1;
+                            } else {
+                                s0 = k;
+                            }
+                        } else {
+                            while (pcm_avg[k] < thold && k < s1) {
+                                k++;
+                            }
+                            s0 = k;
+                            tokens[j].t0 = 100*k/WHISPER_SAMPLE_RATE;
+                        }
+                    }
+
+                    {
+                        int k = s1;
+                        if (pcm_avg[k] > thold) {
+                            while (k < (int) pcmf32.size() - 1 && pcm_avg[k] > thold) {
+                                k++;
+                            }
+                            tokens[j].t1 = 100*k/WHISPER_SAMPLE_RATE;
+                            if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) {
+                                tokens[j].t1 = tokens[j + 1].t0;
+                            } else {
+                                s1 = k;
+                            }
+                        } else {
+                            while (pcm_avg[k] < thold && k > s0) {
+                                k--;
+                            }
+                            s1 = k;
+                            tokens[j].t1 = 100*k/WHISPER_SAMPLE_RATE;
+                        }
+                    }
+                }
+            }
+
+            const int t_expand = 0;
+
+            for (int j = 0; j < n; j++) {
+                if (j > 0) {
+                    tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
+                }
+                if (j < n - 1) {
+                    tokens[j].t1 = tokens[j].t1 + t_expand;
+                }
+            }
+
+            for (int j = 0; j < n; ++j) {
+                const auto & token = tokens[j];
+                const auto tt = token.pt > params.word_thold && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]";
+                printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
+                        tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, token.text.c_str());
+
+                if (tokens[j].id >= whisper_token_eot(ctx)) {
+                    continue;
+                }
+
+                //printf("[%s --> %s] %s\n", to_timestamp(token.t0).c_str(), to_timestamp(token.t1).c_str(), whisper_token_to_str(ctx, token.id));
+
+                //fout << "# " << to_timestamp(token.t0) << " --> " << to_timestamp(token.t1) << " " << whisper_token_to_str(ctx, token.id) << "\n";
+            }
+
+            static const int line_wrap = 60;
+            static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
+
+            if (!is_first) {
+                fout << ",";
+            }
+
+            // background text
+            fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'";
+
+            is_first = false;
+
+            for (int j = 0; j < n; ++j) {
+                const auto & token = tokens[j];
+
+                if (tokens[j].id >= whisper_token_eot(ctx)) {
+                    continue;
+                }
+
+                std::string txt_bg;
+                std::string txt_fg; // highlight token
+                std::string txt_ul; // underline
+
+                txt_bg = "> ";
+                txt_fg = "> ";
+                txt_ul = "\\ \\ ";
+
+                {
+                    int ncnt = 0;
+                    for (int k = 0; k < n; ++k) {
+                        const auto & token2 = tokens[k];
+
+                        if (tokens[k].id >= whisper_token_eot(ctx)) {
+                            continue;
+                        }
+
+                        const std::string txt = whisper_token_to_str(ctx, token2.id);
+
+                        txt_bg += txt;
+
+                        if (k == j) {
+                            for (int l = 0; l < (int) txt.size(); ++l) {
+                                txt_fg += txt[l];
+                                txt_ul += "_";
+                            }
+                            txt_fg += "|";
+                        } else {
+                            for (int l = 0; l < (int) txt.size(); ++l) {
+                                txt_fg += "\\ ";
+                                txt_ul += "\\ ";
+                            }
+                        }
+
+                        ncnt += txt.size();
+
+                        if (ncnt > line_wrap) {
+                            if (k < j) {
+                                txt_bg = "> ";
+                                txt_fg = "> ";
+                                txt_ul = "\\ \\ ";
+                                ncnt = 0;
+                            } else {
+                                break;
+                            }
+                        }
+                    }
+
+                    ::replace_all(txt_bg, "'", "’");
+                    ::replace_all(txt_bg, "\"", "\\\"");
+                    ::replace_all(txt_fg, "'", "’");
+                    ::replace_all(txt_fg, "\"", "\\\"");
+                }
+
+                // background text
+                fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='" << txt_bg << "':enable='between(t," << token.tt0/100.0 << "," << token.tt1/100.0 << ")'";
+
+                // foreground text
+                fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2:text='" << txt_fg << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'";
+
+                // underline
+                fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2+16:text='" << txt_ul << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'";
+            }
+        }
+
+        fout << "\" -c:v libx264 -pix_fmt yuv420p -y " << fname_inp << ".mp4" << "\n";
+
+        fout << "\n\n";
+        fout << "echo \"Your video has been saved to " << fname_inp << ".mp4\"" << "\n";
+        fout << "\n";
+        fout << "echo \"  ffplay " << fname_inp << ".mp4\"\n";
+        fout << "\n";
+
+        fout.close();
+
+        fprintf(stderr, "%s: run 'source %s' to generate karaoke video\n", __func__, fname);
+    }
+
+    return true;
+}
+
 int main(int argc, char ** argv) {
     whisper_params params;
 
@@ -403,7 +807,10 @@ int main(int argc, char ** argv) {
                 fprintf(stderr, "%s: failed to process audio\n", argv[0]);
                 return 8;
             }
+        }
 
+        // output stuff
+        {
             printf("\n");
 
             // output to text file
@@ -423,6 +830,12 @@ int main(int argc, char ** argv) {
                 const auto fname_srt = fname_inp + ".srt";
                 output_srt(ctx, fname_srt.c_str(), params);
             }
+
+            // output to WTS file
+            if (params.output_wts) {
+                const auto fname_wts = fname_inp + ".wts";
+                output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, pcmf32);
+            }
         }
     }
 
index 3663259976bfd73f6c3f57730d3e19a99fcb9c85..7f2b49b893ca7dad24c6f12e08724da912c05517 100644 (file)
@@ -1843,7 +1843,8 @@ whisper_token_data whisper_sample_best(
             }
         }
 
-        result.pt = max_ts/(sum_ts + 1e-6);
+        result.pt = max_ts/(sum_ts + 1e-10);
+        result.ptsum = sum_ts;
     }
 
     // find the top K tokens
@@ -2518,7 +2519,10 @@ int whisper_full(
                 prompt.push_back(token.id);
                 tokens_cur.push_back(token);
 
-                //printf("%s: %s\n", __func__, ctx->vocab.id_to_token[id].c_str());
+                //{
+                //    const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
+                //    printf("%s: %10s %6.3f '%s'\n", __func__, tt.c_str(), token.pt, ctx->vocab.id_to_token[token.id].c_str());
+                //}
 
                 // end of text token
                 if (token.id == whisper_token_eot(ctx)) {
@@ -2803,6 +2807,10 @@ whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segm
     return ctx->result_all[i_segment].tokens[i_token].id;
 }
 
+struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) {
+    return ctx->result_all[i_segment].tokens[i_token];
+}
+
 float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
     return ctx->result_all[i_segment].tokens[i_token].p;
 }
index b49f35a848b5269a8a25bf9930f8a4bf6129d076..5d7c40d00ee452a6b6633cd3875f9d1d9699e41d 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -72,8 +72,9 @@ extern "C" {
         whisper_token id;  // token id
         whisper_token tid; // forced timestamp token id
 
-        float p;  // probability of the token
-        float pt; // probability of the timestamp token
+        float p;     // probability of the token
+        float pt;    // probability of the timestamp token
+        float ptsum; // sum of probabilities of all timestamp tokens
     };
 
     // Allocates all memory needed for the model and loads the model from the given file.
@@ -241,6 +242,10 @@ extern "C" {
     WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token);
     WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token);
 
+    // Get token data for the specified token in the specified segment.
+    // This contains probabilities, timestamps, etc.
+    WHISPER_API struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token);
+
     // Get the probability of the specified token in the specified segment.
     WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);