]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
stream : few updates to make it compatible for Vim usage (#99)
authorGeorgi Gerganov <redacted>
Thu, 27 Oct 2022 19:10:50 +0000 (22:10 +0300)
committerGeorgi Gerganov <redacted>
Thu, 27 Oct 2022 19:10:50 +0000 (22:10 +0300)
examples/stream/stream.cpp

index 65798b7bfb3185ea537588d2f0ee8e65f814783b..276ad2432caa1fce08252f4b3c7e0004e3a10547 100644 (file)
@@ -17,6 +17,7 @@
 #include <string>
 #include <thread>
 #include <vector>
+#include <fstream>
 
 //  500 -> 00:05.000
 // 6000 -> 01:00.000
@@ -47,7 +48,7 @@ struct whisper_params {
 
     std::string language  = "en";
     std::string model     = "models/ggml-base.en.bin";
-    std::string fname_inp = "samples/jfk.wav";
+    std::string fname_out = "";
 };
 
 void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@@ -84,7 +85,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         } else if (arg == "-m" || arg == "--model") {
             params.model = argv[++i];
         } else if (arg == "-f" || arg == "--file") {
-            params.fname_inp = argv[++i];
+            params.fname_out = argv[++i];
         } else if (arg == "-h" || arg == "--help") {
             whisper_print_usage(argc, argv, params);
             exit(0);
@@ -115,7 +116,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
     fprintf(stderr, "  -nt,      --no_timestamps  do not print timestamps\n");
     fprintf(stderr, "  -l LANG,  --language LANG  spoken language (default: %s)\n", params.language.c_str());
     fprintf(stderr, "  -m FNAME, --model FNAME    model path (default: %s)\n", params.model.c_str());
-    fprintf(stderr, "  -f FNAME, --file FNAME     input WAV file path (default: %s)\n", params.fname_inp.c_str());
+    fprintf(stderr, "  -f FNAME, --file FNAME     text output file name (default: no output to file)\n");
     fprintf(stderr, "\n");
 }
 
