]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
Reduce memory usage even more + better sampling
authorGeorgi Gerganov <redacted>
Fri, 30 Sep 2022 16:33:09 +0000 (19:33 +0300)
committerGeorgi Gerganov <redacted>
Fri, 30 Sep 2022 16:35:27 +0000 (19:35 +0300)
- The encode/decode memory buffers are now reused
- If the 30-sec segment goes for too long without a timestamp token, we
  force one. Improves transcription for large model
- Stereo support
- Add "micro-machines.wav" sample

Makefile
README.md
main.cpp

index 363c263b0ceb8bde90fe6cbcfee61150e14241e2..35cfe781c55dc64e35d5a124e2b6126ff7a45c55 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -20,10 +20,13 @@ samples:
        @wget --quiet --show-progress -O samples/gb0.ogg https://upload.wikimedia.org/wikipedia/commons/2/22/George_W._Bush%27s_weekly_radio_address_%28November_1%2C_2008%29.oga
        @wget --quiet --show-progress -O samples/gb1.ogg https://upload.wikimedia.org/wikipedia/commons/1/1f/George_W_Bush_Columbia_FINAL.ogg
        @wget --quiet --show-progress -O samples/hp0.ogg https://upload.wikimedia.org/wikipedia/en/d/d4/En.henryfphillips.ogg
+       @wget --quiet --show-progress -O samples/mm1.wav https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav
        @echo "Converting to 16-bit WAV ..."
        @ffmpeg -loglevel -0 -y -i samples/gb0.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/gb0.wav
        @ffmpeg -loglevel -0 -y -i samples/gb1.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/gb1.wav
        @ffmpeg -loglevel -0 -y -i samples/hp0.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/hp0.wav
+       @ffmpeg -loglevel -0 -y -i samples/mm1.wav -ar 16000 -ac 1 -c:a pcm_s16le samples/mm0.wav
+       @rm samples/mm1.wav
 
 
 # if not already downloaded, the following targets download the specified model and
index 5b06c1268342b43a3aa8876ea9054f61537adc08..068636a21a54e896a83b7c414d2eeacd425b9582 100644 (file)
--- a/README.md
+++ b/README.md
@@ -1,12 +1,13 @@
 # whisper.cpp
 
