]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
stream : improve real-time transcription
authorGeorgi Gerganov <redacted>
Mon, 10 Oct 2022 19:06:18 +0000 (22:06 +0300)
committerGeorgi Gerganov <redacted>
Mon, 10 Oct 2022 19:06:27 +0000 (22:06 +0300)
stream.cpp

index f927819f8d499a42f14334a52db7597c0cb238a4..86b09a0e7d603430b54cf22b97a81fc33b9bdb09 100644 (file)
@@ -37,6 +37,7 @@ struct whisper_params {
     int32_t seed      = -1; // RNG seed, not used currently
     int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
     int32_t step_ms   = 3000;
+    int32_t length_ms = 10000;
 
     bool verbose              = false;
     bool translate            = false;
@@ -61,6 +62,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
             params.n_threads = std::stoi(argv[++i]);
         } else if (arg == "--step") {
             params.step_ms = std::stoi(argv[++i]);
+        } else if (arg == "--length") {
+            params.length_ms = std::stoi(argv[++i]);
         } else if (arg == "-v" || arg == "--verbose") {
             params.verbose = true;
         } else if (arg == "--translate") {
@@ -104,9 +107,10 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
     fprintf(stderr, "  -s SEED,  --seed SEED      RNG seed (default: -1)\n");
     fprintf(stderr, "  -t N,     --threads N      number of threads to use during computation (default: %d)\n", params.n_threads);
     fprintf(stderr, "            --step N         audio step size in milliseconds (default: %d)\n", params.step_ms);
+    fprintf(stderr, "            --length N       audio length in milliseconds (default: %d)\n", params.length_ms);
     fprintf(stderr, "  -v,       --verbose        verbose output\n");
     fprintf(stderr, "            --translate      translate from source language to english\n");
-    fprintf(stderr, "  -nc,      --no-context     disable context from earlier audio (default: false)\n");
+    fprintf(stderr, "  -kc,      --keep-context   keep text context from earlier audio (default: false)\n");
     fprintf(stderr, "  -ps,      --print_special  print special tokens\n");
     fprintf(stderr, "  -nt,      --no_timestamps  do not print timestamps\n");
     fprintf(stderr, "  -l LANG,  --language LANG  spoken language (default: %s)\n", params.language.c_str());
@@ -206,6 +210,7 @@ int main(int argc, char ** argv) {
     struct whisper_context * ctx = whisper_init(params.model.c_str());
 
     const int n_samples = (params.step_ms/1000.0)*WHISPER_SAMPLE_RATE;
+    const int n_samples_len = (params.length_ms/1000.0)*WHISPER_SAMPLE_RATE;
     const int n_samples_30s = 30*WHISPER_SAMPLE_RATE;
     std::vector<float> pcmf32(n_samples_30s, 0.0f);
     std::vector<float> pcmf32_old;
@@ -220,8 +225,12 @@ int main(int argc, char ** argv) {
                 printf("%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
             }
         }
-        printf("%s: processing %d samples (%.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
-                __func__, n_samples, float(n_samples)/WHISPER_SAMPLE_RATE, params.n_threads,
+        printf("%s: processing %d samples (step = %.1f sec / len = %.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
+                __func__,
+                n_samples,
+                float(n_samples)/WHISPER_SAMPLE_RATE,
+                float(n_samples_len)/WHISPER_SAMPLE_RATE,
+                params.n_threads,
                 params.language.c_str(),
                 params.translate ? "translate" : "transcribe",
                 params.no_timestamps ? 0 : 1);
@@ -230,6 +239,7 @@ int main(int argc, char ** argv) {
 
     SDL_PauseAudioDevice(g_dev_id_in, 0);
 
+    int n_iter = 0;
     bool is_running = true;
 
     // main audio loop
@@ -253,8 +263,10 @@ int main(int argc, char ** argv) {
         const int n_samples_new = SDL_GetQueuedAudioSize(g_dev_id_in)/sizeof(float);
 
         // take one second from previous iteration
-        // TODO: better strategy
-        const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_30s/30 - n_samples_new));
+        //const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_30s/30 - n_samples_new));
+
+        // take up to params.length_ms audio from previous iteration
+        const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_len - n_samples_new));
 
         //printf("processing: take = %d, new = %d, old = %d\n", n_samples_take, n_samples_new, (int) pcmf32_old.size());
 
@@ -288,7 +300,9 @@ int main(int argc, char ** argv) {
 
             // print result;
             {
-                printf("\n");
+                if ((n_iter % (params.length_ms / params.step_ms - 1)) != 0) {
+                    printf("\33[2K\r");
+                }
 
                 const int n_segments = whisper_full_n_segments(ctx);
                 for (int i = 0; i < n_segments; ++i) {
@@ -305,6 +319,13 @@ int main(int argc, char ** argv) {
                     }
                 }
             }
+
+            ++n_iter;
+            if ((n_iter % (params.length_ms / params.step_ms - 1)) == 0) {
+                printf("\n");
+
+                pcmf32_old.clear();
+            }
         }
     }