]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
main : add diarization support for all current output types (#1031)
authorColin <redacted>
Sun, 25 Jun 2023 12:07:57 +0000 (07:07 -0500)
committerGitHub <redacted>
Sun, 25 Jun 2023 12:07:57 +0000 (15:07 +0300)
Co-authored-by: Georgi Gerganov <redacted>
examples/main/main.cpp

index 07a7591fe2151dc3c3b649c545c8f01b504aa288..ff62f74b88761ba1c7809caf233838957f5d34ac 100644 (file)
@@ -210,6 +210,39 @@ struct whisper_print_user_data {
     const std::vector<std::vector<float>> * pcmf32s;
 };
 
+std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) {
+    std::string speaker = "";
+    const int64_t n_samples = pcmf32s[0].size();
+
+    const int64_t is0 = timestamp_to_sample(t0, n_samples);
+    const int64_t is1 = timestamp_to_sample(t1, n_samples);
+
+    double energy0 = 0.0f;
+    double energy1 = 0.0f;
+
+    for (int64_t j = is0; j < is1; j++) {
+        energy0 += fabs(pcmf32s[0][j]);
+        energy1 += fabs(pcmf32s[1][j]);
+    }
+
+    if (energy0 > 1.1*energy1) {
+        speaker = "0";
+    } else if (energy1 > 1.1*energy0) {
+        speaker = "1";
+    } else {
+        speaker = "?";
+    }
+
+    //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, speaker = %s\n", is0, is1, energy0, energy1, speaker.c_str());
+
+    if (!id_only) {
+        speaker.insert(0, "(speaker ");
+        speaker.append(")");
+    }
+
+    return speaker;
+}
+
 void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) {
     const auto & params  = *((whisper_print_user_data *) user_data)->params;
     const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
@@ -239,28 +272,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
         }
 
         if (params.diarize && pcmf32s.size() == 2) {
-            const int64_t n_samples = pcmf32s[0].size();
-
-            const int64_t is0 = timestamp_to_sample(t0, n_samples);
-            const int64_t is1 = timestamp_to_sample(t1, n_samples);
-
-            double energy0 = 0.0f;
-            double energy1 = 0.0f;
-
-            for (int64_t j = is0; j < is1; j++) {
-                energy0 += fabs(pcmf32s[0][j]);
-                energy1 += fabs(pcmf32s[1][j]);
-            }
-
-            if (energy0 > 1.1*energy1) {
-                speaker = "(speaker 0)";
-            } else if (energy1 > 1.1*energy0) {
-                speaker = "(speaker 1)";
-            } else {
-                speaker = "(speaker ?)";
-            }
-
-            //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
+            speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
         }
 
         if (params.print_colors) {
@@ -294,7 +306,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
     }
 }
 
