]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
stream : add "audio_ctx" parameter
authorGeorgi Gerganov <redacted>
Sun, 20 Nov 2022 19:12:01 +0000 (21:12 +0200)
committerGeorgi Gerganov <redacted>
Sun, 20 Nov 2022 19:22:41 +0000 (21:22 +0200)
Used to overwrite the audio context size of the Encoder.
For example, setting "audio_ctx = 512" will make it run about 3 times
faster, processing about 10s of audio, instead of 30s.

The transcription quality drops, but this can be used for real-time
streaming purposes where performance is important.

examples/stream/stream.cpp
whisper.cpp
whisper.h

index 040ba9ebf4b67a3c0c87cd71efe2e505bbac6d35..4ff93d397f1e5fd493ab39f212b1a9161faec7c1 100644 (file)
@@ -40,6 +40,7 @@ struct whisper_params {
     int32_t step_ms    = 3000;
     int32_t length_ms  = 10000;
     int32_t capture_id = -1;
+    int32_t audio_ctx  = 0;
 
     bool speed_up             = false;
     bool verbose              = false;
@@ -69,6 +70,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
             params.length_ms = std::stoi(argv[++i]);
         } else if (arg == "-c" || arg == "--capture") {
             params.capture_id = std::stoi(argv[++i]);
+        } else if (arg == "-ac" || arg == "--audio_ctx") {
+            params.audio_ctx = std::stoi(argv[++i]);
         } else if (arg == "-su" || arg == "--speed-up") {
             params.speed_up = true;
         } else if (arg == "-v" || arg == "--verbose") {
@@ -116,6 +119,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
     fprintf(stderr, "            --step N         audio step size in milliseconds (default: %d)\n", params.step_ms);
     fprintf(stderr, "            --length N       audio length in milliseconds (default: %d)\n", params.length_ms);
     fprintf(stderr, "  -c ID,    --capture ID     capture device ID (default: -1)\n");
+    fprintf(stderr, "  -ac N,    --audio_ctx N    audio context size (default: %d, 0 - all)\n", params.audio_ctx);
     fprintf(stderr, "  -su,      --speed-up       speed up audio by factor of 2 (faster processing, reduced accuracy, default: %s)\n", params.speed_up ? "true" : "false");
     fprintf(stderr, "  -v,       --verbose        verbose output\n");
     fprintf(stderr, "            --translate      translate from source language to english\n");
@@ -322,7 +326,6 @@ int main(int argc, char ** argv) {
         {
             whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
 
-            wparams.max_tokens           = 32;
             wparams.print_progress       = false;
             wparams.print_special_tokens = params.print_special_tokens;
             wparams.print_realtime       = false;
@@ -330,9 +333,11 @@ int main(int argc, char ** argv) {
             wparams.translate            = params.translate;
             wparams.no_context           = params.no_context;
             wparams.single_segment       = true;
+            wparams.max_tokens           = 32;
             wparams.language             = params.language.c_str();
             wparams.n_threads            = params.n_threads;
 
+            wparams.audio_ctx            = params.audio_ctx;
             wparams.speed_up             = params.speed_up;
 
             if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
index 48f93ebd89e8e0f7efe130c0632181207568944d..d35b90f317fa0f39b9a5b85cc29fda60631bb535 100644 (file)
@@ -424,6 +424,9 @@ struct whisper_context {
     int64_t t_last;
     whisper_token tid_last;
     std::vector<float> energy; // PCM signal energy
+
+    // [EXPERIMENTAL] speed-up techniques
+    int32_t exp_n_audio_ctx; // 0 - use default
 };
 
 // load the model from a ggml file
@@ -974,9 +977,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
 
             model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
             model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
-
-            //memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
-            //memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
         }
 
         const size_t memory_size =
@@ -1079,7 +1079,7 @@ static bool whisper_encode(
     const auto & mel_inp = wctx.mel;
     const auto & hparams = model.hparams;
 
-    const int n_ctx   = WHISPER_EXPERIMENT_AUDIO_CTX;
+    const int n_ctx   = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
     const int n_state = hparams.n_audio_state;
     const int n_head  = hparams.n_audio_head;
     const int n_layer = hparams.n_audio_layer;
@@ -1133,6 +1133,8 @@ static bool whisper_encode(
         cur = ggml_gelu(ctx0, cur);
     }
 
+    // ===================================================================
+    // NOTE: experimenting with partial evaluation of the encoder (ignore)
     //static int iter = -1;
     //const int n_iter = 1500/n_ctx;
 
@@ -1151,6 +1153,10 @@ static bool whisper_encode(
     struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
 
     cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur));
+    // ===================================================================
+
+    // original:
+    //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
 
     struct ggml_tensor * inpL = cur;
 
@@ -1494,8 +1500,7 @@ static bool whisper_decode(
     const int n_layer = hparams.n_text_layer;
 
     const int N = n_tokens;
-    //const int M = hparams.n_audio_ctx;
-    const int M = WHISPER_EXPERIMENT_AUDIO_CTX;
+    const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
 
     struct ggml_init_params params = {
             .mem_size   = wctx.buf_compute.size(),
@@ -2405,6 +2410,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
                     /*.max_tokens           =*/ 0,
 
                     /*.speed_up             =*/ false,
+                    /*.audio_ctx            =*/ 0,
 
                     /*.language             =*/ "en",
 
@@ -2447,6 +2453,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
                     /*.max_tokens           =*/ 0,
 
                     /*.speed_up             =*/ false,
+                    /*.audio_ctx            =*/ 0,
 
                     /*.language             =*/ "en",
 
@@ -2577,6 +2584,9 @@ int whisper_full(
         prompt_past.clear();
     }
 
+    // overwrite audio_ctx
+    ctx->exp_n_audio_ctx = params.audio_ctx;
+
     // these tokens determine the task that will be performed
     std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
     if (whisper_is_multilingual(ctx)) {
index 0211995dcb8d5922110eaea35f9cda1cbdc9072c..88cc71131442de8000d221b8a8ec4c1f07147001 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -24,8 +24,6 @@
 #define WHISPER_HOP_LENGTH  160
 #define WHISPER_CHUNK_SIZE  30
 
-#define WHISPER_EXPERIMENT_AUDIO_CTX 512
-
 #ifdef __cplusplus
 extern "C" {
 #endif
@@ -207,7 +205,8 @@ extern "C" {
         int   max_tokens;       // max tokens per segment (0 = no limit)
 
         // [EXPERIMENTAL] speed-up techniques
-        bool speed_up; // speed-up the audio by 2x using Phase Vocoder
+        bool speed_up;  // speed-up the audio by 2x using Phase Vocoder
+        int  audio_ctx; // overwrite the audio context size (0 = use default)
 
         const char * language;