-C/C++ port of [OpenAI's Whisper](https://github.com/openai/whisper) speech-to-text model
+High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
 
 - Plain C/C++ implementation without dependencies
 - ARM_NEON and AVX intrinsics support
 - Mixed F16 / F32 support
 - Low memory usage (Flash Attention + Flash Forward)
 - Zero memory allocations at runtime
+- Runs on the CPU (Mac and Linux support)
 
 ## Usage
 
@@ -50,7 +51,12 @@ options:
 
 bash ./download-ggml-model.sh base.en
 Downloading ggml model base.en ...
-Model base.en already exists. Skipping download.
+models/ggml-base.en.bin          100%[=====================================>] 141.11M  8.58MB/s    in 22s
+Done! Model 'base.en' saved in 'models/ggml-base.en.bin'
+You can now use it like this:
+
+  $ ./main -m models/ggml-base.en.bin -f samples/jfk.wav
+
 
 ===============================================
 Running base.en on all samples in ./samples ...
@@ -73,7 +79,7 @@ whisper_model_load: n_text_layer  = 6
 whisper_model_load: n_mels        = 80
 whisper_model_load: f16           = 1
 whisper_model_load: type          = 2
-whisper_model_load: mem_required  = 611.00 MB
+whisper_model_load: mem_required  = 377.00 MB
 whisper_model_load: adding 1607 extra tokens
 whisper_model_load: ggml ctx size = 163.43 MB
 whisper_model_load: memory size =    22.83 MB
@@ -86,12 +92,12 @@ main: processing 176000 samples (11.0 sec), 4 threads, lang = english, task = tr
 [00:00.000 --> 00:11.000]   And so my fellow Americans ask not what your country can do for you. Ask what you can do for your country.
 
 
-main:     load time =    61.78 ms
-main:      mel time =    41.74 ms
-main:   sample time =     2.10 ms
-main:   encode time =   718.60 ms / 119.77 ms per layer
-main:   decode time =    83.55 ms
-main:    total time =   908.15 ms
+main:     load time =    82.05 ms
+main:      mel time =    44.15 ms
+main:   sample time =     1.98 ms
+main:   encode time =   674.77 ms / 112.46 ms per layer
+main:   decode time =    82.91 ms
+main:    total time =   886.29 ms
 ```
 
 The command downloads the `base.en` model converted to custom `ggml` format and runs the inference on all `.wav` samples in the folder `samples`.
@@ -131,10 +137,12 @@ make large
 
 ## Another example
 
-Here is another example of transcribing a [3:24 min speech](https://upload.wikimedia.org/wikipedia/commons/1/1f/George_W_Bush_Columbia_FINAL.ogg) in less than a minute, using `medium.en` model:
+Here is another example of transcribing a [3:24 min speech](https://upload.wikimedia.org/wikipedia/commons/1/1f/George_W_Bush_Columbia_FINAL.ogg)
+in less than a minute on a MacBook M1 Pro, using `medium.en` model:
 
 ```java
 $ ./main -m models/ggml-medium.en.bin -f samples/gb1.wav -t 8
+
 whisper_model_load: loading model from 'models/ggml-medium.en.bin'
 whisper_model_load: n_vocab       = 51864
 whisper_model_load: n_audio_ctx   = 1500
@@ -148,7 +156,7 @@ whisper_model_load: n_text_layer  = 24
 whisper_model_load: n_mels        = 80
 whisper_model_load: f16           = 1
 whisper_model_load: type          = 4
-whisper_model_load: mem_required  = 2786.00 MB
+whisper_model_load: mem_required  = 2502.00 MB
 whisper_model_load: adding 1607 extra tokens
 whisper_model_load: ggml ctx size = 1644.97 MB
 whisper_model_load: memory size =   182.62 MB
@@ -187,30 +195,30 @@ main: processing 3179750 samples (198.7 sec), 8 threads, lang = english, task =
 [03:14.000 --> 03:24.000]   [Music]
 
 
-main:     load time =   438.55 ms
-main:      mel time =   440.22 ms
-main:   sample time =    32.23 ms
-main:   encode time = 42329.63 ms / 1763.73 ms per layer
-main:   decode time = 15190.00 ms
-main:    total time = 58444.63 ms
+main:     load time =   522.18 ms
+main:      mel time =   423.43 ms
+main:   sample time =    31.42 ms
+main:   encode time = 41518.51 ms / 1729.94 ms per layer
+main:   decode time = 14907.22 ms
+main:    total time = 57416.63 ms
 ```
 
 ## Limitations
 
 - Very basic greedy sampling scheme - always pick up the top token
+- Only 16-bit WAV at 16 kHz is supported
 - Inference only
-- Runs on the CPU
-- Only mono-channel 16-bit WAV is supported
+- No GPU support
 
 ## Memory usage
 
-| Model | Disk | Mem |
-| ---   | --- | --- |
-| tiny | 75 MB | ~460 MB |
-| base | 142 MB | ~620 MB |
-| small | 466 MB | ~1.3 GB |
-| medium | 1.5 GB | ~2.8 GB |
-| large | 2.9 GB | ~4.9 GB |
+| Model  | Disk   | Mem     |
+| ---    | ---    | ---     |
+| tiny   |  75 MB | ~240 MB |
+| base   | 142 MB | ~380 MB |
+| small  | 466 MB | ~970 MB |
+| medium | 1.5 GB | ~2.5 GB |
+| large  | 2.9 GB | ~4.6 GB |
 
 ## ggml format
 
index ac205315638e8bb20618aecf5d59066514f50131..4822aa92965c8cebf28c1b1fc16ffcbcb42db997 100644 (file)
--- a/main.cpp
+++ b/main.cpp
@@ -158,11 +158,11 @@ const std::map<e_model, size_t> MEM_REQ_ENCODE_LAYER = {
 };
 
 const std::map<e_model, size_t> MEM_REQ_DECODE = {
-    { MODEL_TINY,    190ull*MB },
-    { MODEL_BASE,    190ull*MB },
-    { MODEL_SMALL,   190ull*MB },
-    { MODEL_MEDIUM,  200ull*MB },
-    { MODEL_LARGE,   200ull*MB },
+    { MODEL_TINY,     94ull*MB },
+    { MODEL_BASE,     96ull*MB },
+    { MODEL_SMALL,    98ull*MB },
+    { MODEL_MEDIUM,  100ull*MB },
+    { MODEL_LARGE,   102ull*MB },
 };
 
 const std::map<e_model, size_t> MEM_REQ_DECODE_LAYER = {
@@ -173,6 +173,11 @@ const std::map<e_model, size_t> MEM_REQ_DECODE_LAYER = {
     { MODEL_LARGE,   110ull*MB },
 };
 
+// the memory buffers used to store the model in memory and perform the inference computations
+std::vector<uint8_t> g_buf_model;
+std::vector<uint8_t> g_buf_compute;
+std::vector<uint8_t> g_buf_compute_layer;
+
 const int SAMPLE_RATE = 16000;
 const int N_FFT       = 400;
 const int N_MEL       = 80;
@@ -542,13 +547,15 @@ bool whisper_model_load(const std::string & fname, whisper_model & model, whispe
         printf("%s: f16           = %d\n", __func__, hparams.f16);
         printf("%s: type          = %d\n", __func__, model.type);
 
+        g_buf_model.resize(MEM_REQ_MODEL.at(model.type));
+        g_buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
+        g_buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
+
         // this is the total memory required to run the inference
         const size_t mem_required =
-                   MEM_REQ_MODEL.at(model.type) +
-                  MEM_REQ_ENCODE.at(model.type) +
-            MEM_REQ_ENCODE_LAYER.at(model.type) +
-                  MEM_REQ_DECODE.at(model.type) +
-            MEM_REQ_DECODE_LAYER.at(model.type);
+                   g_buf_model.size() +
+                   g_buf_compute.size() +
+                   g_buf_compute_layer.size();
 
         printf("%s: mem_required  = %.2f MB\n", __func__, mem_required / 1024.0 / 1024.0);
     }
@@ -752,8 +759,8 @@ bool whisper_model_load(const std::string & fname, whisper_model & model, whispe
     // create the ggml context
     {
         struct ggml_init_params params = {
-            .mem_size   = ctx_size,
-            .mem_buffer = NULL,
+            .mem_size   = g_buf_model.size(),
+            .mem_buffer = g_buf_model.data(),
         };
 
         model.ctx = ggml_init(params);
@@ -1089,17 +1096,10 @@ bool whisper_encode(
     const int n_mels = hparams.n_mels;
     assert(mel_inp.n_mel == n_mels);
 
-    struct ggml_init_params params;
-
-    {
-        static size_t buf_size = MEM_REQ_ENCODE.at(model.type);
-        static void * buf = malloc(buf_size);
-
-        params = {
-            .mem_size   = buf_size,
-            .mem_buffer = buf,
-        };
-    }
+    struct ggml_init_params params = {
+        .mem_size   = g_buf_compute.size(),
+        .mem_buffer = g_buf_compute.data(),
+    };
 
     struct ggml_context * ctx0 = ggml_init(params);
 
@@ -1151,16 +1151,10 @@ bool whisper_encode(
 
         // create separate context for each layer to reduce memory usage
 
-        struct ggml_init_params paramsL;
-        {
-            static size_t buf_size = MEM_REQ_ENCODE_LAYER.at(model.type);
-            static void * buf = malloc(buf_size);
-
-            paramsL = {
-                .mem_size   = buf_size,
-                .mem_buffer = buf,
-            };
-        }
+        struct ggml_init_params paramsL = {
+            .mem_size   = g_buf_compute_layer.size(),
+            .mem_buffer = g_buf_compute_layer.data(),
+        };
 
         struct ggml_context * ctxL = ggml_init(paramsL);
 
@@ -1492,17 +1486,10 @@ bool whisper_decode(
     const int N = prompt.size();
     const int M = hparams.n_audio_ctx;
 
-    struct ggml_init_params params;
-
-    {
-        static size_t buf_size = MEM_REQ_DECODE.at(model.type);
-        static void * buf = malloc(buf_size);
-
-        params = {
-            .mem_size   = buf_size,
-            .mem_buffer = buf,
+    struct ggml_init_params params = {
+            .mem_size   = g_buf_compute.size(),
+            .mem_buffer = g_buf_compute.data(),
         };
-    }
 
     struct ggml_context * ctx0 = ggml_init(params);
 
@@ -1525,17 +1512,10 @@ bool whisper_decode(
     for (int il = 0; il < n_layer; ++il) {
         const auto & layer = model.layers_decoder[il];
 
-        struct ggml_init_params paramsL;
-
-        {
-            static size_t buf_size = MEM_REQ_DECODE_LAYER.at(model.type);
-            static void * buf = malloc(buf_size);
-
-            paramsL = {
-                .mem_size   = buf_size,
-                .mem_buffer = buf,
-            };
-        }
+        struct ggml_init_params paramsL = {
+            .mem_size   = g_buf_compute_layer.size(),
+            .mem_buffer = g_buf_compute_layer.data(),
+        };
 
         struct ggml_context * ctxL = ggml_init(paramsL);
         struct ggml_cgraph gf = { .n_threads = n_threads };
@@ -1849,7 +1829,7 @@ bool whisper_decode(
 // TODO: temperature
 whisper_vocab::id whisper_sample_best(
         const whisper_vocab & vocab,
-        const float * probs) {
+        const float * probs, bool need_timestamp) {
     int n_logits = vocab.id_to_token.size();
 
     std::vector<std::pair<double, whisper_vocab::id>> probs_id;
@@ -1859,7 +1839,7 @@ whisper_vocab::id whisper_sample_best(
         probs_id.push_back(std::make_pair(probs[i], i));
     }
 
-    const int top_k = 10;
+    const int top_k = 4;
 
     // find the top K tokens
     std::partial_sort(
@@ -1876,6 +1856,15 @@ whisper_vocab::id whisper_sample_best(
     //    printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
     //}
 
+    if (need_timestamp) {
+        // at the end of the 30-second audio segment, we start giving preference to time tokens
+        for (int i = 0; i < top_k; i++) {
+            if (probs_id[i].second > vocab.token_beg + 1300 && probs_id[i].first > probs_id[0].first*0.1) {
+                return probs_id[i].second;
+            }
+        }
+    }
+
     int res = 0;
     while ((probs_id[res].second == vocab.token_sot ||
             probs_id[res].second == vocab.token_solm ||
@@ -2136,8 +2125,8 @@ int main(int argc, char ** argv) {
             return 2;
         }
 
-        if (wav.channels != 1) {
-            fprintf(stderr, "%s: WAV file '%s' must be mono\n", argv[0], params.fname_inp.c_str());
+        if (wav.channels != 1 && wav.channels != 2) {
+            fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], params.fname_inp.c_str());
             return 3;
         }
 
@@ -2158,8 +2147,14 @@ int main(int argc, char ** argv) {
 
         // convert to float
         pcmf32.resize(pcm16.size());
-        for (size_t i = 0; i < pcm16.size(); i++) {
-            pcmf32[i] = float(pcm16[i])/32768.0f;
+        if (wav.channels == 1) {
+            for (size_t i = 0; i < pcm16.size(); i++) {
+                pcmf32[i] = float(pcm16[i])/32768.0f;
+            }
+        } else {
+            for (size_t i = 0; i < pcm16.size(); i++) {
+                pcmf32[i] = float(pcm16[i*2 + 0] + pcm16[i*2 + 1])/32768.0f/2.0f;
+            }
         }
     }
 
@@ -2252,6 +2247,7 @@ int main(int argc, char ** argv) {
         int seek_delta = 100*CHUNK_SIZE;
         whisper_vocab::id last_id = 0;
 
+        // print the prompt
         //printf("\n\n");
         //for (int i = 0; i < prompt.size(); i++) {
         //    printf("%s: prompt[%d] = %s\n", __func__, i, vocab.id_to_token[prompt[i]].c_str());
@@ -2294,7 +2290,7 @@ int main(int argc, char ** argv) {
                 {
                     const int64_t t_start_sample_us = ggml_time_us();
 
-                    id = whisper_sample_best(vocab, probs.data() + (probs.size() - n_vocab));
+                    id = whisper_sample_best(vocab, probs.data() + (probs.size() - n_vocab), result_len == 0);
                     if (i > 0) {
                         tid = whisper_sample_timestamp(vocab, probs.data() + (probs.size() - n_vocab));
                     }
@@ -2313,6 +2309,8 @@ int main(int argc, char ** argv) {
                 prompt.push_back(id);
                 result_cur.push_back({ id, seek + 2*(tid - vocab.token_beg) });
 
+                //printf("%s: %s\n", __func__, vocab.id_to_token[id].c_str());
+
                 // end of text token
                 if (id == vocab.token_eot) {
                     break;