]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
main : add stereo-channel-based diarization (#64)
authorGeorgi Gerganov <redacted>
Fri, 25 Nov 2022 20:08:58 +0000 (22:08 +0200)
committerGeorgi Gerganov <redacted>
Fri, 25 Nov 2022 20:08:58 +0000 (22:08 +0200)
Not tested - I don't have stereo dialog audio

examples/main/main.cpp

index 2b9f2e1745d8bbc84f76179958755fd87ec2c903..569404caa49840566a62d3d00a298e9cfeab3f24 100644 (file)
@@ -36,6 +36,10 @@ std::string to_timestamp(int64_t t, bool comma = false) {
     return std::string(buf);
 }
 
+int timestamp_to_sample(int64_t t, int n_samples) {
+    return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
+}
+
 // helper function to replace substrings
 void replace_all(std::string & s, const std::string & search, const std::string & replace) {
     for (size_t pos = 0; ; pos += replace.length()) {
@@ -60,6 +64,7 @@ struct whisper_params {
 
     bool speed_up      = false;
     bool translate     = false;
+    bool diarize       = false;
     bool output_txt    = false;
     bool output_vtt    = false;
     bool output_srt    = false;
@@ -99,6 +104,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-wt"   || arg == "--word-thold")    { params.word_thold    = std::stof(argv[++i]); }
         else if (arg == "-su"   || arg == "--speed-up")      { params.speed_up      = true; }
         else if (arg == "-tr"   || arg == "--translate")     { params.translate     = true; }
+        else if (arg == "-di"   || arg == "--diarize")       { params.diarize       = true; }
         else if (arg == "-otxt" || arg == "--output-txt")    { params.output_txt    = true; }
         else if (arg == "-ovtt" || arg == "--output-vtt")    { params.output_vtt    = true; }
         else if (arg == "-osrt" || arg == "--output-srt")    { params.output_srt    = true; }
@@ -135,6 +141,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
     fprintf(stderr, "  -wt N,    --word-thold N  [%-7.2f] word timestamp probability threshold\n",         params.word_thold);
     fprintf(stderr, "  -su,      --speed-up      [%-7s] speed up audio by x2 (reduced accuracy)\n",        params.speed_up ? "true" : "false");
     fprintf(stderr, "  -tr,      --translate     [%-7s] translate from source language to english\n",      params.translate ? "true" : "false");
+    fprintf(stderr, "  -di,      --diarize       [%-7s] stereo audio diarization\n",                       params.diarize ? "true" : "false");
     fprintf(stderr, "  -otxt,    --output-txt    [%-7s] output result in a text file\n",                   params.output_txt ? "true" : "false");
     fprintf(stderr, "  -ovtt,    --output-vtt    [%-7s] output result in a vtt file\n",                    params.output_vtt ? "true" : "false");
     fprintf(stderr, "  -osrt,    --output-srt    [%-7s] output result in a srt file\n",                    params.output_srt ? "true" : "false");
@@ -148,8 +155,15 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
     fprintf(stderr, "\n");
 }
 
+struct whisper_print_user_data {
+    const whisper_params * params;
+
+    const std::vector<std::vector<float>> * pcmf32s;
+};
+
 void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
-    const whisper_params & params = *(whisper_params *) user_data;
+    const auto & params  = *((whisper_print_user_data *) user_data)->params;
+    const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
 
     const int n_segments = whisper_full_n_segments(ctx);
 
@@ -186,6 +200,33 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
             const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
             const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
 
+            std::string speaker = "";
+
+            if (params.diarize && pcmf32s.size() == 2) {
+                const int64_t n_samples = pcmf32s[0].size();
+
+                const int64_t is0 = timestamp_to_sample(t0, n_samples);
+                const int64_t is1 = timestamp_to_sample(t1, n_samples);
+
+                double energy0 = 0.0f;
+                double energy1 = 0.0f;
+
+                for (int64_t j = is0; j < is1; j++) {
+                    energy0 += fabs(pcmf32s[0][j]);
+                    energy1 += fabs(pcmf32s[1][j]);
+                }
+
+                if (energy0 > 1.1*energy1) {
+                    speaker = "(speaker 0)";
+                } else if (energy1 > 1.1*energy0) {
+                    speaker = "(speaker 1)";
+                } else {
+                    speaker = "(speaker ?)";
+                }
+
+                //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
+            }
+
             if (params.print_colors) {
                 printf("[%s --> %s]  ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
                 for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
@@ -201,13 +242,13 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
 
                     const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
 
-                    printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
+                    printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
                 }
                 printf("\n");
             } else {
                 const char * text = whisper_full_get_segment_text(ctx, i);
 
-                printf("[%s --> %s]  %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
+                printf("[%s --> %s]  %s%s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), speaker.c_str(), text);
             }
         }
     }
@@ -235,7 +276,7 @@ bool output_vtt(struct whisper_context * ctx, const char * fname) {
     std::ofstream fout(fname);
     if (!fout.is_open()) {
         fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
-        return 9;
+        return false;
     }
 
     fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
@@ -425,6 +466,7 @@ int main(int argc, char ** argv) {
         const auto fname_inp = params.fname_inp[f];
 
         std::vector<float> pcmf32; // mono-channel F32 PCM
+        std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
 
         // WAV input
         {
@@ -453,22 +495,27 @@ int main(int argc, char ** argv) {
             }
             else if (drwav_init_file(&wav, fname_inp.c_str(), NULL) == false) {
                 fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
-                return 4;
+                return 5;
             }
 
             if (wav.channels != 1 && wav.channels != 2) {
                 fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
-                return 5;
+                return 6;
+            }
+
+            if (params.diarize && wav.channels != 2 && params.no_timestamps == false) {
+                fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", argv[0], fname_inp.c_str());
+                return 6;
             }
 
             if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
                 fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
-                return 6;
+                return 8;
             }
 
             if (wav.bitsPerSample != 16) {
                 fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
-                return 7;
+                return 9;
             }
 
             const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
@@ -489,6 +536,18 @@ int main(int argc, char ** argv) {
                     pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
                 }
             }
+
+            if (params.diarize) {
+                // convert to stereo, float
+                pcmf32s.resize(2);
+
+                pcmf32s[0].resize(n);
+                pcmf32s[1].resize(n);
+                for (int i = 0; i < n; i++) {
+                    pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
+                    pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
+                }
+            }
         }
 
         // print system information
@@ -540,15 +599,17 @@ int main(int argc, char ** argv) {
 
             wparams.speed_up         = params.speed_up;
 
+            whisper_print_user_data user_data = { &params, &pcmf32s };
+
             // this callback is called on each new segment
             if (!wparams.print_realtime) {
                 wparams.new_segment_callback           = whisper_print_segment_callback;
-                wparams.new_segment_callback_user_data = &params;
+                wparams.new_segment_callback_user_data = &user_data;
             }
 
             if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
                 fprintf(stderr, "%s: failed to process audio\n", argv[0]);
-                return 8;
+                return 10;
             }
         }