-bool output_txt(struct whisper_context * ctx, const char * fname) {
+bool output_txt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
     std::ofstream fout(fname);
     if (!fout.is_open()) {
         fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@@ -306,13 +318,22 @@ bool output_txt(struct whisper_context * ctx, const char * fname) {
     const int n_segments = whisper_full_n_segments(ctx);
     for (int i = 0; i < n_segments; ++i) {
         const char * text = whisper_full_get_segment_text(ctx, i);
-        fout << text << "\n";
+        std::string speaker = "";
+
+        if (params.diarize && pcmf32s.size() == 2)
+        {
+            const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
+            const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+            speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
+        }
+
+        fout << speaker << text << "\n";
     }
 
     return true;
 }
 
-bool output_vtt(struct whisper_context * ctx, const char * fname) {
+bool output_vtt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
     std::ofstream fout(fname);
     if (!fout.is_open()) {
         fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@@ -328,15 +349,23 @@ bool output_vtt(struct whisper_context * ctx, const char * fname) {
         const char * text = whisper_full_get_segment_text(ctx, i);
         const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
         const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+        std::string speaker = "";
+
+        if (params.diarize && pcmf32s.size() == 2)
+        {
+            speaker = estimate_diarization_speaker(pcmf32s, t0, t1, true);
+            speaker.insert(0, "<v Speaker");
+            speaker.append(">");
+        }
 
         fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
-        fout << text << "\n\n";
+        fout << speaker << text << "\n\n";
     }
 
     return true;
 }
 
-bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params) {
+bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
     std::ofstream fout(fname);
     if (!fout.is_open()) {
         fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@@ -350,10 +379,16 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
         const char * text = whisper_full_get_segment_text(ctx, i);
         const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
         const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+        std::string speaker = "";
+
+        if (params.diarize && pcmf32s.size() == 2)
+        {
+            speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
+        }
 
         fout << i + 1 + params.offset_n << "\n";
         fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n";
-        fout << text << "\n\n";
+        fout << speaker << text << "\n\n";
     }
 
     return true;
@@ -390,7 +425,7 @@ char *escape_double_quotes_and_backslashes(const char *str) {
     return escaped;
 }
 
-bool output_csv(struct whisper_context * ctx, const char * fname) {
+bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
     std::ofstream fout(fname);
     if (!fout.is_open()) {
         fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@@ -400,7 +435,13 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
     fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
 
     const int n_segments = whisper_full_n_segments(ctx);
-    fout << "start,end,text\n";
+    fout << "start,end,";
+    if (params.diarize && pcmf32s.size() == 2)
+    {
+        fout << "speaker,";
+    }
+    fout << "text\n";
+
     for (int i = 0; i < n_segments; ++i) {
         const char * text = whisper_full_get_segment_text(ctx, i);
         const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
@@ -408,13 +449,18 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
         char * text_escaped = escape_double_quotes_and_backslashes(text);
 
         //need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
-        fout << 10 * t0 << "," << 10 * t1 << ",\"" << text_escaped    << "\"\n";
+        fout << 10 * t0 << "," << 10 * t1 << ",";
+        if (params.diarize && pcmf32s.size() == 2)
+        {
+            fout << estimate_diarization_speaker(pcmf32s, t0, t1, true) << ",";
+        }
+        fout << "\"" << text_escaped << "\"\n";
     }
 
     return true;
 }
 
-bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params) {
+bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
     std::ofstream fout(fname);
     int indent = 0;
 
@@ -530,7 +576,11 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
                         value_i("from", t0 * 10, false);
                         value_i("to", t1 * 10, true);
                     end_obj(false);
-                    value_s("text", text, true);
+                    value_s("text", text, !params.diarize);
+
+                    if (params.diarize && pcmf32s.size() == 2) {
+                        value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true);
+                    }
                 end_obj(i == (n_segments - 1));
             }
 
@@ -542,7 +592,7 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
 // karaoke video generation
 // outputs a bash script that uses ffmpeg to generate a video with the subtitles
 // TODO: font parameter adjustments
-bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) {
+bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec, std::vector<std::vector<float>> pcmf32s) {
     std::ofstream fout(fname);
 
     fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
@@ -579,6 +629,11 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
         fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'";
 
         bool is_first = true;
+        std::string speaker = "";
+
+        if (params.diarize && pcmf32s.size() == 2) {
+            speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
+        }
 
         for (int j = 0; j < n; ++j) {
             const auto & token = tokens[j];
@@ -587,13 +642,19 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
                 continue;
             }
 
-            std::string txt_bg;
-            std::string txt_fg; // highlight token
-            std::string txt_ul; // underline
+            std::string txt_bg = "";
+            std::string txt_fg = ""; // highlight token
+            std::string txt_ul = ""; // underline
 
-            txt_bg = "> ";
-            txt_fg = "> ";
-            txt_ul = "\\ \\ ";
+            if (params.diarize && pcmf32s.size() == 2) {
+                txt_bg = speaker;
+                txt_fg = speaker;
+                txt_ul = "\\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ ";
+            }
+
+            txt_bg.append("> ");
+            txt_fg.append("> ");
+            txt_ul.append("\\ \\ ");
 
             {
                 for (int k = 0; k < n; ++k) {
@@ -656,8 +717,7 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
     return true;
 }
 
-bool output_lrc(struct whisper_context * ctx, const char * fname) {
-
+bool output_lrc(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
     std::ofstream fout(fname);
     if (!fout.is_open()) {
         fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@@ -682,8 +742,16 @@ bool output_lrc(struct whisper_context * ctx, const char * fname) {
         char buf[16];
         snprintf(buf, sizeof(buf), "%02d:%02d.%02d", (int) min, (int) sec, (int) ( msec / 10));
         std::string timestamp_lrc = std::string(buf);
+        std::string speaker = "";
+
+        if (params.diarize && pcmf32s.size() == 2)
+        {
+            const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
+            const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+            speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
+        }
 
-        fout <<  '[' << timestamp_lrc << ']' << text << "\n";
+        fout <<  '[' << timestamp_lrc << ']' << speaker << text << "\n";
     }
 
     return true;
@@ -828,43 +896,43 @@ int main(int argc, char ** argv) {
             // output to text file
             if (params.output_txt) {
                 const auto fname_txt = fname_out + ".txt";
-                output_txt(ctx, fname_txt.c_str());
+                output_txt(ctx, fname_txt.c_str(), params, pcmf32s);
             }
 
             // output to VTT file
             if (params.output_vtt) {
                 const auto fname_vtt = fname_out + ".vtt";
-                output_vtt(ctx, fname_vtt.c_str());
+                output_vtt(ctx, fname_vtt.c_str(), params, pcmf32s);
             }
 
             // output to SRT file
             if (params.output_srt) {
                 const auto fname_srt = fname_out + ".srt";
-                output_srt(ctx, fname_srt.c_str(), params);
+                output_srt(ctx, fname_srt.c_str(), params, pcmf32s);
             }
 
             // output to WTS file
             if (params.output_wts) {
                 const auto fname_wts = fname_out + ".wts";
-                output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
+                output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE, pcmf32s);
             }
 
             // output to CSV file
             if (params.output_csv) {
                 const auto fname_csv = fname_out + ".csv";
-                output_csv(ctx, fname_csv.c_str());
+                output_csv(ctx, fname_csv.c_str(), params, pcmf32s);
             }
 
             // output to JSON file
             if (params.output_jsn) {
                 const auto fname_jsn = fname_out + ".json";
-                output_json(ctx, fname_jsn.c_str(), params);
+                output_json(ctx, fname_jsn.c_str(), params, pcmf32s);
             }
 
             // output to LRC file
             if (params.output_lrc) {
                 const auto fname_lrc = fname_out + ".lrc";
-                output_lrc(ctx, fname_lrc.c_str());
+                output_lrc(ctx, fname_lrc.c_str(), params, pcmf32s);
             }
         }
     }