]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ref #22 : add option to provide multiple input .wav files
authorGeorgi Gerganov <redacted>
Wed, 5 Oct 2022 20:44:10 +0000 (23:44 +0300)
committerGeorgi Gerganov <redacted>
Wed, 5 Oct 2022 20:44:10 +0000 (23:44 +0300)
README.md
main.cpp

index 9d5685d85889f17a8790b823b4a0a5e7bee53e80..a73dfb8a01907a75fb53a2bf232ff3ee913456aa 100644 (file)
--- a/README.md
+++ b/README.md
@@ -31,13 +31,12 @@ For a quick demo, simply run `make base.en`:
 
 ```java
 $ make base.en
-
-gcc -pthread -O3 -mavx -mavx2 -mfma -mf16c -c ggml.c
-g++ -pthread -O3 -std=c++11 -c main.cpp
-g++ -pthread -o main ggml.o main.o
+cc  -O3 -std=c11 -Wall -Wextra -Wno-unused-parameter -Wno-unused-function -pthread   -c ggml.c
+c++ -O3 -std=c++11 -Wall -Wextra -Wno-unused-parameter -Wno-unused-function -pthread -c whisper.cpp
+c++ -O3 -std=c++11 -Wall -Wextra -Wno-unused-parameter -Wno-unused-function -pthread main.cpp whisper.o ggml.o -o main
 ./main -h
 
-usage: ./main [options]
+usage: ./main [options] file0.wav file1.wav ...
 
 options:
   -h,       --help           show this help message and exit
@@ -49,11 +48,11 @@ options:
   -nt,      --no_timestamps  do not print timestamps
   -l LANG,  --language LANG  spoken language (default: en)
   -m FNAME, --model FNAME    model path (default: models/ggml-base.en.bin)
-  -f FNAME, --file FNAME     input WAV file path (default: samples/jfk.wav)
+  -f FNAME, --file FNAME     input WAV file path
 
 bash ./download-ggml-model.sh base.en
 Downloading ggml model base.en ...
-models/ggml-base.en.bin          100%[=====================================>] 141.11M  8.58MB/s    in 22s
+models/ggml-base.en.bin            100%[===================================>] 141.11M  6.49MB/s    in 23s
 Done! Model 'base.en' saved in 'models/ggml-base.en.bin'
 You can now use it like this:
 
@@ -86,20 +85,18 @@ whisper_model_load: adding 1607 extra tokens
 whisper_model_load: ggml ctx size = 163.43 MB
 whisper_model_load: memory size =    22.83 MB
 whisper_model_load: model size  =   140.54 MB
-log_mel_spectrogram: n_sample = 176000, n_len = 1100
-log_mel_spectrogram: recording length: 11.000000 s
 
-main: processing 176000 samples (11.0 sec), 4 threads, lang = english, task = transcribe, timestamps = 1 ...
+main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, lang = en, task = transcribe, timestamps = 1 ...
 
-[00:00.000 --> 00:11.000]   And so my fellow Americans ask not what your country can do for you. Ask what you can do for your country.
+[00:00.000 --> 00:11.000]   And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
 
 
-main:     load time =    82.05 ms
-main:      mel time =    44.15 ms
-main:   sample time =     1.98 ms
-main:   encode time =   674.77 ms / 112.46 ms per layer
-main:   decode time =    82.91 ms
-main:    total time =   886.29 ms
+whisper_print_timings:     load time =    77.48 ms
+whisper_print_timings:      mel time =    26.10 ms
+whisper_print_timings:   sample time =     2.19 ms
+whisper_print_timings:   encode time =   632.95 ms / 105.49 ms per layer
+whisper_print_timings:   decode time =    85.11 ms / 14.18 ms per layer
+whisper_print_timings:    total time =   824.14 ms
 ```
 
 The command downloads the `base.en` model converted to custom `ggml` format and runs the inference on all `.wav` samples in the folder `samples`.
