]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ref #17 : add options to output result to file
authorGeorgi Gerganov <redacted>
Sat, 8 Oct 2022 14:22:22 +0000 (17:22 +0300)
committerGeorgi Gerganov <redacted>
Sat, 8 Oct 2022 14:22:22 +0000 (17:22 +0300)
Support for:

- plain text
- VTT
- SRT

main.cpp
whisper.cpp

index 9769b7ff61d78519bab00155edf8d1f55d9ba245..728ab6faa2f83dc5c1e82ed35c03b5bba42e2376 100644 (file)
--- a/main.cpp
+++ b/main.cpp
@@ -5,6 +5,7 @@
 #define DR_WAV_IMPLEMENTATION
 #include "dr_wav.h"
 
+#include <fstream>
 #include <cstdio>
 #include <string>
 #include <thread>
@@ -32,6 +33,9 @@ struct whisper_params {
 
     bool verbose              = false;
     bool translate            = false;
+    bool output_txt           = false;
+    bool output_vtt           = false;
+    bool output_srt           = false;
     bool print_special_tokens = false;
     bool no_timestamps        = false;
 
@@ -69,6 +73,12 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
                 whisper_print_usage(argc, argv, params);
                 exit(0);
             }
+        } else if (arg == "-otxt" || arg == "--output-txt") {
+            params.output_txt = true;
+        } else if (arg == "-ovtt" || arg == "--output-vtt") {
+            params.output_vtt = true;
+        } else if (arg == "-osrt" || arg == "--output-srt") {
+            params.output_srt = true;
         } else if (arg == "-ps" || arg == "--print_special") {
             params.print_special_tokens = true;
         } else if (arg == "-nt" || arg == "--no_timestamps") {
@@ -101,6 +111,8 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
     fprintf(stderr, "  -o N,     --offset N       offset in milliseconds (default: %d)\n", params.offset_ms);
     fprintf(stderr, "  -v,       --verbose        verbose output\n");
     fprintf(stderr, "            --translate      translate from source language to english\n");
+    fprintf(stderr, "  -otxt,    --output-txt     output result in a text file\n");
+    fprintf(stderr, "  -ovtt,    --output-vtt     output result in a vtt file\n");
     fprintf(stderr, "  -ps,      --print_special  print special tokens\n");
     fprintf(stderr, "  -nt,      --no_timestamps  do not print timestamps\n");
     fprintf(stderr, "  -l LANG,  --language LANG  spoken language (default: %s)\n", params.language.c_str());
@@ -123,7 +135,7 @@ int main(int argc, char ** argv) {
     if (params.fname_inp.empty()) {
         fprintf(stderr, "error: no input files specified\n");
         whisper_print_usage(argc, argv, params);
-        return 1;
+        return 2;
     }
 
     // whisper init
@@ -140,22 +152,22 @@ int main(int argc, char ** argv) {
             if (!drwav_init_file(&wav, fname_inp.c_str(), NULL)) {
                 fprintf(stderr, "%s: failed to open WAV file '%s' - check your input\n", argv[0], fname_inp.c_str());
                 whisper_print_usage(argc, argv, {});
-                return 2;
+                return 3;
             }
 
             if (wav.channels != 1 && wav.channels != 2) {
                 fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
-                return 3;
+                return 4;
             }
 
             if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
                 fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
-                return 4;
+                return 5;
             }
 
             if (wav.bitsPerSample != 16) {
                 fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
-                return 5;
+                return 6;
             }
 
             int n = wav.totalPCMFrameCount;
@@ -193,9 +205,11 @@ int main(int argc, char ** argv) {
                     params.language.c_str(),
                     params.translate ? "translate" : "transcribe",
                     params.no_timestamps ? 0 : 1);
+
             printf("\n");
         }
 
+
         // run the inference
         {
             whisper_full_params wparams = whisper_full_default_params(WHISPER_DECODE_GREEDY);
@@ -211,10 +225,10 @@ int main(int argc, char ** argv) {
 
             if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
                 fprintf(stderr, "%s: failed to process audio\n", argv[0]);
-                return 6;
+                return 7;
             }
 
-            // print result;
+            // print result
             if (!wparams.print_realtime) {
                 printf("\n");
 
@@ -233,6 +247,76 @@ int main(int argc, char ** argv) {
                     }
                 }
             }
+
+            printf("\n");
+
+            // output to text file
+            if (params.output_txt) {
+
+                const auto fname_txt = fname_inp + ".txt";
+                std::ofstream fout_txt(fname_txt);
+                if (!fout_txt.is_open()) {
+                    fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_txt.c_str());
+                    return 8;
+                }
+
+                printf("%s: saving output to '%s.txt'\n", __func__, fname_inp.c_str());
+
+                const int n_segments = whisper_full_n_segments(ctx);
+                for (int i = 0; i < n_segments; ++i) {
+                    const char * text = whisper_full_get_segment_text(ctx, i);
+                    fout_txt << text;
+                }
+            }
+
+            // output to VTT file
+            if (params.output_vtt) {
+
+                const auto fname_vtt = fname_inp + ".vtt";
+                std::ofstream fout_vtt(fname_vtt);
+                if (!fout_vtt.is_open()) {
+                    fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_vtt.c_str());
+                    return 9;
+                }
+
+                printf("%s: saving output to '%s.vtt'\n", __func__, fname_inp.c_str());
+
+                fout_vtt << "WEBVTT\n\n";
+
+                const int n_segments = whisper_full_n_segments(ctx);
+                for (int i = 0; i < n_segments; ++i) {
+                    const char * text = whisper_full_get_segment_text(ctx, i);
+                    const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
+                    const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+
+                    fout_vtt << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
+                    fout_vtt << text << "\n\n";
+                }
+            }
+
+            // output to SRT file
+            if (params.output_srt) {
+
+                const auto fname_srt = fname_inp + ".srt";
+                std::ofstream fout_srt(fname_srt);
+                if (!fout_srt.is_open()) {
+                    fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_srt.c_str());
+                    return 10;
+                }
+
+                printf("%s: saving output to '%s.srt'\n", __func__, fname_inp.c_str());
+
+                const int n_segments = whisper_full_n_segments(ctx);
+                for (int i = 0; i < n_segments; ++i) {
+                    const char * text = whisper_full_get_segment_text(ctx, i);
+                    const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
+                    const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+
+                    fout_srt << i + 1 << "\n";
+                    fout_srt << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
+                    fout_srt << text << "\n\n";
+                }
+            }
         }
     }
 
index af89815180910cb96dce8d351a9a7713694e4f3f..b59cfd7f04b95a25bd7037977e110b474bf95339 100644 (file)
@@ -2242,7 +2242,7 @@ whisper_token whisper_token_transcribe() {
 void whisper_print_timings(struct whisper_context * ctx) {
     const int64_t t_end_us = ggml_time_us();
 
-    printf("\n\n");
+    printf("\n");
     printf("%s:     load time = %8.2f ms\n", __func__, ctx->t_load_us/1000.0f);
     printf("%s:      mel time = %8.2f ms\n", __func__, ctx->t_mel_us/1000.0f);
     printf("%s:   sample time = %8.2f ms\n", __func__, ctx->t_sample_us/1000.0f);