@@ -143,9 +144,9 @@ bool audio_sdl_init(const int capture_id) {
 
         {
             int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
-            printf("%s: found %d capture devices:\n", __func__, nDevices);
+            fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
             for (int i = 0; i < nDevices; i++) {
-                printf("%s:    - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
+                fprintf(stderr, "%s:    - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
             }
         }
     }
@@ -163,21 +164,21 @@ bool audio_sdl_init(const int capture_id) {
         capture_spec_requested.samples  = 1024;
 
         if (capture_id >= 0) {
-            printf("%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
+            fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
             g_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
         } else {
-            printf("%s: attempt to open default capture device ...\n", __func__);
+            fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
             g_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
         }
         if (!g_dev_id_in) {
-            printf("%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
+            fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
             g_dev_id_in = 0;
         } else {
-            printf("%s: obtained spec for input device (SDL Id = %d):\n", __func__, g_dev_id_in);
-            printf("%s:     - sample rate:       %d\n", __func__, capture_spec_obtained.freq);
-            printf("%s:     - format:            %d (required: %d)\n", __func__, capture_spec_obtained.format, capture_spec_requested.format);
-            printf("%s:     - channels:          %d (required: %d)\n", __func__, capture_spec_obtained.channels, capture_spec_requested.channels);
-            printf("%s:     - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
+            fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, g_dev_id_in);
+            fprintf(stderr, "%s:     - sample rate:       %d\n", __func__, capture_spec_obtained.freq);
+            fprintf(stderr, "%s:     - format:            %d (required: %d)\n", __func__, capture_spec_obtained.format, capture_spec_requested.format);
+            fprintf(stderr, "%s:     - channels:          %d (required: %d)\n", __func__, capture_spec_obtained.channels, capture_spec_requested.channels);
+            fprintf(stderr, "%s:     - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
         }
     }
 
@@ -212,6 +213,7 @@ int main(int argc, char ** argv) {
     const int n_samples = (params.step_ms/1000.0)*WHISPER_SAMPLE_RATE;
     const int n_samples_len = (params.length_ms/1000.0)*WHISPER_SAMPLE_RATE;
     const int n_samples_30s = 30*WHISPER_SAMPLE_RATE;
+
     std::vector<float> pcmf32(n_samples_30s, 0.0f);
     std::vector<float> pcmf32_old;
 
@@ -219,15 +221,15 @@ int main(int argc, char ** argv) {
 
     // print some info about the processing
     {
-        printf("\n");
+        fprintf(stderr, "\n");
         if (!whisper_is_multilingual(ctx)) {
             if (params.language != "en" || params.translate) {
                 params.language = "en";
                 params.translate = false;
-                printf("%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
+                fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
             }
         }
-        printf("%s: processing %d samples (step = %.1f sec / len = %.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
+        fprintf(stderr, "%s: processing %d samples (step = %.1f sec / len = %.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
                 __func__,
                 n_samples,
                 float(n_samples)/WHISPER_SAMPLE_RATE,
@@ -237,8 +239,8 @@ int main(int argc, char ** argv) {
                 params.translate ? "translate" : "transcribe",
                 params.no_timestamps ? 0 : 1);
 
-        printf("%s: n_new_line = %d\n", __func__, n_new_line);
-        printf("\n");
+        fprintf(stderr, "%s: n_new_line = %d\n", __func__, n_new_line);
+        fprintf(stderr, "\n");
     }
 
     SDL_PauseAudioDevice(g_dev_id_in, 0);
@@ -246,6 +248,18 @@ int main(int argc, char ** argv) {
     int n_iter = 0;
     bool is_running = true;
 
+    std::ofstream fout;
+    if (params.fname_out.length() > 0) {
+        fout.open(params.fname_out);
+        if (!fout.is_open()) {
+            fprintf(stderr, "%s: failed to open output file '%s'!\n", __func__, params.fname_out.c_str());
+            return 1;
+        }
+    }
+
+    printf("[Start speaking]");
+    fflush(stdout);
+
     // main audio loop
     while (is_running) {
         // process SDL events:
@@ -253,13 +267,18 @@ int main(int argc, char ** argv) {
         while (SDL_PollEvent(&event)) {
             switch (event.type) {
                 case SDL_QUIT:
-                    is_running = false;
-                    break;
+                    {
+                        is_running = false;
+                    } break;
                 default:
                     break;
             }
         }
 
+        if (!is_running) {
+            break;
+        }
+
         // process new audio
         if (n_iter > 0 && SDL_GetQueuedAudioSize(g_dev_id_in) > 2*n_samples*sizeof(float)) {
             fprintf(stderr, "\n\n%s: WARNING: cannot process audio fast enough, dropping audio ...\n\n", __func__);
@@ -312,20 +331,37 @@ int main(int argc, char ** argv) {
             {
                 printf("\33[2K\r");
 
+                // print long empty line to clear the previous line
+                printf("%s", std::string(100, ' ').c_str());
+
+                printf("\33[2K\r");
+
                 const int n_segments = whisper_full_n_segments(ctx);
                 for (int i = 0; i < n_segments; ++i) {
                     const char * text = whisper_full_get_segment_text(ctx, i);
 
                     if (params.no_timestamps) {
-                        printf ("%s", text);
+                        printf("%s", text);
                         fflush(stdout);
+
+                        if (params.fname_out.length() > 0) {
+                            fout << text;
+                        }
                     } else {
                         const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
                         const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
 
                         printf ("[%s --> %s]  %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
+
+                        if (params.fname_out.length() > 0) {
+                            fout << "[" << to_timestamp(t0) << " --> " << to_timestamp(t1) << "]  " << text << std::endl;
+                        }
                     }
                 }
+
+                if (params.fname_out.length() > 0) {
+                    fout << std::endl;
+                }
             }
 
             ++n_iter;