index 6d1c55dace01eb6966d959b804e0738df820ea55..d363ab7e071daa9e5bdf1ba719497f3dd22afde6 100644 (file)
--- a/main.cpp
+++ b/main.cpp
@@ -36,7 +36,8 @@ struct whisper_params {
 
     std::string language  = "en";
     std::string model     = "models/ggml-base.en.bin";
-    std::string fname_inp = "samples/jfk.wav";
+
+    std::vector<std::string> fname_inp = {};
 };
 
 void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@@ -45,6 +46,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
     for (int i = 1; i < argc; i++) {
         std::string arg = argv[i];
 
+        if (arg[0] != '-') {
+            params.fname_inp.push_back(arg);
+            continue;
+        }
+
         if (arg == "-s" || arg == "--seed") {
             params.seed = std::stoi(argv[++i]);
         } else if (arg == "-t" || arg == "--threads") {
@@ -67,7 +73,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         } else if (arg == "-m" || arg == "--model") {
             params.model = argv[++i];
         } else if (arg == "-f" || arg == "--file") {
-            params.fname_inp = argv[++i];
+            params.fname_inp.push_back(argv[++i]);
         } else if (arg == "-h" || arg == "--help") {
             whisper_print_usage(argc, argv, params);
             exit(0);
@@ -83,7 +89,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
 
 void whisper_print_usage(int argc, char ** argv, const whisper_params & params) {
     fprintf(stderr, "\n");
-    fprintf(stderr, "usage: %s [options]\n", argv[0]);
+    fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
     fprintf(stderr, "\n");
     fprintf(stderr, "options:\n");
     fprintf(stderr, "  -h,       --help           show this help message and exit\n");
@@ -95,7 +101,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
     fprintf(stderr, "  -nt,      --no_timestamps  do not print timestamps\n");
     fprintf(stderr, "  -l LANG,  --language LANG  spoken language (default: %s)\n", params.language.c_str());
     fprintf(stderr, "  -m FNAME, --model FNAME    model path (default: %s)\n", params.model.c_str());
-    fprintf(stderr, "  -f FNAME, --file FNAME     input WAV file path (default: %s)\n", params.fname_inp.c_str());
+    fprintf(stderr, "  -f FNAME, --file FNAME     input WAV file path\n");
     fprintf(stderr, "\n");
 }
 
@@ -110,106 +116,116 @@ int main(int argc, char ** argv) {
         params.seed = time(NULL);
     }
 
+    if (params.fname_inp.empty()) {
+        fprintf(stderr, "error: no input files specified\n");
+        whisper_print_usage(argc, argv, params);
+        return 1;
+    }
+
     // whisper init
 
     struct whisper_context * ctx = whisper_init(params.model.c_str());
 
-    // WAV input
-    std::vector<float> pcmf32;
-    {
-        drwav wav;
-        if (!drwav_init_file(&wav, params.fname_inp.c_str(), NULL)) {
-            fprintf(stderr, "%s: failed to open WAV file '%s' - check your input\n", argv[0], params.fname_inp.c_str());
-            whisper_print_usage(argc, argv, {});
-            return 2;
-        }
+    for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
+        const auto fname_inp = params.fname_inp[f];
+
+        // WAV input
+        std::vector<float> pcmf32;
+        {
+            drwav wav;
+            if (!drwav_init_file(&wav, fname_inp.c_str(), NULL)) {
+                fprintf(stderr, "%s: failed to open WAV file '%s' - check your input\n", argv[0], fname_inp.c_str());
+                whisper_print_usage(argc, argv, {});
+                return 2;
+            }
 
-        if (wav.channels != 1 && wav.channels != 2) {
-            fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], params.fname_inp.c_str());
-            return 3;
-        }
+            if (wav.channels != 1 && wav.channels != 2) {
+                fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
+                return 3;
+            }
 
-        if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
-            fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], params.fname_inp.c_str());
-            return 4;
-        }
+            if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
+                fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
+                return 4;
+            }
 
-        if (wav.bitsPerSample != 16) {
-            fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], params.fname_inp.c_str());
-            return 5;
-        }
+            if (wav.bitsPerSample != 16) {
+                fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
+                return 5;
+            }
 
-        int n = wav.totalPCMFrameCount;
+            int n = wav.totalPCMFrameCount;
 
