]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
examples : add tinydiarization support for streaming (#1137)
authorDuncan McConnell <redacted>
Thu, 3 Aug 2023 08:24:07 +0000 (03:24 -0500)
committerGitHub <redacted>
Thu, 3 Aug 2023 08:24:07 +0000 (11:24 +0300)
examples/stream/stream.cpp

index cec5a2b7ce5974a5d0c18b0bf2d935d23ef736e3..4c7f7d1af4781d199b433424fec2ce4213ec6ffc 100644 (file)
@@ -47,6 +47,7 @@ struct whisper_params {
     bool print_special = false;
     bool no_context    = true;
     bool no_timestamps = false;
+    bool tinydiarize   = false;
 
     std::string language  = "en";
     std::string model     = "models/ggml-base.en.bin";
@@ -80,6 +81,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-l"   || arg == "--language")      { params.language      = argv[++i]; }
         else if (arg == "-m"   || arg == "--model")         { params.model         = argv[++i]; }
         else if (arg == "-f"   || arg == "--file")          { params.fname_out     = argv[++i]; }
+        else if (arg == "-tdrz" || arg == "--tinydiarize")  { params.tinydiarize   = true; }
+
         else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
             whisper_print_usage(argc, argv, params);
@@ -113,6 +116,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -l LANG,  --language LANG [%-7s] spoken language\n",                                params.language.c_str());
     fprintf(stderr, "  -m FNAME, --model FNAME   [%-7s] model path\n",                                     params.model.c_str());
     fprintf(stderr, "  -f FNAME, --file FNAME    [%-7s] text output file name\n",                          params.fname_out.c_str());
+    fprintf(stderr, "  -tdrz,     --tinydiarize  [%-7s] enable tinydiarize (requires a tdrz model)\n",     params.tinydiarize ? "true" : "false");
     fprintf(stderr, "\n");
 }
 
@@ -299,6 +303,8 @@ int main(int argc, char ** argv) {
             wparams.audio_ctx        = params.audio_ctx;
             wparams.speed_up         = params.speed_up;
 
+            wparams.tdrz_enable      = params.tinydiarize; // [TDRZ]
+
             // disable temperature fallback
             //wparams.temperature_inc  = -1.0f;
             wparams.temperature_inc  = params.no_fallback ? 0.0f : wparams.temperature_inc;
@@ -344,10 +350,19 @@ int main(int argc, char ** argv) {
                         const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
                         const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
 
-                        printf ("[%s --> %s]  %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
+                        std::string output = "[" + to_timestamp(t0) + " --> " + to_timestamp(t1) + "]  " + text;
+
+                        if (whisper_full_get_segment_speaker_turn_next(ctx, i)) {
+                            output += " [SPEAKER_TURN]";
+                        }
+
+                        output += "\n";
+
+                        printf("%s", output.c_str());
+                        fflush(stdout);
 
                         if (params.fname_out.length() > 0) {
-                            fout << "[" << to_timestamp(t0) << " --> " << to_timestamp(t1) << "]  " << text << std::endl;
+                            fout << output;
                         }
                     }
                 }