-        std::vector<int16_t> pcm16;
-        pcm16.resize(n*wav.channels);
-        drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
-        drwav_uninit(&wav);
+            std::vector<int16_t> pcm16;
+            pcm16.resize(n*wav.channels);
+            drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
+            drwav_uninit(&wav);
 
-        // convert to mono, float
-        pcmf32.resize(n);
-        if (wav.channels == 1) {
-            for (int i = 0; i < n; i++) {
-                pcmf32[i] = float(pcm16[i])/32768.0f;
-            }
-        } else {
-            for (int i = 0; i < n; i++) {
-                pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
+            // convert to mono, float
+            pcmf32.resize(n);
+            if (wav.channels == 1) {
+                for (int i = 0; i < n; i++) {
+                    pcmf32[i] = float(pcm16[i])/32768.0f;
+                }
+            } else {
+                for (int i = 0; i < n; i++) {
+                    pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
+                }
             }
         }
-    }
 
-    // print some info about the processing
-    {
-        printf("\n");
-        if (!whisper_is_multilingual(ctx)) {
-            if (params.language != "en" || params.translate) {
-                params.language = "en";
-                params.translate = false;
-                printf("%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
+        // print some info about the processing
+        {
+            printf("\n");
+            if (!whisper_is_multilingual(ctx)) {
+                if (params.language != "en" || params.translate) {
+                    params.language = "en";
+                    params.translate = false;
+                    printf("%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
+                }
             }
+            printf("%s: processing '%s' (%d samples, %.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
+                    __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, params.n_threads,
+                    params.language.c_str(),
+                    params.translate ? "translate" : "transcribe",
+                    params.no_timestamps ? 0 : 1);
+            printf("\n");
         }
-        printf("%s: processing %d samples (%.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
-                __func__, int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, params.n_threads,
-                params.language.c_str(),
-                params.translate ? "translate" : "transcribe",
-                params.no_timestamps ? 0 : 1);
-        printf("\n");
-    }
 
-    // run the inference
-    {
-        whisper_full_params wparams = whisper_full_default_params(WHISPER_DECODE_GREEDY);
-
-        wparams.print_realtime       = true;
-        wparams.print_progress       = false;
-        wparams.print_timestamps     = !params.no_timestamps;
-        wparams.print_special_tokens = params.print_special_tokens;
-        wparams.translate            = params.translate;
-        wparams.language             = params.language.c_str();
-        wparams.n_threads            = params.n_threads;
-
-        if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
-            fprintf(stderr, "%s: failed to process audio\n", argv[0]);
-            return 6;
-        }
+        // run the inference
+        {
+            whisper_full_params wparams = whisper_full_default_params(WHISPER_DECODE_GREEDY);
+
+            wparams.print_realtime       = true;
+            wparams.print_progress       = false;
+            wparams.print_timestamps     = !params.no_timestamps;
+            wparams.print_special_tokens = params.print_special_tokens;
+            wparams.translate            = params.translate;
+            wparams.language             = params.language.c_str();
+            wparams.n_threads            = params.n_threads;
+
+            if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
+                fprintf(stderr, "%s: failed to process audio\n", argv[0]);
+                return 6;
+            }
 
-        // print result;
-        if (!wparams.print_realtime) {
-            printf("\n");
+            // print result;
+            if (!wparams.print_realtime) {
+                printf("\n");
 
-            const int n_segments = whisper_full_n_segments(ctx);
-            for (int i = 0; i < n_segments; ++i) {
-                const char * text = whisper_full_get_segment_text(ctx, i);
+                const int n_segments = whisper_full_n_segments(ctx);
+                for (int i = 0; i < n_segments; ++i) {
+                    const char * text = whisper_full_get_segment_text(ctx, i);
 
-                if (params.no_timestamps) {
-                    printf ("%s", text);
-                    fflush(stdout);
-                } else {
-                    const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
-                    const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+                    if (params.no_timestamps) {
+                        printf ("%s", text);
+                        fflush(stdout);
+                    } else {
+                        const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
+                        const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
 
-                    printf ("[%s --> %s]  %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
+                        printf ("[%s --> %s]  %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
+                    }
                 }
             }
         }