]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
Improve decoding (#291)
authorGeorgi Gerganov <redacted>
Sun, 15 Jan 2023 09:29:57 +0000 (11:29 +0200)
committerGitHub <redacted>
Sun, 15 Jan 2023 09:29:57 +0000 (11:29 +0200)
* whisper : prepare infra for new decoding strategies

* whisper : apply logit filters and compute logprobs

* whisper : add whisper_get_logits()

* whisper : separate self and cross attention memory

Initial step needed for supporting parallel decoders

* whisper : move probs_id buffer to whisper_context

* whisper : refactor kv cache into separate struct

* whisper : move self-attention kv cache to whisper_decoder

* whisper : wip decoding parameters + strategies

* whisper : wip decoding parameters + strategies (part 2)

* whisper : wip decoding parameters + strategies (part 3)

* whisper : wip decoding parameters + strategies (part 4)

* whisper : fix prompt_past update to not include prompt_init

* whisper : temperature + best_of support

* whisper : support for compression_ration_threshold

We actually use entropy, but it is similar

* command : fix example to use logits instead of obsolete probs

* whisper : handle empty sequence ranking

* whisper : add WHISPER_DEBUG + diagnostic prints + new main args

* whisper : minor fixes

* whisper : add beam-search support

* whisper : bug fix when there no previous context

* whisper : add comments

* stream : disable temperature fallback

For real-time processing, we always want a single decoder running at T=0

* whisper.swiftui : update example - fix paths + add empty folders

.gitignore
README.md
examples/command/command.cpp
examples/main/main.cpp
examples/stream.wasm/emscripten.cpp
examples/stream/stream.cpp
examples/whisper.swiftui/whisper.swiftui.demo/Resources/models/.gitignore [new file with mode: 0644]
examples/whisper.swiftui/whisper.swiftui.demo/Resources/samples/.gitignore [new file with mode: 0644]
examples/whisper.swiftui/whisper.swiftui.xcodeproj/project.pbxproj
whisper.cpp
whisper.h

index 8a495199e756be70f24aa204f9730962d1a96077..5ca3702c3310ed9622907cf5b31e8058ed9fb0f8 100644 (file)
@@ -8,6 +8,7 @@ build/
 build-em/
 build-debug/
 build-release/
+build-static/
 build-sanitize-addr/
 build-sanitize-thread/
 
@@ -18,6 +19,7 @@ build-sanitize-thread/
 /bench
 
 sync.sh
+libwhisper.a
 libwhisper.so
 compile_commands.json
 
index f22724a5054bad9f0fad986167a2bd9d44254d8c..448e7588059aef5dff9858715affe3458636a6a2 100644 (file)
--- a/README.md
+++ b/README.md
@@ -212,17 +212,7 @@ make large
 ## Limitations
 
 - Inference only
-- No GPU support
-- Very basic greedy sampling scheme - always pick up the token with highest probability.
-  This should be similar to the [GreedyDecoder](https://github.com/openai/whisper/blob/main/whisper/decoding.py#L249-L274)
-  from the original python implementation, so in order to make a fair comparison between the 2 implementations, make sure
-  to run the python code with the following parameters:
-
-  ```
-  whisper --best_of None --beam_size None ...
-  ```
-
-  In the future, `whisper.cpp` will support more sampling strategies.
+- No GPU support (yet)
 
 ## Another example
 
index 3dae3a5e31c7b2b6f358ebe9475abdafe2c1f1b3..2bdaf87c45c305f30d2617b6d4d43e5bab8cc1c2 100644 (file)
@@ -671,56 +671,81 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
                 break;
             }
 
-            const auto * probs = whisper_get_probs(ctx);
-            std::vector<std::pair<float, int>> probs_id;
-
-            double psum = 0.0;
-            for (int i = 0; i < (int) allowed_commands.size(); ++i) {
-                probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
-                for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
-                    probs_id.back().first += probs[allowed_tokens[i][j]];
-                }
-                probs_id.back().first /= allowed_tokens[i].size();
-                psum += probs_id.back().first;
-            }
+            // estimate command probability
+            // NOTE: not optimal
+            {
+                const auto * logits = whisper_get_logits(ctx);
 
-            // normalize
-            for (auto & p : probs_id) {
-                p.first /= psum;
-            }
+                std::vector<float> probs(whisper_n_vocab(ctx), 0.0f);
 
-            // sort descending
-            {
-                using pair_type = decltype(probs_id)::value_type;
-                std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
-                    return a.first > b.first;
-                });
-            }
+                // compute probs from logits via softmax
+                {
+                    float max = -1e9;
+                    for (int i = 0; i < (int) probs.size(); ++i) {
+                        max = std::max(max, logits[i]);
+                    }
 
-            // print the commands and the respective probabilities
-            {
-                fprintf(stdout, "\n");
-                for (const auto & cmd : probs_id) {
-                    fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
-                    for (int token : allowed_tokens[cmd.second]) {
-                        fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]);
+                    float sum = 0.0f;
+                    for (int i = 0; i < (int) probs.size(); ++i) {
+                        probs[i] = expf(logits[i] - max);
+                        sum += probs[i];
+                    }
+
+                    for (int i = 0; i < (int) probs.size(); ++i) {
+                        probs[i] /= sum;
                     }
+                }
+
+                std::vector<std::pair<float, int>> probs_id;
+
+                double psum = 0.0;
+                for (int i = 0; i < (int) allowed_commands.size(); ++i) {
+                    probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
+                    for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
+                        probs_id.back().first += probs[allowed_tokens[i][j]];
+                    }
+                    probs_id.back().first /= allowed_tokens[i].size();
+                    psum += probs_id.back().first;
+                }
+
+                // normalize
+                for (auto & p : probs_id) {
+                    p.first /= psum;
+                }
+
+                // sort descending
+                {
+                    using pair_type = decltype(probs_id)::value_type;
+                    std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
+                        return a.first > b.first;
+                    });
+                }
+
+                // print the commands and the respective probabilities
+                {
                     fprintf(stdout, "\n");
+                    for (const auto & cmd : probs_id) {
+                        fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
+                        for (int token : allowed_tokens[cmd.second]) {
+                            fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]);
+                        }
+                        fprintf(stdout, "\n");
+                    }
                 }
-            }
 
-            // best command
-            {
-                const auto t_end = std::chrono::high_resolution_clock::now();
+                // best command
+                {
+                    const auto t_end = std::chrono::high_resolution_clock::now();
 
-                const float prob = probs_id[0].first;
-                const int index = probs_id[0].second;
+                    const float prob = probs_id[0].first;
+                    const int index = probs_id[0].second;
 
-                fprintf(stdout, "\n");
-                fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
-                        "\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
-                        (int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
-                fprintf(stdout, "\n");
+                    fprintf(stdout, "\n");
+                    fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
+                            "\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
+                            (int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
+                    fprintf(stdout, "\n");
+                }
             }
 
             audio.clear();
index 48e02923d017fd01c326046a2f83f3ac98ef7fd0..65b06ca516ae0aadf7694b8c274b7018b17baa6b 100644 (file)
@@ -59,8 +59,12 @@ struct whisper_params {
     int32_t duration_ms  = 0;
     int32_t max_context  = -1;
     int32_t max_len      = 0;
+    int32_t best_of      = 5;
+    int32_t beam_size    = -1;
 
-    float word_thold = 0.01f;
+    float word_thold    = 0.01f;
+    float entropy_thold = 2.4f;
+    float logprob_thold = -1.0f;
 
     bool speed_up       = false;
     bool translate      = false;
@@ -104,7 +108,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-d"    || arg == "--duration")       { params.duration_ms    = std::stoi(argv[++i]); }
         else if (arg == "-mc"   || arg == "--max-context")    { params.max_context    = std::stoi(argv[++i]); }
         else if (arg == "-ml"   || arg == "--max-len")        { params.max_len        = std::stoi(argv[++i]); }
+        else if (arg == "-bo"   || arg == "--best-of")        { params.best_of        = std::stoi(argv[++i]); }
+        else if (arg == "-bs"   || arg == "--beam-size")      { params.beam_size      = std::stoi(argv[++i]); }
         else if (arg == "-wt"   || arg == "--word-thold")     { params.word_thold     = std::stof(argv[++i]); }
+        else if (arg == "-et"   || arg == "--entropy-thold")  { params.entropy_thold  = std::stof(argv[++i]); }
+        else if (arg == "-lpt"  || arg == "--logprob-thold")  { params.logprob_thold  = std::stof(argv[++i]); }
         else if (arg == "-su"   || arg == "--speed-up")       { params.speed_up       = true; }
         else if (arg == "-tr"   || arg == "--translate")      { params.translate      = true; }
         else if (arg == "-di"   || arg == "--diarize")        { params.diarize        = true; }
@@ -136,31 +144,35 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
     fprintf(stderr, "\n");
     fprintf(stderr, "options:\n");
-    fprintf(stderr, "  -h,       --help           [default] show this help message and exit\n");
-    fprintf(stderr, "  -t N,     --threads N      [%-7d] number of threads to use during computation\n",    params.n_threads);
-    fprintf(stderr, "  -p N,     --processors N   [%-7d] number of processors to use during computation\n", params.n_processors);
-    fprintf(stderr, "  -ot N,    --offset-t N     [%-7d] time offset in milliseconds\n",                    params.offset_t_ms);
-    fprintf(stderr, "  -on N,    --offset-n N     [%-7d] segment index offset\n",                           params.offset_n);
-    fprintf(stderr, "  -d  N,    --duration N     [%-7d] duration of audio to process in milliseconds\n",   params.duration_ms);
-    fprintf(stderr, "  -mc N,    --max-context N  [%-7d] maximum number of text context tokens to store\n", params.max_context);
-    fprintf(stderr, "  -ml N,    --max-len N      [%-7d] maximum segment length in characters\n",           params.max_len);
-    fprintf(stderr, "  -wt N,    --word-thold N   [%-7.2f] word timestamp probability threshold\n",         params.word_thold);
-    fprintf(stderr, "  -su,      --speed-up       [%-7s] speed up audio by x2 (reduced accuracy)\n",        params.speed_up ? "true" : "false");
-    fprintf(stderr, "  -tr,      --translate      [%-7s] translate from source language to english\n",      params.translate ? "true" : "false");
-    fprintf(stderr, "  -di,      --diarize        [%-7s] stereo audio diarization\n",                       params.diarize ? "true" : "false");
-    fprintf(stderr, "  -otxt,    --output-txt     [%-7s] output result in a text file\n",                   params.output_txt ? "true" : "false");
-    fprintf(stderr, "  -ovtt,    --output-vtt     [%-7s] output result in a vtt file\n",                    params.output_vtt ? "true" : "false");
-    fprintf(stderr, "  -osrt,    --output-srt     [%-7s] output result in a srt file\n",                    params.output_srt ? "true" : "false");
-    fprintf(stderr, "  -owts,    --output-words   [%-7s] output script for generating karaoke video\n",     params.output_wts ? "true" : "false");
-    fprintf(stderr, "  -ocsv,    --output-csv     [%-7s] output result in a CSV file\n",                    params.output_csv ? "true" : "false");
-    fprintf(stderr, "  -ps,      --print-special  [%-7s] print special tokens\n",                           params.print_special ? "true" : "false");
-    fprintf(stderr, "  -pc,      --print-colors   [%-7s] print colors\n",                                   params.print_colors ? "true" : "false");
-    fprintf(stderr, "  -pp,      --print-progress [%-7s] print progress\n",                                 params.print_progress ? "true" : "false");
-    fprintf(stderr, "  -nt,      --no-timestamps  [%-7s] do not print timestamps\n",                        params.no_timestamps ? "false" : "true");
-    fprintf(stderr, "  -l LANG,  --language LANG  [%-7s] spoken language ('auto' for auto-detect)\n",       params.language.c_str());
-    fprintf(stderr, "            --prompt PROMPT  [%-7s] initial prompt\n",                                 params.prompt.c_str());
-    fprintf(stderr, "  -m FNAME, --model FNAME    [%-7s] model path\n",                                     params.model.c_str());
-    fprintf(stderr, "  -f FNAME, --file FNAME     [%-7s] input WAV file path\n",                            "");
+    fprintf(stderr, "  -h,       --help            [default] show this help message and exit\n");
+    fprintf(stderr, "  -t N,     --threads N       [%-7d] number of threads to use during computation\n",    params.n_threads);
+    fprintf(stderr, "  -p N,     --processors N    [%-7d] number of processors to use during computation\n", params.n_processors);
+    fprintf(stderr, "  -ot N,    --offset-t N      [%-7d] time offset in milliseconds\n",                    params.offset_t_ms);
+    fprintf(stderr, "  -on N,    --offset-n N      [%-7d] segment index offset\n",                           params.offset_n);
+    fprintf(stderr, "  -d  N,    --duration N      [%-7d] duration of audio to process in milliseconds\n",   params.duration_ms);
+    fprintf(stderr, "  -mc N,    --max-context N   [%-7d] maximum number of text context tokens to store\n", params.max_context);
+    fprintf(stderr, "  -ml N,    --max-len N       [%-7d] maximum segment length in characters\n",           params.max_len);
+    fprintf(stderr, "  -bo N,    --best-of N       [%-7d] number of best candidates to keep\n",              params.best_of);
+    fprintf(stderr, "  -bs N,    --beam-size N     [%-7d] beam size for beam search\n",                      params.beam_size);
+    fprintf(stderr, "  -wt N,    --word-thold N    [%-7.2f] word timestamp probability threshold\n",         params.word_thold);
+    fprintf(stderr, "  -et N,    --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n",           params.entropy_thold);
+    fprintf(stderr, "  -lpt N,   --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n",   params.logprob_thold);
+    fprintf(stderr, "  -su,      --speed-up        [%-7s] speed up audio by x2 (reduced accuracy)\n",        params.speed_up ? "true" : "false");
+    fprintf(stderr, "  -tr,      --translate       [%-7s] translate from source language to english\n",      params.translate ? "true" : "false");
+    fprintf(stderr, "  -di,      --diarize         [%-7s] stereo audio diarization\n",                       params.diarize ? "true" : "false");
+    fprintf(stderr, "  -otxt,    --output-txt      [%-7s] output result in a text file\n",                   params.output_txt ? "true" : "false");
+    fprintf(stderr, "  -ovtt,    --output-vtt      [%-7s] output result in a vtt file\n",                    params.output_vtt ? "true" : "false");
+    fprintf(stderr, "  -osrt,    --output-srt      [%-7s] output result in a srt file\n",                    params.output_srt ? "true" : "false");
+    fprintf(stderr, "  -owts,    --output-words    [%-7s] output script for generating karaoke video\n",     params.output_wts ? "true" : "false");
+    fprintf(stderr, "  -ocsv,    --output-csv      [%-7s] output result in a CSV file\n",                    params.output_csv ? "true" : "false");
+    fprintf(stderr, "  -ps,      --print-special   [%-7s] print special tokens\n",                           params.print_special ? "true" : "false");
+    fprintf(stderr, "  -pc,      --print-colors    [%-7s] print colors\n",                                   params.print_colors ? "true" : "false");
+    fprintf(stderr, "  -pp,      --print-progress  [%-7s] print progress\n",                                 params.print_progress ? "true" : "false");
+    fprintf(stderr, "  -nt,      --no-timestamps   [%-7s] do not print timestamps\n",                        params.no_timestamps ? "false" : "true");
+    fprintf(stderr, "  -l LANG,  --language LANG   [%-7s] spoken language ('auto' for auto-detect)\n",       params.language.c_str());
+    fprintf(stderr, "            --prompt PROMPT   [%-7s] initial prompt\n",                                 params.prompt.c_str());
+    fprintf(stderr, "  -m FNAME, --model FNAME     [%-7s] model path\n",                                     params.model.c_str());
+    fprintf(stderr, "  -f FNAME, --file FNAME      [%-7s] input WAV file path\n",                            "");
     fprintf(stderr, "\n");
 }
 
@@ -235,7 +247,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
                 const char * text = whisper_full_get_token_text(ctx, i, j);
                 const float  p    = whisper_full_get_token_p   (ctx, i, j);
 
-                const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
+                const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size()))));
 
                 printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
             }
@@ -331,20 +343,19 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
     const int n_segments = whisper_full_n_segments(ctx);
     for (int i = 0; i < n_segments; ++i) {
         const char * text = whisper_full_get_segment_text(ctx, i);
-       if (text[0] == ' ')
-         text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character.
+        if (text[0] == ' ') {
+            text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character.
+        }
         const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
         const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
-       //need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
-        fout << 10 * t0 << ", " 
-            << 10 * t1 << ", \"" 
-            << text    << "\"\n";
+
+        //need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
+        fout << 10 * t0 << ", " << 10 * t1 << ", \"" << text    << "\"\n";
     }
 
     return true;
 }
 
-
 // karaoke video generation
 // outputs a bash script that uses ffmpeg to generate a video with the subtitles
 // TODO: font parameter adjustments
@@ -620,6 +631,8 @@ int main(int argc, char ** argv) {
         {
             whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
 
+            wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
+
             wparams.print_realtime   = false;
             wparams.print_progress   = params.print_progress;
             wparams.print_timestamps = !params.no_timestamps;
@@ -633,12 +646,18 @@ int main(int argc, char ** argv) {
 
             wparams.token_timestamps = params.output_wts || params.max_len > 0;
             wparams.thold_pt         = params.word_thold;
+            wparams.entropy_thold    = params.entropy_thold;
+            wparams.logprob_thold    = params.logprob_thold;
             wparams.max_len          = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
 
             wparams.speed_up         = params.speed_up;
 
-            wparams.prompt_tokens    = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
-            wparams.prompt_n_tokens  = prompt_tokens.empty() ? 0       : prompt_tokens.size();
+            wparams.greedy.best_of        = params.best_of;
+            wparams.beam_search.beam_size = params.beam_size;
+            wparams.temperature_inc = -1;
+
+            wparams.prompt_tokens     = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
+            wparams.prompt_n_tokens   = prompt_tokens.empty() ? 0       : prompt_tokens.size();
 
             whisper_print_user_data user_data = { &params, &pcmf32s };
 
index e4cdf639a406f1488a27efbde47c3b3025cbef5d..144a14d268fee0402ee903c935aa59b149261cc8 100644 (file)
@@ -49,6 +49,9 @@ void stream_main(size_t index) {
     wparams.max_tokens       = 32;
     wparams.audio_ctx        = 768; // partial encoder context for better performance
 
+    // disable temperature fallback
+    wparams.temperature_inc  = -1.0f;
+
     wparams.language         = "en";
 
     printf("stream: using %d threads\n", wparams.n_threads);
index 9f0c16c669a722c4aecc71fcb8603d6c9f4ace82..e1251704f5d91b881ffac4e4f2de2da1d9d0f171 100644 (file)
@@ -615,6 +615,9 @@ int main(int argc, char ** argv) {
             wparams.audio_ctx        = params.audio_ctx;
             wparams.speed_up         = params.speed_up;
 
+            // disable temperature fallback
+            wparams.temperature_inc  = -1.0f;
+
             wparams.prompt_tokens    = params.no_context ? nullptr : prompt_tokens.data();
             wparams.prompt_n_tokens  = params.no_context ? 0       : prompt_tokens.size();
 
diff --git a/examples/whisper.swiftui/whisper.swiftui.demo/Resources/models/.gitignore b/examples/whisper.swiftui/whisper.swiftui.demo/Resources/models/.gitignore
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/examples/whisper.swiftui/whisper.swiftui.demo/Resources/samples/.gitignore b/examples/whisper.swiftui/whisper.swiftui.demo/Resources/samples/.gitignore
new file mode 100644 (file)
index 0000000..e69de29
index 9cc09c09b5202c391e3a8d562e60c72074b2d0a5..cc0afbcae4ff43b8df46807d7744e5743c9904c4 100644 (file)
                0AAC5DA029539CD0003032C3 /* WhisperCppDemo.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = WhisperCppDemo.entitlements; sourceTree = "<group>"; };
                0AAC5DA229539CD0003032C3 /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = "<group>"; };
                0AAC5DC629539EAF003032C3 /* WhisperCppDemo-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "WhisperCppDemo-Bridging-Header.h"; sourceTree = "<group>"; };
-               0AAC5DC729539EB0003032C3 /* whisper.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = whisper.cpp; path = ../../../whisper.cpp; sourceTree = "<group>"; };
-               0AAC5DC829539EB0003032C3 /* whisper.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = whisper.h; path = ../../../whisper.h; sourceTree = "<group>"; };
-               0AAC5DC929539EB0003032C3 /* ggml.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; name = ggml.c; path = ../../../ggml.c; sourceTree = "<group>"; };
-               0AAC5DCA29539EB0003032C3 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = ggml.h; path = ../../../ggml.h; sourceTree = "<group>"; };
+               0AAC5DC729539EB0003032C3 /* whisper.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = whisper.cpp; sourceTree = "<group>"; };
+               0AAC5DC829539EB0003032C3 /* whisper.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = whisper.h; sourceTree = "<group>"; };
+               0AAC5DC929539EB0003032C3 /* ggml.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = ggml.c; sourceTree = "<group>"; };
+               0AAC5DCA29539EB0003032C3 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ggml.h; sourceTree = "<group>"; };
                0AAC5DCD2953A05C003032C3 /* WhisperState.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = WhisperState.swift; sourceTree = "<group>"; };
                0AAC5DD02953A394003032C3 /* LibWhisper.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LibWhisper.swift; sourceTree = "<group>"; };
 /* End PBXFileReference section */
                                0AAC5DC729539EB0003032C3 /* whisper.cpp */,
                                0AAC5DC829539EB0003032C3 /* whisper.h */,
                        );
-                       path = whisper.cpp;
+                       name = whisper.cpp;
+                       path = ../..;
                        sourceTree = "<group>";
                };
                0AAC5DCF2953A36C003032C3 /* whisper.cpp.swift */ = {
index a64505693f718ec55c5a3c438ec619d271865a20..c40085675baaa7b6ff65852d14402c2c2be7ea8a 100644 (file)
 #include <thread>
 #include <vector>
 #include <regex>
+#include <random>
+
+#define WHISPER_ASSERT(x) \
+    do { \
+        if (!(x)) { \
+            fprintf(stderr, "WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
+            abort(); \
+        } \
+    } while (0)
+
+// define this to enable verbose trace logging - useful for debugging purposes
+//#define WHISPER_DEBUG
+
+#if defined(WHISPER_DEBUG)
+#define WHISPER_PRINT_DEBUG(...) \
+    do { \
+        fprintf(stderr, __VA_ARGS__); \
+    } while (0)
+#else
+#define WHISPER_PRINT_DEBUG(...)
+#endif
 
-#define USE_FLASH_ATTN
-//#define USE_FLASH_FF
+#define WHISPER_USE_FLASH_ATTN
+//#define WHISPER_USE_FLASH_FF
+#define WHISPER_MAX_DECODERS 16
 
 // available whisper models
 enum e_model {
@@ -141,12 +163,20 @@ static const std::map<e_model, size_t> MEM_REQ_MODEL = {
     { MODEL_LARGE,  2952ull*MB },
 };
 
-static const std::map<e_model, size_t> MEM_REQ_MEMORY = {
-    { MODEL_TINY,     12ull*MB },
-    { MODEL_BASE,     24ull*MB },
-    { MODEL_SMALL,    70ull*MB },
-    { MODEL_MEDIUM,  184ull*MB },
-    { MODEL_LARGE,   306ull*MB },
+static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
+    { MODEL_TINY,      3ull*MB },
+    { MODEL_BASE,      6ull*MB },
+    { MODEL_SMALL,    16ull*MB },
+    { MODEL_MEDIUM,   43ull*MB },
+    { MODEL_LARGE,    71ull*MB },
+};
+
+static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
+    { MODEL_TINY,      9ull*MB },
+    { MODEL_BASE,     18ull*MB },
+    { MODEL_SMALL,    53ull*MB },
+    { MODEL_MEDIUM,  141ull*MB },
+    { MODEL_LARGE,   235ull*MB },
 };
 
 static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
@@ -204,10 +234,6 @@ struct whisper_vocab {
     std::map<token, id> token_to_id;
     std::map<id, token> id_to_token;
 
-    // used to avoid memory allocations during sampling
-    // TODO: move to whisper_context in the future
-    std::vector<std::pair<double, whisper_vocab::id>> probs_id;
-
     id token_eot  = 50256;
     id token_sot  = 50257;
     id token_prev = 50360;
@@ -349,6 +375,17 @@ struct whisper_layer_decoder {
     struct ggml_tensor * mlp_1_b;
 };
 
+struct whisper_kv_cache {
+    struct ggml_tensor * k;
+    struct ggml_tensor * v;
+
+    struct ggml_context * ctx;
+
+    std::vector<uint8_t> buf;
+
+    int n; // number of tokens currently in the cache
+};
+
 struct whisper_model {
     e_model type = MODEL_UNKNOWN;
 
@@ -371,34 +408,64 @@ struct whisper_model {
     struct ggml_tensor * e_ln_b;
 
     // decoder.positional_embedding
-    struct ggml_tensor * d_pe; // DD
+    struct ggml_tensor * d_pe;
 
     // decoder.token_embedding
-    struct ggml_tensor * d_te; // DD
+    struct ggml_tensor * d_te;
 
     // decoder.ln
-    struct ggml_tensor * d_ln_w; // DD
-    struct ggml_tensor * d_ln_b; // DD
+    struct ggml_tensor * d_ln_w;
+    struct ggml_tensor * d_ln_b;
 
     std::vector<whisper_layer_encoder> layers_encoder;
     std::vector<whisper_layer_decoder> layers_decoder;
 
-    // key + value memory
-    struct ggml_tensor * memory_k;
-    struct ggml_tensor * memory_v;
-
-    struct ggml_tensor * memory_cross_k;
-    struct ggml_tensor * memory_cross_v;
-
     // context
     struct ggml_context * ctx;
-    struct ggml_context * ctx_mem;
+
+    // the model memory buffer is read-only and can be shared between processors
+    std::vector<uint8_t> * buf;
 
     // tensors
     int n_loaded;
     std::map<std::string, struct ggml_tensor *> tensors;
 };
 
+struct whisper_sequence {
+    std::vector<whisper_token_data> tokens;
+
+    // the accumulated transcription in the current interation (used to truncate the tokens array)
+    int result_len;
+
+    double sum_logprobs_all; // the sum of the log probabilities of the tokens
+    double sum_logprobs;     // the sum of the log probabilities of the tokens (first result_len tokens)
+    double avg_logprobs;     // the average log probability of the tokens
+    double entropy;          // the entropy of the tokens
+    double score;            // likelihood rank score
+};
+
+// TAGS: WHISPER_DECODER_INIT
+struct whisper_decoder {
+    // each decoders keeps its own KV-cache
+    whisper_kv_cache kv_self;
+
+    // the currently generated sequence of tokens
+    whisper_sequence sequence;
+
+    int seek_delta; // the window shift found so far based on the decoded timestamp tokens
+
+    bool failed;    // has the current segment failed to decode?
+    bool completed; // has the decoder completed the current segment?
+    bool has_ts;    // have we already sampled a non-beg timestamp token for the current segment?
+
+    // new token probs, logits and logprobs after the last whisper_decode (1-dimensional array: [n_vocab])
+    std::vector<float> probs;
+    std::vector<float> logits;
+    std::vector<float> logprobs;
+
+    std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
+};
+
 struct whisper_context {
     int64_t t_load_us   = 0;
     int64_t t_mel_us    = 0;
@@ -407,24 +474,33 @@ struct whisper_context {
     int64_t t_decode_us = 0;
     int64_t t_start_us  = 0;
 
-    std::vector<uint8_t> * buf_model; // the model buffer is read-only and can be shared between processors
-    std::vector<uint8_t>   buf_memory;
-    std::vector<uint8_t>   buf_compute;
-    std::vector<uint8_t>   buf_compute_layer;
-
     ggml_type wtype; // weight type (FP32 or FP16)
 
+    whisper_mel mel;
+
     whisper_model model;
     whisper_vocab vocab;
 
-    whisper_mel mel;
+    // cross-attention KV cache for the decoders
+    // shared between all decoders
+    whisper_kv_cache kv_cross;
 
-    std::vector<float> probs;
+    whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
+
+    // memory buffers used by encode / decode contexts
+    std::vector<uint8_t> buf_compute;
+    std::vector<uint8_t> buf_compute_layer;
+
+    // decode output (2-dimensional array: [n_tokens][n_vocab])
     std::vector<float> logits;
 
     std::vector<whisper_segment> result_all;
+    std::vector<whisper_token>   prompt_past;
+
+    // work container used to avoid memory allocations
+    std::vector<std::pair<double, whisper_vocab::id>> logits_id;
 
-    std::vector<whisper_token> prompt_past;
+    mutable std::mt19937 rng; // used for sampling at t > 0.0
 
     // [EXPERIMENTAL] token-level timestamps data
     int64_t t_beg;
@@ -441,6 +517,72 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
     loader->read(loader->context, &dest, sizeof(T));
 }
 
+static bool kv_cache_init(
+        const struct whisper_hparams & hparams,
+                        const size_t   mem_bytes,
+             struct whisper_kv_cache & cache,
+                           ggml_type   wtype,
+                                 int   n_ctx) {
+    cache.buf.resize(mem_bytes);
+
+    struct ggml_init_params params;
+    params.mem_size   = cache.buf.size();
+    params.mem_buffer = cache.buf.data();
+
+    cache.ctx = ggml_init(params);
+
+    if (!cache.ctx) {
+        fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
+        return false;
+    }
+
+    const int n_text_state = hparams.n_text_state;
+    const int n_text_layer = hparams.n_text_layer;
+
+    const int n_mem      = n_text_layer*n_ctx;
+    const int n_elements = n_text_state*n_mem;
+
+    cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
+    cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
+
+    return true;
+}
+
+static bool kv_cache_reinit(struct whisper_kv_cache & cache) {
+    WHISPER_ASSERT(cache.ctx);
+
+    const int n_elements = ggml_nelements(cache.k);
+    WHISPER_ASSERT(n_elements == ggml_nelements(cache.v));
+
+    const ggml_type wtype = cache.k->type;
+    WHISPER_ASSERT(wtype == cache.v->type);
+
+    WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_size(wtype));
+
+    struct ggml_init_params params;
+    params.mem_size   = cache.buf.size();
+    params.mem_buffer = cache.buf.data();
+
+    cache.ctx = ggml_init(params);
+
+    if (!cache.ctx) {
+        fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
+        return false;
+    }
+
+    cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
+    cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
+
+    return true;
+}
+
+static void kv_cache_free(struct whisper_kv_cache & cache) {
+    if (cache.ctx) {
+        ggml_free(cache.ctx);
+        cache.ctx = nullptr;
+    }
+}
+
 // load the model from a ggml file
 //
 // file format:
@@ -455,6 +597,10 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
 static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
     fprintf(stderr, "%s: loading model\n", __func__);
 
+    const int64_t t_start_us = ggml_time_us();
+
+    wctx.t_start_us = t_start_us;
+
     auto & model = wctx.model;
     auto & vocab = wctx.vocab;
 
@@ -506,6 +652,12 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
             model.type = e_model::MODEL_LARGE;
         }
 
+        // for the big tensors, we have the option to store the data in 16-bit floats
+        // in order to save memory and also to speed up the computation
+        wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
+
+        const size_t scale = model.hparams.f16 ? 1 : 2;
+
         fprintf(stderr, "%s: n_vocab       = %d\n", __func__, hparams.n_vocab);
         fprintf(stderr, "%s: n_audio_ctx   = %d\n", __func__, hparams.n_audio_ctx);
         fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
@@ -519,11 +671,51 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         fprintf(stderr, "%s: f16           = %d\n", __func__, hparams.f16);
         fprintf(stderr, "%s: type          = %d\n", __func__, model.type);
 
-        wctx.buf_model = new std::vector<uint8_t>();
-        wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type));
-        wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type));
-        wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
-        wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
+        // print memory requirements
+        {
+            // this is the total memory required to run the inference
+            const size_t mem_required =
+                scale*MEM_REQ_MODEL.at       (model.type) +
+                scale*MEM_REQ_KV_CROSS.at    (model.type) +
+                scale*std::max(MEM_REQ_ENCODE.at(model.type),       MEM_REQ_DECODE.at(model.type)) +
+                scale*std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type));
+
+            // this is the memory required by one decoder
+            const size_t mem_required_decoder =
+                scale*MEM_REQ_KV_SELF.at(model.type);
+
+            fprintf(stderr, "%s: mem required  = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
+                    mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
+        }
+
+        // initialize all memory buffers
+        // always have at least one decoder
+
+        wctx.model.buf = new std::vector<uint8_t>();
+        wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type));
+
+        if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) {
+            fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
+            return false;
+        }
+
+        {
+            const size_t memory_size = ggml_nbytes(wctx.decoders[0].kv_self.k) + ggml_nbytes(wctx.decoders[0].kv_self.v);
+            fprintf(stderr, "%s: kv self size  = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
+        }
+
+        if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) {
+            fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
+            return false;
+        }
+
+        {
+            const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v);
+            fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
+        }
+
+        wctx.buf_compute.resize      (scale*std::max(MEM_REQ_ENCODE.at(model.type),       MEM_REQ_DECODE.at(model.type)));
+        wctx.buf_compute_layer.resize(scale*std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
     }
 
     // load mel filters
@@ -607,30 +799,21 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         }
 
         wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
-        wctx.probs.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
 
-        vocab.probs_id.reserve(n_vocab);
-    }
+        wctx.logits_id.reserve(n_vocab);
 
-    {
-        // this is the total memory required to run the inference
-        const size_t mem_required =
-                   wctx.buf_model->size() +
-                   wctx.buf_memory.size() +
-                   wctx.buf_compute.size() +
-                   wctx.buf_compute_layer.size();
+        // TAGS: WHISPER_DECODER_INIT
+        wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx);
 
-        fprintf(stderr, "%s: mem_required  = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0);
+        wctx.decoders[0].probs.reserve   (vocab.n_vocab);
+        wctx.decoders[0].logits.reserve  (vocab.n_vocab);
+        wctx.decoders[0].logprobs.reserve(vocab.n_vocab);
     }
 
-    // for the big tensors, we have the option to store the data in 16-bit floats
-    // in order to save memory and also to speed up the computation
-    wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
+    size_t ctx_size = 0;
 
     const ggml_type wtype = wctx.wtype;
 
-    size_t ctx_size = 0;
-
     {
         const auto & hparams = model.hparams;
 
@@ -738,14 +921,14 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
 
         ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
 
-        fprintf(stderr, "%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
+        fprintf(stderr, "%s: model ctx     = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
     }
 
     // create the ggml context
     {
         struct ggml_init_params params;
-        params.mem_size   = wctx.buf_model->size();
-        params.mem_buffer = wctx.buf_model->data();
+        params.mem_size   = wctx.model.buf->size();
+        params.mem_buffer = wctx.model.buf->data();
 
         model.ctx = ggml_init(params);
         if (!model.ctx) {
@@ -950,56 +1133,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         }
     }
 
-    // create the ggml memory context
-    {
-        struct ggml_init_params params;
-        params.mem_size   = wctx.buf_memory.size();
-        params.mem_buffer = wctx.buf_memory.data();
-
-        model.ctx_mem = ggml_init(params);
-        if (!model.ctx_mem) {
-            fprintf(stderr, "%s: ggml_init() failed\n", __func__);
-            return false;
-        }
-    }
-
-    // key + value memory
-    {
-        auto & ctx = model.ctx_mem;
-
-        const auto & hparams = model.hparams;
-
-        const int n_text_state = hparams.n_text_state;
-        const int n_text_layer = hparams.n_text_layer;
-        const int n_text_ctx   = hparams.n_text_ctx;
-
-        // key/value memory for the self-attention layer
-        {
-            const int n_mem      = n_text_layer*n_text_ctx;
-            const int n_elements = n_text_state*n_mem;
-
-            model.memory_k = ggml_new_tensor_1d(ctx, wtype, n_elements);
-            model.memory_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
-        }
-
-        // key/value memory for the cross-attention layer
-        {
-            const int n_audio_ctx = hparams.n_audio_ctx;
-
-            const int n_mem      = n_text_layer*n_audio_ctx;
-            const int n_elements = n_text_state*n_mem;
-
-            model.memory_cross_k = ggml_new_tensor_1d(ctx, wtype, n_elements);
-            model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
-        }
-
-        const size_t memory_size =
-            ggml_nbytes(model.memory_k)       + ggml_nbytes(model.memory_v) +
-            ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
-
-        fprintf(stderr, "%s: memory size   = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
-    }
-
     // load weights
     {
         size_t total_size = 0;
@@ -1073,6 +1206,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         }
     }
 
+    wctx.rng = std::mt19937(0);
+
+    wctx.t_load_us = ggml_time_us() - t_start_us;
+
     return true;
 }
 
@@ -1086,9 +1223,11 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
 //   - mel_offset: offset in the mel spectrogram (i.e. audio offset)
 //
 static bool whisper_encode(
-              whisper_context & wctx,
-        const int n_threads,
-        const int mel_offset) {
+        whisper_context & wctx,
+              const int   mel_offset,
+              const int   n_threads) {
+    const int64_t t_start_us = ggml_time_us();
+
     const auto & model   = wctx.model;
     const auto & mel_inp = wctx.mel;
     const auto & hparams = model.hparams;
@@ -1229,7 +1368,7 @@ static bool whisper_encode(
 
             // ------
 
-#ifdef USE_FLASH_ATTN
+#ifdef WHISPER_USE_FLASH_ATTN
             struct ggml_tensor * Q =
                 ggml_permute(ctxL,
                         ggml_cpy(ctxL,
@@ -1340,7 +1479,7 @@ static bool whisper_encode(
                         ggml_repeat(ctxL, layer.mlp_ln_b, cur));
             }
 
-#ifdef USE_FLASH_FF
+#ifdef WHISPER_USE_FLASH_FF
             cur = ggml_flash_ff(ctxL,
                     ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, wctx.wtype, n_state, N)),
                     layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
@@ -1461,10 +1600,10 @@ static bool whisper_encode(
                         Vcross),
                     Vcross);
 
-            //struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
-            //struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
-            struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx));
-            struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx));
+            //struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
+            //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
+            struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*n_ctx));
+            struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*n_ctx));
 
             ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
             ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
@@ -1480,6 +1619,8 @@ static bool whisper_encode(
 
     ggml_free(ctx0);
 
+    wctx.t_encode_us += ggml_time_us() - t_start_us;
+
     return true;
 }
 
@@ -1494,16 +1635,22 @@ static bool whisper_encode(
 //   - n_past:     number of past tokens to prefix the prompt with
 //
 static bool whisper_decode(
-              whisper_context & wctx,
-        const int n_threads,
-        const whisper_token * tokens,
-        const int n_tokens,
-        const int n_past) {
+        whisper_context & wctx,
+        whisper_decoder & decoder,
+    const whisper_token * tokens,
+              const int   n_tokens,
+              const int   n_past,
+              const int   n_threads) {
+    const int64_t t_start_us = ggml_time_us();
+
     const auto & model   = wctx.model;
     const auto & hparams = model.hparams;
 
+    auto & kv_self = decoder.kv_self;
+
+    WHISPER_ASSERT(!!kv_self.ctx);
+
     auto & logits_out = wctx.logits;
-    auto & probs_out  = wctx.probs;
 
     const int n_vocab = hparams.n_vocab;
 
@@ -1515,6 +1662,8 @@ static bool whisper_decode(
     const int N = n_tokens;
     const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
 
+    //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
+
     struct ggml_init_params params;
     params.mem_size   = wctx.buf_compute.size();
     params.mem_buffer = wctx.buf_compute.data();
@@ -1593,8 +1742,8 @@ static bool whisper_decode(
 
             // store key and value to memory
             {
-                struct ggml_tensor * k = ggml_view_1d(ctxL, model.memory_k, N*n_state, (ggml_element_size(model.memory_k)*n_state)*(il*n_ctx + n_past));
-                struct ggml_tensor * v = ggml_view_1d(ctxL, model.memory_v, N*n_state, (ggml_element_size(model.memory_v)*n_state)*(il*n_ctx + n_past));
+                struct ggml_tensor * k = ggml_view_1d(ctxL, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
+                struct ggml_tensor * v = ggml_view_1d(ctxL, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past));
 
                 ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k));
                 ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v));
@@ -1612,7 +1761,7 @@ static bool whisper_decode(
             struct ggml_tensor * K =
                 ggml_permute(ctxL,
                         ggml_reshape_3d(ctxL,
-                            ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state),
+                            ggml_view_1d(ctxL, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state),
                             n_state/n_head, n_head, n_past + N),
                         0, 2, 1, 3);
 
@@ -1632,7 +1781,7 @@ static bool whisper_decode(
             struct ggml_tensor * V_trans =
                 ggml_permute(ctxL,
                         ggml_reshape_3d(ctxL,
-                            ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state),
+                            ggml_view_1d(ctxL, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state),
                             n_state/n_head, n_head, n_past + N),
                         1, 2, 0, 3);
 
@@ -1687,12 +1836,12 @@ static bool whisper_decode(
             // Kcross is already scaled
             struct ggml_tensor * Kcross =
                 ggml_reshape_3d(ctxL,
-                        ggml_view_1d(ctxL, model.memory_cross_k, M*n_state, il*M*ggml_element_size(model.memory_cross_k)*n_state),
+                        ggml_view_1d(ctxL, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state),
                         n_state/n_head, n_head, M);
 
             struct ggml_tensor * Vcross =
                 ggml_reshape_3d(ctxL,
-                        ggml_view_1d(ctxL, model.memory_cross_v, M*n_state, il*M*ggml_element_size(model.memory_cross_v)*n_state),
+                        ggml_view_1d(ctxL, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state),
                         n_state/n_head, n_head, M);
 
             // ------
@@ -1823,25 +1972,18 @@ static bool whisper_decode(
 
     struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
 
-    // logits -> probs
-    cur = ggml_dup(ctx0, logits);
-    cur = ggml_soft_max(ctx0, cur); // in-place
-
     // run the computation
     {
         struct ggml_cgraph gf = {};
         gf.n_threads = n_threads;
 
-        ggml_build_forward_expand(&gf, cur);
+        ggml_build_forward_expand(&gf, logits);
         ggml_graph_compute       (ctx0, &gf);
     }
 
     logits_out.resize(N*n_vocab);
     memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
 
-    probs_out.resize(N*n_vocab);
-    memcpy(probs_out.data(), ggml_get_data(cur), sizeof(float)*N*n_vocab);
-
     if (N > 1) {
         //const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N;
         //printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token);
@@ -1850,98 +1992,9 @@ static bool whisper_decode(
 
     ggml_free(ctx0);
 
-    return true;
-}
-
-// the most basic sampling scheme - select the top token
-static whisper_token_data whisper_sample_best(
-              whisper_vocab & vocab,
-        const float * probs,
-              bool force_timestamp,
-              bool is_initial) {
-    whisper_token_data result = {
-        0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
-    };
-
-    const int n_logits = vocab.n_vocab;
-
-    auto & probs_id = vocab.probs_id;
-
-    probs_id.clear();
-    for (int i = 0; i < n_logits; i++) {
-        probs_id.emplace_back(probs[i], i);
-    }
-
-    {
-        double sum_ts =  0.0;
-        double max_ts = -1.0;
-        double max_tx = -1.0;
-
-        for (int i = 0; i < vocab.token_beg; i++) {
-            max_tx = std::max(max_tx, probs_id[i].first);
-        }
-
-        const auto i0 = is_initial ? vocab.token_beg + 101 : vocab.token_beg;
-        const auto i1 = is_initial ? vocab.token_beg + 101 : n_logits;
-
-        // the initial timestamp cannot be larger than 100
-        // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
-        if (is_initial) {
-            for (int i = i0; i < n_logits; ++ i) {
-                probs_id[i].first = -INFINITY;
-            }
-        }
-
-        for (int i = vocab.token_beg; i < i1; i++) {
-            sum_ts += probs_id[i].first;
-            if  (probs_id[i].first > max_ts) {
-                max_ts = probs_id[i].first;
-                result.tid = probs_id[i].second;
-            }
-        }
-
-        // if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a
-        // timestamp token
-        if (sum_ts > max_tx || force_timestamp) {
-            // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
-            for (int i = 0; i < vocab.token_beg; i++) {
-                probs_id[i].first = -INFINITY;
-            }
-        }
-
-        result.pt = max_ts/(sum_ts + 1e-10);
-        result.ptsum = sum_ts;
-    }
-
-    // find the top K tokens
-    const int top_k = 4;
-
-    std::partial_sort(
-            probs_id.begin(),
-            probs_id.begin() + top_k, probs_id.end(),
-            [](const std::pair<double, whisper_vocab::id> & a, const std::pair<double, whisper_vocab::id> & b) {
-        return a.first > b.first;
-    });
-
-    probs_id.resize(top_k);
-
-    //printf("\n");
-    //for (int i = 0; i < (int) probs_id.size(); i++) {
-    //    printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
-    //}
-
-    int res = 0;
-    while ((probs_id[res].second == vocab.token_sot ||
-            probs_id[res].second == vocab.token_solm ||
-            probs_id[res].second == vocab.token_not) &&
-            res < (int) probs_id.size() - 1) {
-        res++;
-    }
-
-    result.id = probs_id[res].second;
-    result.p  = probs_id[res].first;
+    wctx.t_decode_us += ggml_time_us() - t_start_us;
 
-    return result;
+    return true;
 }
 
 //  500 -> 00:05.000
@@ -2043,16 +2096,18 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
 
 // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
 static bool log_mel_spectrogram(
-    const float * samples,
-    const int n_samples,
-    const int /*sample_rate*/,
-    const int fft_size,
-    const int fft_step,
-    const int n_mel,
-    const int n_threads,
-    const whisper_filters & filters,
-    const bool speed_up,
-    whisper_mel & mel) {
+        whisper_context & wctx,
+            const float * samples,
+              const int   n_samples,
+              const int   /*sample_rate*/,
+              const int   fft_size,
+              const int   fft_step,
+              const int   n_mel,
+              const int   n_threads,
+  const whisper_filters & filters,
+             const bool   speed_up,
+            whisper_mel & mel) {
+    const int64_t t_start_us = ggml_time_us();
 
     // Hanning window
     std::vector<float> hann;
@@ -2161,6 +2216,8 @@ static bool log_mel_spectrogram(
         mel.data[i] = (mel.data[i] + 4.0)/4.0;
     }
 
+    wctx.t_mel_us += ggml_time_us() - t_start_us;
+
     return true;
 }
 
@@ -2305,10 +2362,6 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
 
     whisper_context * ctx = new whisper_context;
 
-    const int64_t t_start_us = ggml_time_us();
-
-    ctx->t_start_us = t_start_us;
-
     if (!whisper_model_load(loader, *ctx)) {
         loader->close(loader->context);
         fprintf(stderr, "%s: failed to load model\n", __func__);
@@ -2316,8 +2369,6 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
         return nullptr;
     }
 
-    ctx->t_load_us = ggml_time_us() - t_start_us;
-
     loader->close(loader->context);
 
     return ctx;
@@ -2328,40 +2379,37 @@ void whisper_free(struct whisper_context * ctx) {
         if (ctx->model.ctx) {
             ggml_free(ctx->model.ctx);
         }
-        if (ctx->model.ctx_mem) {
-            ggml_free(ctx->model.ctx_mem);
+        if (ctx->model.buf) {
+            delete ctx->model.buf;
+        }
+        if (ctx->kv_cross.ctx) {
+            ggml_free(ctx->kv_cross.ctx);
         }
-        if (ctx->buf_model) {
-            delete ctx->buf_model;
+        for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
+            if (ctx->decoders[i].kv_self.ctx) {
+                ggml_free(ctx->decoders[i].kv_self.ctx);
+            }
         }
         delete ctx;
     }
 }
 
 int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
-    const int64_t t_start_us = ggml_time_us();
-
-    if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) {
+    if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) {
         fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
         return -1;
     }
 
-    ctx->t_mel_us = ggml_time_us() - t_start_us;
-
     return 0;
 }
 
 // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
 int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
-    const int64_t t_start_us = ggml_time_us();
-
-    if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) {
+    if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) {
         fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
         return -1;
     }
 
-    ctx->t_mel_us = ggml_time_us() - t_start_us;
-
     return 0;
 }
 
@@ -2385,51 +2433,26 @@ int whisper_set_mel(
 }
 
 int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
-    const int64_t t_start_us = ggml_time_us();
-
-    if (!whisper_encode(*ctx, n_threads, offset)) {
+    if (!whisper_encode(*ctx, offset, n_threads)) {
         fprintf(stderr, "%s: failed to eval\n", __func__);
         return -1;
     }
 
-    ctx->t_encode_us += ggml_time_us() - t_start_us;
-
     return 0;
 }
 
 int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
-    const int64_t t_start_us = ggml_time_us();
+    // TODO: add selected_decoder_id to context
+    const int selected_decoder_id = 0;
 
-    if (!whisper_decode(*ctx, n_threads, tokens, n_tokens, n_past)) {
+    if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
         fprintf(stderr, "%s: failed to eval\n", __func__);
         return 1;
     }
 
-    ctx->t_decode_us += ggml_time_us() - t_start_us;
-
     return 0;
 }
 
-struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) {
-    const int64_t t_start_sample_us = ggml_time_us();
-
-    const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false);
-
-    ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-
-    return res;
-}
-
-struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) {
-    const int64_t t_start_sample_us = ggml_time_us();
-
-    const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial);
-
-    ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-
-    return res;
-}
-
 int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) {
     const auto res = tokenize(ctx->vocab, text);
 
@@ -2510,34 +2533,39 @@ int whisper_lang_auto_detect(
         return -7;
     }
 
-    std::vector<std::pair<float, int>> probs_id;
+    auto & logits_id = ctx->logits_id;
+    logits_id.clear();
+
     for (const auto & kv : g_lang) {
         const auto token_lang = whisper_token_lang(ctx, kv.second.first);
-        probs_id.emplace_back(ctx->probs[token_lang], kv.second.first);
+        logits_id.emplace_back(ctx->logits[token_lang], kv.second.first);
     }
 
     // sort descending
     {
-        using pair_type = decltype(probs_id)::value_type;
-        std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
+        using pair_type = std::remove_reference<decltype(logits_id)>::type::value_type;
+        std::sort(logits_id.begin(), logits_id.end(), [](const pair_type & a, const pair_type & b) {
             return a.first > b.first;
         });
     }
 
     // softmax
     {
-        float sum = 0;
-        for (const auto & kv : probs_id) {
-            sum += exp(kv.first);
+        const auto max = logits_id[0].first;
+
+        double sum = 0.0f;
+        for (auto & kv : logits_id) {
+            kv.first = exp(kv.first - max);
+            sum += kv.first;
         }
 
-        for (auto & kv : probs_id) {
-            kv.first = exp(kv.first) / sum;
+        for (auto & kv : logits_id) {
+            kv.first /= sum;
         }
     }
 
     {
-        for (const auto & prob : probs_id) {
+        for (const auto & prob : logits_id) {
             if (lang_probs) {
                 lang_probs[prob.second] = prob.first;
             }
@@ -2546,7 +2574,7 @@ int whisper_lang_auto_detect(
         }
     }
 
-    return probs_id[0].second;
+    return logits_id[0].second;
 }
 
 int whisper_n_len(struct whisper_context * ctx) {
@@ -2569,8 +2597,8 @@ int whisper_is_multilingual(struct whisper_context * ctx) {
     return ctx->vocab.is_multilingual() ? 1 : 0;
 }
 
-float * whisper_get_probs(struct whisper_context * ctx) {
-    return ctx->probs.data();
+float * whisper_get_logits(struct whisper_context * ctx) {
+    return ctx->logits.data();
 }
 
 const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
@@ -2654,105 +2682,77 @@ const char * whisper_print_system_info(void) {
 ////////////////////////////////////////////////////////////////////////////
 
 struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
-    struct whisper_full_params result;
+    struct whisper_full_params result = {
+        /*.strategy         =*/ WHISPER_SAMPLING_GREEDY,
+
+        /*.n_threads        =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
+        /*.n_max_text_ctx   =*/ 16384,
+        /*.offset_ms        =*/ 0,
+        /*.duration_ms      =*/ 0,
+
+        /*.translate        =*/ false,
+        /*.no_context       =*/ false,
+        /*.single_segment   =*/ false,
+        /*.print_special    =*/ false,
+        /*.print_progress   =*/ true,
+        /*.print_realtime   =*/ false,
+        /*.print_timestamps =*/ true,
+
+        /*.token_timestamps =*/ false,
+        /*.thold_pt         =*/ 0.01f,
+        /*.thold_ptsum      =*/ 0.01f,
+        /*.max_len          =*/ 0,
+        /*.max_tokens       =*/ 0,
+
+        /*.speed_up         =*/ false,
+        /*.audio_ctx        =*/ 0,
+
+        /*.prompt_tokens    =*/ nullptr,
+        /*.prompt_n_tokens  =*/ 0,
+
+        /*.language         =*/ "en",
+
+        /*.suppress_blank   =*/ true,
+
+        /*.temperature      =*/  0.0f,
+        /*.max_initial_ts   =*/  1.0f,
+        /*.length_penalty   =*/ -1.0f,
+
+        /*.temperature_inc  =*/  0.2f,
+        /*.entropy_thold    =*/  2.4f,
+        /*.logprob_thold    =*/ -1.0f,
+        /*.no_speech_thold  =*/  0.6f,
+
+        /*.greedy           =*/ {
+            /*.best_of   =*/ -1,
+        },
+
+        /*.beam_search      =*/ {
+            /*.beam_size =*/ -1,
+
+            /*.patience  =*/ -1.0f,
+        },
+
+        /*.new_segment_callback           =*/ nullptr,
+        /*.new_segment_callback_user_data =*/ nullptr,
+
+        /*.encoder_begin_callback           =*/ nullptr,
+        /*.encoder_begin_callback_user_data =*/ nullptr,
+    };
 
     switch (strategy) {
         case WHISPER_SAMPLING_GREEDY:
             {
-                result = {
-                    /*.strategy         =*/ WHISPER_SAMPLING_GREEDY,
-
-                    /*.n_threads        =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
-                    /*.n_max_text_ctx   =*/ 16384,
-                    /*.offset_ms        =*/ 0,
-                    /*.duration_ms      =*/ 0,
-
-                    /*.translate        =*/ false,
-                    /*.no_context       =*/ false,
-                    /*.single_segment   =*/ false,
-                    /*.print_special    =*/ false,
-                    /*.print_progress   =*/ true,
-                    /*.print_realtime   =*/ false,
-                    /*.print_timestamps =*/ true,
-
-                    /*.token_timestamps =*/ false,
-                    /*.thold_pt         =*/ 0.01f,
-                    /*.thold_ptsum      =*/ 0.01f,
-                    /*.max_len          =*/ 0,
-                    /*.max_tokens       =*/ 0,
-
-                    /*.speed_up         =*/ false,
-                    /*.audio_ctx        =*/ 0,
-
-                    /*.prompt_tokens    =*/ nullptr,
-                    /*.prompt_n_tokens  =*/ 0,
-
-                    /*.language         =*/ "en",
-
-                    /*.greedy           =*/ {
-                        /*.n_past =*/ 0,
-                    },
-
-                    /*.beam_search      =*/ {
-                        /*.n_past     =*/ -1,
-                        /*.beam_width =*/ -1,
-                        /*.n_best     =*/ -1,
-                    },
-
-                    /*.new_segment_callback           =*/ nullptr,
-                    /*.new_segment_callback_user_data =*/ nullptr,
-
-                    /*.encoder_begin_callback           =*/ nullptr,
-                    /*.encoder_begin_callback_user_data =*/ nullptr,
+                result.greedy = {
+                    /*.best_of   =*/ 1,
                 };
             } break;
         case WHISPER_SAMPLING_BEAM_SEARCH:
             {
-                result = {
-                    /*.strategy         =*/ WHISPER_SAMPLING_BEAM_SEARCH,
-
-                    /*.n_threads        =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
-                    /*.n_max_text_ctx   =*/ 16384,
-                    /*.offset_ms        =*/ 0,
-                    /*.duration_ms      =*/ 0,
-
-                    /*.translate        =*/ false,
-                    /*.no_context       =*/ false,
-                    /*.single_segment   =*/ false,
-                    /*.print_special    =*/ false,
-                    /*.print_progress   =*/ true,
-                    /*.print_realtime   =*/ false,
-                    /*.print_timestamps =*/ true,
-
-                    /*.token_timestamps =*/ false,
-                    /*.thold_pt         =*/ 0.01f,
-                    /*.thold_ptsum      =*/ 0.01f,
-                    /*.max_len          =*/ 0,
-                    /*.max_tokens       =*/ 0,
-
-                    /*.speed_up         =*/ false,
-                    /*.audio_ctx        =*/ 0,
-
-                    /*.prompt_tokens    =*/ nullptr,
-                    /*.prompt_n_tokens  =*/ 0,
-
-                    /*.language         =*/ "en",
-
-                    /*.greedy           =*/ {
-                        /*.n_past =*/ -1,
-                    },
-
-                    /*.beam_search      =*/ {
-                        /*.n_past     =*/ 0,
-                        /*.beam_width =*/ 10,
-                        /*.n_best     =*/ 5,
-                    },
-
-                    /*.new_segment_callback           =*/ nullptr,
-                    /*.new_segment_callback_user_data =*/ nullptr,
-
-                    /*.encoder_begin_callback           =*/ nullptr,
-                    /*.encoder_begin_callback_user_data =*/ nullptr,
+                result.beam_search = {
+                    /*.beam_size =*/ 5,
+
+                    /*.patience  =*/ -1.0f,
                 };
             } break;
     }
@@ -2763,15 +2763,15 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
 // forward declarations
 static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
 static void whisper_exp_compute_token_level_timestamps(
-        struct whisper_context * ctx,
-        int   i_segment,
-        float thold_pt,
-        float thold_ptsum);
+        struct whisper_context & ctx,
+                           int   i_segment,
+                         float   thold_pt,
+                         float   thold_ptsum);
 
 // wrap the last segment to max_len characters
 // returns the number of new segments
-static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) {
-    auto segment = ctx->result_all.back();
+static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
+    auto segment = ctx.result_all.back();
 
     int res = 1;
     int acc = 0;
@@ -2780,34 +2780,34 @@ static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) {
 
     for (int i = 0; i < (int) segment.tokens.size(); i++) {
         const auto & token = segment.tokens[i];
-        if (token.id >= whisper_token_eot(ctx)) {
+        if (token.id >= whisper_token_eot(&ctx)) {
             continue;
         }
 
-        const auto txt = whisper_token_to_str(ctx, token.id);
+        const auto txt = whisper_token_to_str(&ctx, token.id);
 
         const int cur = strlen(txt);
 
         if (acc + cur > max_len && i > 0) {
             // split here
-            ctx->result_all.back().text = std::move(text);
-            ctx->result_all.back().t1 = token.t0;
-            ctx->result_all.back().tokens.resize(i);
+            ctx.result_all.back().text = std::move(text);
+            ctx.result_all.back().t1 = token.t0;
+            ctx.result_all.back().tokens.resize(i);
 
-            ctx->result_all.push_back({});
-            ctx->result_all.back().t0 = token.t0;
-            ctx->result_all.back().t1 = segment.t1;
+            ctx.result_all.push_back({});
+            ctx.result_all.back().t0 = token.t0;
+            ctx.result_all.back().t1 = segment.t1;
 
             // add tokens [i, end] to the new segment
-            ctx->result_all.back().tokens.insert(
-                    ctx->result_all.back().tokens.end(),
+            ctx.result_all.back().tokens.insert(
+                    ctx.result_all.back().tokens.end(),
                     segment.tokens.begin() + i,
                     segment.tokens.end());
 
             acc = 0;
             text = "";
 
-            segment = ctx->result_all.back();
+            segment = ctx.result_all.back();
             i = -1;
 
             res++;
@@ -2817,52 +2817,409 @@ static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) {
         }
     }
 
-    ctx->result_all.back().text = std::move(text);
+    ctx.result_all.back().text = std::move(text);
 
     return res;
 }
 
-int whisper_full(
-        struct whisper_context * ctx,
-        struct whisper_full_params params,
-        const float * samples,
-        int n_samples) {
-    // clear old results
-    auto & result_all = ctx->result_all;
-
-    result_all.clear();
+// process the logits for the selected decoder
+// - applies logit filters
+// - computes logprobs and probs
+static void whisper_process_logits(
+        const struct whisper_context & ctx,
+    const struct whisper_full_params   params,
+              struct whisper_decoder & decoder,
+                               float   temperature) {
+    const auto & vocab      = ctx.vocab;
+    const auto & tokens_cur = decoder.sequence.tokens;
+
+    const bool is_initial = tokens_cur.size() == 0;
+    const int  n_logits   = vocab.id_to_token.size();
+
+    WHISPER_ASSERT(n_logits == ctx.vocab.n_vocab);
+
+    // extract the logits for the last token
+    // we will be mutating and therefore we don't want to use the ctx.logits buffer directly
+    auto & probs    = decoder.probs;
+    auto & logits   = decoder.logits;
+    auto & logprobs = decoder.logprobs;
+    {
+        logits.resize(n_logits);
+        memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float));
 
-    // compute log mel spectrogram
-    if (params.speed_up) {
-        if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) {
-            fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
-            return -1;
-        }
-    } else {
-        if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
-            fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
-            return -2;
+        if (temperature > 0.0f) {
+            for (int i = 0; i < n_logits; i++) {
+                logits[i] /= temperature;
+            }
         }
-    }
 
-    // auto-detect language if not specified
-    if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
-        std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
+        // will be populated a bit later
+        probs.resize(n_logits);
+        logprobs.resize(n_logits);
+    }
 
-        const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data());
-        if (lang_id < 0) {
-            fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
-            return -3;
+    // apply logit filters here
+    // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L480-L493
+    {
+        // suppress blank
+        // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L388-L390
+        if (params.suppress_blank) {
+            if (is_initial) {
+                logits[vocab.token_eot]           = -INFINITY;
+                logits[vocab.token_to_id.at(" ")] = -INFINITY;
+            }
         }
 
-        params.language = whisper_lang_str(lang_id);
+        // suppress <|notimestamps|> token
+        // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
+        logits[vocab.token_not] = -INFINITY;
 
-        fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
+        // suppress sot and solm tokens
+        logits[vocab.token_sot]  = -INFINITY;
+        logits[vocab.token_solm] = -INFINITY;
+
+        // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
+        // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424
+        {
+            const bool last_was_timestamp        = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg;
+            const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg;
+
+            //fprintf(stderr, "last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp);
+
+            if (last_was_timestamp) {
+                if (penultimate_was_timestamp) {
+                    for (int i = vocab.token_beg; i < n_logits; ++i) {
+                        logits[i] = -INFINITY;
+                    }
+                } else {
+                    for (int i = 0; i < vocab.token_eot; ++i) {
+                        logits[i] = -INFINITY;
+                    }
+                }
+            }
+        }
+
+        // the initial timestamp cannot be larger than max_initial_ts
+        // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
+        if (is_initial && params.max_initial_ts > 0.0f) {
+            const float precision = float(WHISPER_CHUNK_SIZE)/ctx.model.hparams.n_audio_ctx;
+            const int   tid0      = std::round(params.max_initial_ts/precision);
+
+            for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++i) {
+                logits[i] = -INFINITY;
+            }
+        }
+
+        // populate the logprobs array (log_softmax)
+        {
+            const float logit_max = *std::max_element(logits.begin(), logits.end());
+            float logsumexp = 0.0f;
+            for (int i = 0; i < n_logits; ++i) {
+                if (logits[i] > -INFINITY) {
+                    logsumexp += expf(logits[i] - logit_max);
+                }
+            }
+            logsumexp = logf(logsumexp) + logit_max;
+
+            for (int i = 0; i < n_logits; ++i) {
+                if (logits[i] > -INFINITY) {
+                    logprobs[i] = logits[i] - logsumexp;
+                } else {
+                    logprobs[i] = -INFINITY;
+                }
+            }
+        }
+
+        // if sum of probability over timestamps is above any other token, sample timestamp
+        // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
+        {
+            // logsumexp over timestamps
+            float timestamp_logprob = -INFINITY;
+            {
+                float logsumexp = 0.0f;
+                const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end());
+                for (int i = vocab.token_beg; i < n_logits; ++i) {
+                    if (logprobs[i] > -INFINITY) {
+                        logsumexp += expf(logprobs[i] - logprob_max);
+                    }
+                }
+                if (logsumexp > 0.0f) {
+                    timestamp_logprob = logf(logsumexp) + logprob_max;
+                }
+            }
+
+            const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg);
+
+            //fprintf(stderr, "timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
+
+            if (timestamp_logprob > max_text_token_logprob) {
+                for (int i = 0; i < vocab.token_beg; ++i) {
+                    logits[i]   = -INFINITY;
+                    logprobs[i] = -INFINITY;
+                }
+            }
+        }
+    }
+
+    // compute probs
+    {
+        for (int i = 0; i < n_logits; ++i) {
+            if (logits[i] == -INFINITY) {
+                probs[i] = 0.0f;
+            } else {
+                probs[i] = expf(logprobs[i]);
+            }
+        }
+    }
+
+#if 0
+    // print first 100 logits - token string : logit
+    for (int i = 0; i < 100; i++) {
+        const auto token   = vocab.id_to_token.at(i);
+        const auto prob    = probs[i];
+        const auto logit   = logits[i];
+        const auto logprob = logprobs[i];
+        printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
+    }
+
+    // "And", "and", " And", " and"
+    printf("logits[\"and\"]  = %f\n", logits[vocab.token_to_id.at("and")]);
+    printf("logits[\"And\"]  = %f\n", logits[vocab.token_to_id.at("And")]);
+    printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
+    printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
+    printf("logits[\" so\"]  = %f\n", logits[vocab.token_to_id.at(" so")]);
+
+    printf("logprobs[\"and\"]  = %f\n", logprobs[vocab.token_to_id.at("and")]);
+    printf("logprobs[\"And\"]  = %f\n", logprobs[vocab.token_to_id.at("And")]);
+    printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
+    printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
+    printf("logprobs[\" so\"]  = %f\n", logprobs[vocab.token_to_id.at(" so")]);
+
+    printf("probs[\"and\"]  = %f\n", probs[vocab.token_to_id.at("and")]);
+    printf("probs[\"And\"]  = %f\n", probs[vocab.token_to_id.at("And")]);
+    printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
+    printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
+    printf("probs[\" so\"]  = %f\n", probs[vocab.token_to_id.at(" so")]);
+#endif
+}
+
+static whisper_token_data whisper_sample_token(
+      const whisper_context & ctx,
+      const whisper_decoder & decoder,
+                       bool   best) {
+    whisper_token_data result = {
+        0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
+    };
+
+    const auto & vocab = ctx.vocab;
+
+    const auto & probs    = decoder.probs;
+    const auto & logprobs = decoder.logprobs;
+
+    const int n_logits = vocab.n_vocab;
+
+    {
+        double sum_ts = 0.0;
+        double max_ts = 0.0;
+
+        for (int i = vocab.token_beg; i < n_logits; i++) {
+            if (probs[i] == -INFINITY) {
+                continue;
+            }
+
+            sum_ts += probs[i];
+            if (max_ts < probs[i]) {
+                max_ts = probs[i];
+                result.tid = i;
+            }
+        }
+
+        result.pt    = max_ts/(sum_ts + 1e-10);
+        result.ptsum = sum_ts;
+    }
+
+    if (best) {
+        for (int i = 0; i < n_logits; ++i) {
+            if (result.p < probs[i]) {
+                result.id   = i;
+                result.p    = probs[i];
+                result.plog = logprobs[i];
+            }
+        }
+    } else {
+        std::discrete_distribution<> dist(probs.begin(), probs.end());
+
+        result.id   = dist(ctx.rng);
+        result.p    = probs[result.id];
+        result.plog = logprobs[result.id];
+    }
+
+    if (result.id >= vocab.token_beg) {
+        result.tid = result.id;
+        result.pt  = result.p;
+    }
+
+    return result;
+}
+
+static std::vector<whisper_token_data> whisper_sample_token_topk(
+            whisper_context & ctx,
+      const whisper_decoder & decoder,
+                        int   k) {
+    const auto & vocab = ctx.vocab;
+
+    const auto & probs    = decoder.probs;
+    const auto & logits   = decoder.logits;
+    const auto & logprobs = decoder.logprobs;
+
+    const int n_logits = vocab.n_vocab;
+
+    auto & logits_id = ctx.logits_id;
+
+    logits_id.clear();
+    for (int i = 0; i < n_logits; ++i) {
+        logits_id.push_back({ logits[i], i });
+    }
+
+    std::partial_sort(
+            logits_id.begin(),
+            logits_id.begin() + k, logits_id.end(),
+            [](const std::pair<double, whisper_token> & a, const std::pair<double, whisper_token> & b) {
+                return a.first > b.first;
+            });
+
+    std::vector<whisper_token_data> result;
+    result.reserve(k);
+
+    whisper_token tid;
+
+    float pt;
+    float ptsum;
+
+    {
+        double sum_ts = 0.0;
+        double max_ts = 0.0;
+
+        for (int i = vocab.token_beg; i < n_logits; i++) {
+            if (probs[i] == -INFINITY) {
+                continue;
+            }
+
+            sum_ts += probs[i];
+            if (max_ts < probs[i]) {
+                max_ts = probs[i];
+                tid = i;
+            }
+        }
+
+        pt    = max_ts/(sum_ts + 1e-10);
+        ptsum = sum_ts;
+    }
+
+    for (int i = 0; i < k; ++i) {
+        const auto id = logits_id[i].second;
+
+        result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
+
+        if (result[i].id >= vocab.token_beg) {
+            result[i].tid = result[i].id;
+            result[i].pt  = result[i].p;
+        }
+    }
+
+    return result;
+}
+
+// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L178-L192
+static void whisper_sequence_score(
+        const struct whisper_full_params & params,
+                        whisper_sequence & sequence) {
+    if (sequence.result_len == 0) {
+        return;
+    }
+
+    double result = 0.0f;
+
+    for (int i = 0; i < sequence.result_len; ++i) {
+        result += sequence.tokens[i].plog;
+    }
+
+    sequence.sum_logprobs = result;
+    sequence.avg_logprobs = result/sequence.result_len;
+
+    double penalty = sequence.result_len;
+
+    if (params.length_penalty > 0.0f) {
+        penalty = pow((5.0 + penalty)/6.0, params.length_penalty);
+    }
+
+    sequence.score = result/penalty;
+
+    // compute the entropy of the sequence of the last 32 tokens
+    {
+        const int n = 32;
+
+        int cnt = 0;
+        double entropy = 0.0f;
+
+        std::map<whisper_token, int> token_counts;
+        for (int i = std::max(0, sequence.result_len - n); i < sequence.result_len; ++i) {
+            token_counts[sequence.tokens[i].id]++;
+            cnt++;
+        }
+
+        for (const auto & kv : token_counts) {
+            const auto p = kv.second/(double)cnt;
+            entropy -= p*log(p);
+
+            //WHISPER_PRINT_DEBUG("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second);
+        }
+
+        sequence.entropy = entropy;
+    }
+}
+
+int whisper_full(
+        struct whisper_context * ctx,
+        struct whisper_full_params params,
+        const float * samples,
+        int n_samples) {
+    // clear old results
+    auto & result_all = ctx->result_all;
+
+    result_all.clear();
+
+    // compute log mel spectrogram
+    if (params.speed_up) {
+        if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) {
+            fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
+            return -1;
+        }
+    } else {
+        if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
+            fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
+            return -2;
+        }
+    }
+
+    // auto-detect language if not specified
+    if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
+        std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
+
+        const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data());
+        if (lang_id < 0) {
+            fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
+            return -3;
+        }
+
+        params.language = whisper_lang_str(lang_id);
+
+        fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
     }
 
     if (params.token_timestamps) {
-        ctx->t_beg = 0;
-        ctx->t_last = 0;
+        ctx->t_beg    = 0;
+        ctx->t_last   = 0;
         ctx->tid_last = 0;
         ctx->energy = get_signal_energy(samples, n_samples, 32);
     }
@@ -2877,6 +3234,54 @@ int whisper_full(
         return 0;
     }
 
+    // a set of temperatures to use
+    // [ t0, t0 + delta, t0 + 2*delta, ..., < 1.0f + 1e-6f ]
+    std::vector<float> temperatures;
+    if (params.temperature_inc > 0.0f) {
+        for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_inc) {
+            temperatures.push_back(t);
+        }
+    } else {
+        temperatures.push_back(params.temperature);
+    }
+
+    // initialize the decoders
+    int n_decoders = 1;
+
+    switch (params.strategy) {
+        case WHISPER_SAMPLING_GREEDY:
+            {
+                n_decoders = params.greedy.best_of;
+            } break;
+        case WHISPER_SAMPLING_BEAM_SEARCH:
+            {
+                n_decoders = std::max(params.greedy.best_of, params.beam_search.beam_size);
+            } break;
+    };
+
+    n_decoders = std::max(1, n_decoders);
+
+    // TAGS: WHISPER_DECODER_INIT
+    for (int j = 1; j < n_decoders; j++) {
+        auto & decoder = ctx->decoders[j];
+
+        if (decoder.kv_self.ctx == nullptr) {
+            decoder.kv_self = ctx->decoders[0].kv_self;
+            if (!kv_cache_reinit(decoder.kv_self)) {
+                fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
+                return -4;
+            }
+
+            WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
+
+            decoder.sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity());
+
+            decoder.probs.resize   (ctx->vocab.n_vocab);
+            decoder.logits.resize  (ctx->vocab.n_vocab);
+            decoder.logprobs.resize(ctx->vocab.n_vocab);
+        }
+    }
+
     // the accumulated text context so far
     auto & prompt_past = ctx->prompt_past;
     if (params.no_context) {
@@ -2895,7 +3300,7 @@ int whisper_full(
     // overwrite audio_ctx, max allowed is hparams.n_audio_ctx
     if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
         fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
-        return -4;
+        return -5;
     }
     ctx->exp_n_audio_ctx = params.audio_ctx;
 
@@ -2914,14 +3319,31 @@ int whisper_full(
     int progress_prev = 0;
     int progress_step = 5;
 
-    std::vector<whisper_token_data> tokens_cur;
-    tokens_cur.reserve(whisper_n_text_ctx(ctx));
+    int seek = seek_start;
 
     std::vector<whisper_token> prompt;
     prompt.reserve(whisper_n_text_ctx(ctx));
 
+    // beam-search helpers
+    struct kv_buf {
+        std::vector<uint8_t> k;
+        std::vector<uint8_t> v;
+    };
+
+    std::vector<kv_buf> kv_bufs;
+
+    struct beam_candidate {
+        int decoder_idx;
+        int seek_delta;
+
+        bool has_ts;
+
+        whisper_sequence sequence;
+    };
+
+    std::vector<beam_candidate> beam_candidates;
+
     // main loop
-    int seek = seek_start;
     while (true) {
         const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
         while (progress_cur >= progress_prev + progress_step) {
@@ -2936,12 +3358,6 @@ int whisper_full(
             break;
         }
 
-        // if there is a very short audio segment left to process, we remove any past prompt since it tends
-        // to confuse the decoder and often make it repeat or hallucinate stuff
-        if (seek > seek_start && seek + 500 >= seek_end) {
-            prompt_past.clear();
-        }
-
         if (params.encoder_begin_callback) {
             if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
                 fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
@@ -2950,239 +3366,526 @@ int whisper_full(
         }
 
         // encode audio features starting at offset seek
-        if (whisper_encode(ctx, seek, params.n_threads) != 0) {
+        if (!whisper_encode(*ctx, seek, params.n_threads)) {
             fprintf(stderr, "%s: failed to encode\n", __func__);
-            return -4;
+            return -6;
+        }
+
+        // if there is a very short audio segment left to process, we remove any past prompt since it tends
+        // to confuse the decoder and often make it repeat or hallucinate stuff
+        if (seek > seek_start && seek + 500 >= seek_end) {
+            prompt_past.clear();
         }
 
-        int n_past = 0;
-        prompt.clear();
+        int best_decoder_id = 0;
 
-        // if we have already generated some text, use it as a prompt to condition the next generation
-        if (!prompt_past.empty()) {
-            int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
+        for (int it = 0; it < (int) temperatures.size(); ++it) {
+            const float t_cur = temperatures[it];
 
-            prompt = { whisper_token_prev(ctx) };
-            prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
+            int n_decoders_cur = 1;
 
-            prompt_past.clear();
-            prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end());
-        }
+            switch (params.strategy) {
+                case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
+                    {
+                        if (t_cur > 0.0f) {
+                            n_decoders_cur = params.greedy.best_of;
+                        }
+                    } break;
+                case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
+                    {
+                        if (t_cur > 0.0f) {
+                            n_decoders_cur = params.greedy.best_of;
+                        } else {
+                            n_decoders_cur = params.beam_search.beam_size;
+                        }
+                    } break;
+            };
 
-        prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
+            n_decoders_cur = std::max(1, n_decoders_cur);
 
-        int seek_delta = 100*WHISPER_CHUNK_SIZE;
+            WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur);
 
-        // print the prompt
-        //printf("\n\n");
-        //for (int i = 0; i < prompt.size(); i++) {
-        //    printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str());
-        //}
-        //printf("\n\n");
+            // TAGS: WHISPER_DECODER_INIT
+            for (int j = 0; j < n_decoders_cur; ++j) {
+                auto & decoder = ctx->decoders[j];
 
-        // the accumulated transcription in the current interation
-        int result_len = 0;
-        tokens_cur.clear();
+                decoder.kv_self.n = 0;
 
-        bool failed = false;
-        bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment?
+                decoder.sequence.tokens.clear();
+                decoder.sequence.result_len       = 0;
+                decoder.sequence.sum_logprobs_all = 0.0;
+                decoder.sequence.sum_logprobs     = -INFINITY;
+                decoder.sequence.avg_logprobs     = -INFINITY;
+                decoder.sequence.entropy          = 0.0;
+                decoder.sequence.score            = -INFINITY;
 
-        for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
-            if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
-                fprintf(stderr, "%s: failed to decode\n", __func__);
-                return -5;
-            }
+                decoder.seek_delta = 100*WHISPER_CHUNK_SIZE;
 
-            n_past += prompt.size();
-            prompt.clear();
+                decoder.failed    = false;
+                decoder.completed = false;
+                decoder.has_ts    = false;
+            }
 
-            // very basic greedy sampling strategy:
-            //
-            //   - always take the most probable token
-            //
-            // more sophisticated sampling strategies could be implemented here, but we keep it simple
-            // feel free to experiment!
-            //
+            // init prompt and kv cache for the current iteration
+            // run whisper_decoder() only for decoder 0 and copy the results for the other decoders
             {
-                const auto token = (i == 0) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx);
+                prompt.clear();
 
-                // timestamp token - update sliding window
-                if (token.id > whisper_token_beg(ctx)) {
-                    const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
+                // if we have already generated some text, use it as a prompt to condition the next generation
+                if (!prompt_past.empty() && t_cur > 0.5f) {
+                    int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
 
-                    // do not allow to go back in time
-                    if (has_ts && seek_delta > seek_delta_new && result_len < i) {
-                        break;
+                    prompt = { whisper_token_prev(ctx) };
+                    prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
+                }
+
+                // init new transcription with sot, language (opt) and task tokens
+                prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
+
+                // print the prompt
+                //WHISPER_PRINT_DEBUG("\n\n");
+                //for (int i = 0; i < (int) prompt.size(); i++) {
+                //    WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str());
+                //}
+                //WHISPER_PRINT_DEBUG("\n\n");
+
+                if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
+                    fprintf(stderr, "%s: failed to decode\n", __func__);
+                    return -7;
+                }
+
+                {
+                    const int64_t t_start_sample_us = ggml_time_us();
+
+                    whisper_process_logits(*ctx, params, ctx->decoders[0], t_cur);
+
+                    ctx->decoders[0].kv_self.n += prompt.size();
+
+                    for (int j = 1; j < n_decoders_cur; ++j) {
+                        auto & decoder = ctx->decoders[j];
+
+                        memcpy(decoder.kv_self.k->data, ctx->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
+                        memcpy(decoder.kv_self.v->data, ctx->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
+
+                        decoder.kv_self.n += prompt.size();
+
+                        memcpy(decoder.probs.data(),    ctx->decoders[0].probs.data(),    decoder.probs.size()*sizeof(decoder.probs[0]));
+                        memcpy(decoder.logits.data(),   ctx->decoders[0].logits.data(),   decoder.logits.size()*sizeof(decoder.logits[0]));
+                        memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
                     }
 
-                    seek_delta = seek_delta_new;
-                    result_len = i + 1;
-                    has_ts = true;
+                    ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
                 }
+            }
 
-                // add it to the context
-                prompt.push_back(token.id);
-                tokens_cur.push_back(token);
+            for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
+                const int64_t t_start_sample_us = ggml_time_us();
 
-                //{
-                //    const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
-                //    printf("%s: %3d %10s %6d %6.3f '%s'\n", __func__, i, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
-                //}
+                // store the KV caches of all decoders when doing beam-search
+                if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
+                    kv_bufs.resize(n_decoders_cur);
+                    for (int j = 0; j < n_decoders_cur; ++j) {
+                        auto & decoder = ctx->decoders[j];
 
-                // end of segment
-                if (token.id == whisper_token_eot(ctx) ||                // end of text token
-                    (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
-                    (has_ts && seek + seek_delta + 100 >= seek_end)      // end of audio reached
-                    ) {
-                    if (result_len == 0) {
-                        if (seek + seek_delta + 100 >= seek_end) {
-                            result_len = i + 1;
-                        } else {
-                            failed = true;
-                            break;
+                        if (decoder.completed || decoder.failed) {
+                            continue;
                         }
-                    }
 
-                    if (params.single_segment) {
-                        result_len = i + 1;
-                        seek_delta = 100*WHISPER_CHUNK_SIZE;
+                        kv_bufs[j].k.resize(ggml_nbytes(decoder.kv_self.k));
+                        kv_bufs[j].v.resize(ggml_nbytes(decoder.kv_self.v));
+
+                        memcpy(kv_bufs[j].k.data(), decoder.kv_self.k->data, kv_bufs[j].k.size());
+                        memcpy(kv_bufs[j].v.data(), decoder.kv_self.v->data, kv_bufs[j].v.size());
                     }
 
-                    break;
+                    beam_candidates.clear();
                 }
 
-                // TESTS: if no tensors are loaded, it means we are running tests
-                if (ctx->model.n_loaded == 0) {
-                    seek_delta = 100*WHISPER_CHUNK_SIZE;
-                    break;
+                // generate new sequence candidates for each decoder
+                for (int j = 0; j < n_decoders_cur; ++j) {
+                    auto & decoder = ctx->decoders[j];
+
+                    if (decoder.completed || decoder.failed) {
+                        continue;
+                    }
+
+                    switch (params.strategy) {
+                        case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
+                            {
+                                if (t_cur < 1e-6f) {
+                                    decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
+                                } else {
+                                    decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
+                                }
+
+                                decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
+                            } break;
+                        case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
+                            {
+                                const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
+
+                                for (const auto & token : tokens_new) {
+                                    beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence });
+                                    beam_candidates.back().sequence.tokens.push_back(token);
+                                    beam_candidates.back().sequence.sum_logprobs_all += token.plog;
+
+                                    //WHISPER_PRINT_DEBUG("%s: beam candidate: %s (%f, %f)\n", __func__, ctx->vocab.id_to_token.at(token.id).c_str(), token.plog, beam_candidates.back().sequence.sum_logprobs_all);
+                                }
+                            } break;
+                    };
                 }
-            }
 
-            // sometimes, the decoding can get stuck in a repetition loop
-            // this is a simple strategy to avoid such cases - we simply flag the decoding as failed and advance
-            // the sliding window by 1 second
-            if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
-                failed = true;
-                break;
-            }
-        }
+                // for beam-search, choose the top candidates and update the KV caches
+                if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
+                    std::sort(
+                            beam_candidates.begin(),
+                            beam_candidates.end(),
+                            [](const beam_candidate & a, const beam_candidate & b) {
+                        return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
+                    });
 
-        if (failed) {
-            // when we fail to sample timestamp token, retry by clearing the past prompt
-            // if it fails again, then we advance the window by 1 second
-            if (!prompt_past.empty()) {
-                prompt_past.clear();
-            } else {
-                fprintf(stderr, "\n%s: failed to generate timestamp token - skipping one second\n\n", __func__);
-                seek += 100;
-            }
-            continue;
-        }
+                    int cur_c = 0;
 
-        // shrink down to result_len
-        tokens_cur.resize(result_len);
+                    for (int j = 0; j < n_decoders_cur; ++j) {
+                        auto & decoder = ctx->decoders[j];
 
-        for (const auto & r : tokens_cur) {
-            prompt_past.push_back(r.id);
-        }
+                        if (decoder.completed || decoder.failed) {
+                            continue;
+                        }
 
-        // store the text from this iteration
-        if (!tokens_cur.empty()) {
-            int  i0 = 0;
-            auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
+                        auto & cur = beam_candidates[cur_c++];
 
-            std::string text;
+                        while (beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
+                            ++cur_c;
+                        }
 
-            for (int i = 0; i < (int) tokens_cur.size(); i++) {
-                //printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
-                //        ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
-                //        ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
+                        decoder.sequence   = cur.sequence;
+                        decoder.seek_delta = cur.seek_delta;
+                        decoder.has_ts     = cur.has_ts;
 
-                if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
-                } else {
-                    text += whisper_token_to_str(ctx, tokens_cur[i].id);
+                        memcpy(decoder.kv_self.k->data, kv_bufs[cur.decoder_idx].k.data(), kv_bufs[cur.decoder_idx].k.size());
+                        memcpy(decoder.kv_self.v->data, kv_bufs[cur.decoder_idx].v.data(), kv_bufs[cur.decoder_idx].v.size());
+
+                        WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
+                                __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
+                    }
                 }
-                if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
-                    const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
-                    if (!text.empty()) {
-                        const auto tt0 = params.speed_up ? 2*t0 : t0;
-                        const auto tt1 = params.speed_up ? 2*t1 : t1;
-
-                        if (params.print_realtime) {
-                            if (params.print_timestamps) {
-                                printf("[%s --> %s]  %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
-                            } else {
-                                printf("%s", text.c_str());
-                                fflush(stdout);
+
+                // update the decoder state
+                // - check if the sequence is completed
+                // - check if the sequence is failed
+                // - update sliding window based on timestamp tokens
+                for (int j = 0; j < n_decoders_cur; ++j) {
+                    auto & decoder = ctx->decoders[j];
+
+                    if (decoder.completed || decoder.failed) {
+                        continue;
+                    }
+
+                    auto & has_ts     = decoder.has_ts;
+                    auto & failed     = decoder.failed;
+                    auto & completed  = decoder.completed;
+                    auto & seek_delta = decoder.seek_delta;
+                    auto & result_len = decoder.sequence.result_len;
+
+                    {
+                        const auto & token = decoder.sequence.tokens.back();
+
+                        // timestamp token - update sliding window
+                        if (token.id > whisper_token_beg(ctx)) {
+                            const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
+
+                            // do not allow to go back in time
+                            if (has_ts && seek_delta > seek_delta_new && result_len < i) {
+                                failed = true; // TODO: maybe this is not a failure ?
+                                continue;
                             }
-                        }
 
-                        result_all.push_back({ tt0, tt1, text, {} });
-                        for (int j = i0; j <= i; j++) {
-                            result_all.back().tokens.push_back(tokens_cur[j]);
+                            seek_delta = seek_delta_new;
+                            result_len = i + 1;
+                            has_ts = true;
                         }
 
-                        int n_new = 1;
+#ifdef WHISPER_DEBUG
+                        {
+                            const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]";
+                            WHISPER_PRINT_DEBUG("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n",
+                                    __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str());
+                        }
+#endif
 
-                        if (params.token_timestamps) {
-                            whisper_exp_compute_token_level_timestamps(
-                                    ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
+                        // end of segment
+                        if (token.id == whisper_token_eot(ctx) ||               // end of text token
+                           (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
+                           (has_ts && seek + seek_delta + 100 >= seek_end)      // end of audio reached
+                           ) {
+                            if (result_len == 0) {
+                                if (seek + seek_delta + 100 >= seek_end) {
+                                    result_len = i + 1;
+                                } else {
+                                    failed = true;
+                                    continue;
+                                }
+                            }
 
-                            if (params.max_len > 0) {
-                                n_new = whisper_wrap_segment(ctx, params.max_len);
+                            if (params.single_segment) {
+                                result_len = i + 1;
+                                seek_delta = 100*WHISPER_CHUNK_SIZE;
                             }
+
+                            completed = true;
+                            continue;
                         }
-                        if (params.new_segment_callback) {
-                            params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
+
+                        // TESTS: if no tensors are loaded, it means we are running tests
+                        if (ctx->model.n_loaded == 0) {
+                            seek_delta = 100*WHISPER_CHUNK_SIZE;
+                            completed = true;
+                            continue;
                         }
                     }
-                    text = "";
-                    while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
-                        i++;
+
+                    // sometimes, the decoding can get stuck in a repetition loop
+                    // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
+                    if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
+                        failed = true;
+                        continue;
                     }
-                    i--;
-                    t0 = t1;
-                    i0 = i + 1;
                 }
-            }
 
-            if (!text.empty()) {
-                const auto t1 = seek + seek_delta;
+                // check if all decoders have finished (i.e. completed or failed)
+                {
+                    bool completed_all = true;
 
-                const auto tt0 = params.speed_up ? 2*t0 : t0;
-                const auto tt1 = params.speed_up ? 2*t1 : t1;
+                    for (int j = 0; j < n_decoders_cur; ++j) {
+                        auto & decoder = ctx->decoders[j];
 
-                if (params.print_realtime) {
-                    if (params.print_timestamps) {
-                        printf("[%s --> %s]  %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
-                    } else {
-                        printf("%s", text.c_str());
-                        fflush(stdout);
+                        if (decoder.completed || decoder.failed) {
+                            continue;
+                        }
+
+                        completed_all = false;
+                    }
+
+                    if (completed_all) {
+                        break;
                     }
                 }
 
-                result_all.push_back({ tt0, tt1, text, {} });
-                for (int j = i0; j < (int) tokens_cur.size(); j++) {
-                    result_all.back().tokens.push_back(tokens_cur[j]);
+                ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+
+                // obtain logits for the next token
+                for (int j = 0; j < n_decoders_cur; ++j) {
+                    auto & decoder = ctx->decoders[j];
+
+                    if (decoder.failed || decoder.completed) {
+                        continue;
+                    }
+
+                    decoder.tokens_tmp.resize(1);
+                    decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id;
+
+                    //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
+
+                    if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
+                        fprintf(stderr, "%s: failed to decode\n", __func__);
+                        return -8;
+                    }
+
+                    {
+                        const int64_t t_start_sample_us = ggml_time_us();
+
+                        whisper_process_logits(*ctx, params, decoder, t_cur);
+
+                        ++decoder.kv_self.n;
+
+                        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+                    }
                 }
+            }
+
+            // rank the resulting sequences and select the best one
+            {
+                double best_score = -INFINITY;
+
+                for (int j = 0; j < n_decoders_cur; ++j) {
+                    auto & decoder = ctx->decoders[j];
+
+                    if (decoder.failed) {
+                        continue;
+                    }
+
+                    decoder.sequence.tokens.resize(decoder.sequence.result_len);
+                    whisper_sequence_score(params, decoder.sequence);
 
-                int n_new = 1;
+                    WHISPER_PRINT_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n",
+                            __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy);
 
-                if (params.token_timestamps) {
-                    whisper_exp_compute_token_level_timestamps(
-                            ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
+                    if (decoder.sequence.result_len > 8 && decoder.sequence.entropy < params.entropy_thold) {
+                        WHISPER_PRINT_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n",
+                                __func__, j, decoder.sequence.entropy, params.entropy_thold);
 
-                    if (params.max_len > 0) {
-                        n_new = whisper_wrap_segment(ctx, params.max_len);
+                        decoder.failed = true;
+
+                        continue;
+                    }
+
+                    if (best_score < decoder.sequence.score) {
+                        best_score = decoder.sequence.score;
+                        best_decoder_id = j;
                     }
                 }
-                if (params.new_segment_callback) {
-                    params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
+
+                WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id);
+            }
+
+            // was the decoding successful for the current temperature?
+            {
+                bool success = true;
+
+                const auto & decoder = ctx->decoders[best_decoder_id];
+
+                if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
+                    success = false;
+                }
+
+                if (success) {
+                    //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
+                    //    WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
+                    //}
+
+                    break;
                 }
             }
+
+            WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
         }
 
-        seek += seek_delta;
+        // output results through a user-provided callback
+        {
+            const auto & best_decoder = ctx->decoders[best_decoder_id];
+
+            const auto seek_delta = best_decoder.seek_delta;
+            const auto result_len = best_decoder.sequence.result_len;
+
+            const auto & tokens_cur = best_decoder.sequence.tokens;
+
+            //WHISPER_PRINT_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
+
+            // update prompt_past
+            prompt_past.clear();
+            if (prompt.front() == whisper_token_prev(ctx)) {
+                prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
+            }
+
+            for (int i = 0; i < result_len; ++i) {
+                prompt_past.push_back(tokens_cur[i].id);
+            }
+
+            // store the text from this iteration
+            if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
+                int  i0 = 0;
+                auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
+
+                std::string text;
+
+                for (int i = 0; i < (int) tokens_cur.size(); i++) {
+                    //printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
+                    //        ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
+                    //        ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
+
+                    if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
+                    } else {
+                        text += whisper_token_to_str(ctx, tokens_cur[i].id);
+                    }
+
+                    if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
+                        const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
+                        if (!text.empty()) {
+                            const auto tt0 = params.speed_up ? 2*t0 : t0;
+                            const auto tt1 = params.speed_up ? 2*t1 : t1;
+
+                            if (params.print_realtime) {
+                                if (params.print_timestamps) {
+                                    printf("[%s --> %s]  %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
+                                } else {
+                                    printf("%s", text.c_str());
+                                    fflush(stdout);
+                                }
+                            }
+
+                            //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid);
+
+                            result_all.push_back({ tt0, tt1, text, {} });
+                            for (int j = i0; j <= i; j++) {
+                                result_all.back().tokens.push_back(tokens_cur[j]);
+                            }
+
+                            int n_new = 1;
+
+                            if (params.token_timestamps) {
+                                whisper_exp_compute_token_level_timestamps(
+                                        *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
+
+                                if (params.max_len > 0) {
+                                    n_new = whisper_wrap_segment(*ctx, params.max_len);
+                                }
+                            }
+                            if (params.new_segment_callback) {
+                                params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
+                            }
+                        }
+                        text = "";
+                        while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
+                            i++;
+                        }
+                        i--;
+                        t0 = t1;
+                        i0 = i + 1;
+                    }
+                }
+
+                if (!text.empty()) {
+                    const auto t1 = seek + seek_delta;
+
+                    const auto tt0 = params.speed_up ? 2*t0 : t0;
+                    const auto tt1 = params.speed_up ? 2*t1 : t1;
+
+                    if (params.print_realtime) {
+                        if (params.print_timestamps) {
+                            printf("[%s --> %s]  %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
+                        } else {
+                            printf("%s", text.c_str());
+                            fflush(stdout);
+                        }
+                    }
+
+                    result_all.push_back({ tt0, tt1, text, {} });
+                    for (int j = i0; j < (int) tokens_cur.size(); j++) {
+                        result_all.back().tokens.push_back(tokens_cur[j]);
+                    }
+
+                    int n_new = 1;
+
+                    if (params.token_timestamps) {
+                        whisper_exp_compute_token_level_timestamps(
+                                *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
+
+                        if (params.max_len > 0) {
+                            n_new = whisper_wrap_segment(*ctx, params.max_len);
+                        }
+                    }
+                    if (params.new_segment_callback) {
+                        params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
+                    }
+                }
+            }
+
+            // update audio window
+            seek += seek_delta;
+
+            WHISPER_PRINT_DEBUG("seek = %d, seek_delta = %d\n", seek, seek_delta);
+        }
     }
 
     return 0;
@@ -3204,52 +3907,31 @@ int whisper_full_parallel(
     std::vector<struct whisper_context> ctxs(n_processors - 1);
 
     for (int i = 0; i < n_processors - 1; ++i) {
-        ctxs[i] = *ctx;
-
-        auto & model = ctxs[i].model;
-
-        // create the ggml memory context
-        {
-            struct ggml_init_params params;
-            params.mem_size   = ctxs[i].buf_memory.size();
-            params.mem_buffer = ctxs[i].buf_memory.data();
+        auto & ctx_p = ctxs[i];
 
-            model.ctx_mem = ggml_init(params);
-            if (!model.ctx_mem) {
-                fprintf(stderr, "%s: ggml_init() failed\n", __func__);
-                return false;
-            }
-        }
+        ctx_p = *ctx;
 
-        // separate key + value memory for each processor
-        {
-            auto & mctx = model.ctx_mem;
-
-            const auto & hparams = model.hparams;
+        ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
 
-            const int n_text_state = hparams.n_text_state;
-            const int n_text_layer = hparams.n_text_layer;
-            const int n_text_ctx   = hparams.n_text_ctx;
+        ctx_p.logits_id.reserve(ctx_p.vocab.n_vocab);
 
-            // key/value memory for the self-attention layer
-            {
-                const int n_mem      = n_text_layer*n_text_ctx;
-                const int n_elements = n_text_state*n_mem;
+        if (!kv_cache_reinit(ctx_p.kv_cross)) {
+            fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i);
+            return false;
+        }
 
-                model.memory_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
-                model.memory_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
+        // TAGS: WHISPER_DECODER_INIT
+        for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
+            if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) {
+                fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i);
+                return false;
             }
 
-            // key/value memory for the cross-attention layer
-            {
-                const int n_audio_ctx = hparams.n_audio_ctx;
-
-                const int n_mem      = n_text_layer*n_audio_ctx;
-                const int n_elements = n_text_state*n_mem;
+            ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx);
 
-                model.memory_cross_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
-                model.memory_cross_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
-            }
+            ctx_p.decoders[j].probs.reserve   (ctx_p.vocab.n_vocab);
+            ctx_p.decoders[j].logits.reserve  (ctx_p.vocab.n_vocab);
+            ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab);
         }
     }
 
@@ -3314,6 +3996,12 @@ int whisper_full_parallel(
         ctx->t_sample_us += ctxs[i].t_sample_us;
         ctx->t_encode_us += ctxs[i].t_encode_us;
         ctx->t_decode_us += ctxs[i].t_decode_us;
+
+        kv_cache_free(ctx->kv_cross);
+
+        for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
+            kv_cache_free(ctx->decoders[j].kv_self);
+        }
     }
 
     // average the timings
@@ -3438,14 +4126,14 @@ static std::vector<float> get_signal_energy(const float * signal, int n_samples,
 }
 
 static void whisper_exp_compute_token_level_timestamps(
-        struct whisper_context * ctx,
-        int   i_segment,
-        float thold_pt,
-        float thold_ptsum) {
-    auto & segment = ctx->result_all[i_segment];
+        struct whisper_context & ctx,
+                           int   i_segment,
+                         float   thold_pt,
+                         float   thold_ptsum) {
+    auto & segment = ctx.result_all[i_segment];
     auto & tokens  = segment.tokens;
 
-    const int n_samples = ctx->energy.size();
+    const int n_samples = ctx.energy.size();
 
     if (n_samples == 0) {
         fprintf(stderr, "%s: no signal data available\n", __func__);
@@ -3468,28 +4156,28 @@ static void whisper_exp_compute_token_level_timestamps(
         return;
     }
 
-    auto & t_beg    = ctx->t_beg;
-    auto & t_last   = ctx->t_last;
-    auto & tid_last = ctx->tid_last;
+    auto & t_beg    = ctx.t_beg;
+    auto & t_last   = ctx.t_last;
+    auto & tid_last = ctx.tid_last;
 
     for (int j = 0; j < n; ++j) {
         auto & token = tokens[j];
 
         if (j == 0) {
-            if (token.id == whisper_token_beg(ctx)) {
+            if (token.id == whisper_token_beg(&ctx)) {
                 tokens[j    ].t0 = t0;
                 tokens[j    ].t1 = t0;
                 tokens[j + 1].t0 = t0;
 
                 t_beg    = t0;
                 t_last   = t0;
-                tid_last = whisper_token_beg(ctx);
+                tid_last = whisper_token_beg(&ctx);
             } else {
                 tokens[j    ].t0 = t_last;
             }
         }
 
-        const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx));
+        const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(&ctx));
 
         tokens[j].id    = token.id;
         tokens[j].tid   = token.tid;
@@ -3497,7 +4185,7 @@ static void whisper_exp_compute_token_level_timestamps(
         tokens[j].pt    = token.pt;
         tokens[j].ptsum = token.ptsum;
 
-        tokens[j].vlen = voice_length(whisper_token_to_str(ctx, token.id));
+        tokens[j].vlen = voice_length(whisper_token_to_str(&ctx, token.id));
 
         if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
             if (j > 0) {
@@ -3529,6 +4217,8 @@ static void whisper_exp_compute_token_level_timestamps(
                 p1--;
             }
 
+            //printf("p0=%d p1=%d t0=%lld t1=%lld\n", p0, p1, tokens[p0].t0, tokens[p1].t1);
+
             if (p1 > p0) {
                 double psum = 0.0;
                 for (int j = p0; j <= p1; j++) {
@@ -3576,7 +4266,7 @@ static void whisper_exp_compute_token_level_timestamps(
         const int hw = WHISPER_SAMPLE_RATE/8;
 
         for (int j = 0; j < n; j++) {
-            if (tokens[j].id >= whisper_token_eot(ctx)) {
+            if (tokens[j].id >= whisper_token_eot(&ctx)) {
                 continue;
             }
 
@@ -3591,15 +4281,15 @@ static void whisper_exp_compute_token_level_timestamps(
             float sum = 0.0f;
 
             for (int k = ss0; k < ss1; k++) {
-                sum += ctx->energy[k];
+                sum += ctx.energy[k];
             }
 
             const float thold = 0.5*sum/ns;
 
             {
                 int k = s0;
-                if (ctx->energy[k] > thold && j > 0) {
-                    while (k > 0 && ctx->energy[k] > thold) {
+                if (ctx.energy[k] > thold && j > 0) {
+                    while (k > 0 && ctx.energy[k] > thold) {
                         k--;
                     }
                     tokens[j].t0 = sample_to_timestamp(k);
@@ -3609,7 +4299,7 @@ static void whisper_exp_compute_token_level_timestamps(
                         s0 = k;
                     }
                 } else {
-                    while (ctx->energy[k] < thold && k < s1) {
+                    while (ctx.energy[k] < thold && k < s1) {
                         k++;
                     }
                     s0 = k;
@@ -3619,8 +4309,8 @@ static void whisper_exp_compute_token_level_timestamps(
 
             {
                 int k = s1;
-                if (ctx->energy[k] > thold) {
-                    while (k < n_samples - 1 && ctx->energy[k] > thold) {
+                if (ctx.energy[k] > thold) {
+                    while (k < n_samples - 1 && ctx.energy[k] > thold) {
                         k++;
                     }
                     tokens[j].t1 = sample_to_timestamp(k);
@@ -3630,7 +4320,7 @@ static void whisper_exp_compute_token_level_timestamps(
                         s1 = k;
                     }
                 } else {
-                    while (ctx->energy[k] < thold && k > s0) {
+                    while (ctx.energy[k] < thold && k > s0) {
                         k--;
                     }
                     s1 = k;
@@ -3657,11 +4347,11 @@ static void whisper_exp_compute_token_level_timestamps(
     // debug info
     //for (int j = 0; j < n; ++j) {
     //    const auto & token = tokens[j];
-    //    const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]";
+    //    const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(&ctx, token.tid) : "[?]";
     //    printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
-    //            tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(ctx, token.id));
+    //            tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(&ctx, token.id));
 
-    //    if (tokens[j].id >= whisper_token_eot(ctx)) {
+    //    if (tokens[j].id >= whisper_token_eot(&ctx)) {
     //        continue;
     //    }
     //}
index 63f61af5114f2cb8864219b9b9615bea64126c2d..84504b7b23f9d9e70a26397ba23469a5e4382070 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -74,6 +74,7 @@ extern "C" {
         whisper_token tid; // forced timestamp token id
 
         float p;           // probability of the token
+        float plog;        // log probability of the token
         float pt;          // probability of the timestamp token
         float ptsum;       // sum of probabilities of all timestamp tokens
 
@@ -136,6 +137,7 @@ extern "C" {
     // tokens + n_tokens is the provided context for the decoder.
     // n_past is the number of tokens to use from previous decoder calls.
     // Returns 0 on success
+    // TODO: add support for multiple decoders
     WHISPER_API int whisper_decode(
             struct whisper_context * ctx,
                const whisper_token * tokens,
@@ -143,14 +145,6 @@ extern "C" {
                                int   n_past,
                                int   n_threads);
 
-    // Token sampling methods.
-    // These are provided for convenience and can be used after each call to whisper_decode().
-    // You can also implement your own sampling method using the whisper_get_probs() function.
-    // whisper_sample_best() returns the token with the highest probability
-    // whisper_sample_timestamp() returns the most probable timestamp token
-    WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
-    WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial);
-
     // Convert the provided text into tokens.
     // The tokens pointer must be large enough to hold the resulting tokens.
     // Returns the number of tokens on success, no more than n_max_tokens
@@ -192,8 +186,11 @@ extern "C" {
     WHISPER_API int whisper_n_audio_ctx    (struct whisper_context * ctx);
     WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
 
-    // The probabilities for the next token
-    WHISPER_API float * whisper_get_probs(struct whisper_context * ctx);
+    // Token logits obtained from the last call to whisper_decode()
+    // The logits for the last token are stored in the last row
+    // Rows: n_tokens
+    // Cols: n_vocab
+    WHISPER_API float * whisper_get_logits(struct whisper_context * ctx);
 
     // Token Id -> String. Uses the vocabulary in the provided context
     WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
@@ -222,8 +219,8 @@ extern "C" {
 
     // Available sampling strategies
     enum whisper_sampling_strategy {
-        WHISPER_SAMPLING_GREEDY,      // Always select the most probable token
-        WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet!
+        WHISPER_SAMPLING_GREEDY,      // similar to OpenAI's GreefyDecoder
+        WHISPER_SAMPLING_BEAM_SEARCH, // similar to OpenAI's BeamSearchDecoder
     };
 
     // Text segment callback
@@ -243,17 +240,17 @@ extern "C" {
         enum whisper_sampling_strategy strategy;
 
         int n_threads;
-        int n_max_text_ctx;
+        int n_max_text_ctx;     // max tokens to use from past text as prompt for the decoder
         int offset_ms;          // start offset in ms
         int duration_ms;        // audio duration to process in ms
 
         bool translate;
-        bool no_context;
+        bool no_context;        // do not use initial prompt for the decoder (if any)
         bool single_segment;    // force single segment output (useful for streaming)
-        bool print_special;
-        bool print_progress;
-        bool print_realtime;
-        bool print_timestamps;
+        bool print_special;     // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
+        bool print_progress;    // print progress information
+        bool print_realtime;    // print results from within whisper.cpp (avoid it, use callback instead)
+        bool print_timestamps;  // print timestamps for each text segment when printing realtime
 
         // [EXPERIMENTAL] token-level timestamps
         bool  token_timestamps; // enable token-level timestamps
@@ -263,10 +260,11 @@ extern "C" {
         int   max_tokens;       // max tokens per segment (0 = no limit)
 
         // [EXPERIMENTAL] speed-up techniques
+        // note: these can significantly reduce the quality of the output
         bool speed_up;          // speed-up the audio by 2x using Phase Vocoder
         int  audio_ctx;         // overwrite the audio context size (0 = use default)
 
-        // tokens to provide the whisper model as initial prompt
+        // tokens to provide to the whisper decoder as initial prompt
         // these are prepended to any existing text context from a previous call
         const whisper_token * prompt_tokens;
         int prompt_n_tokens;
@@ -274,19 +272,35 @@ extern "C" {
         // for auto-detection, set to nullptr, "" or "auto"
         const char * language;
 
+        // common decoding parameters:
+        bool suppress_blank;    // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
+
+        float temperature;      // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
+        float max_initial_ts;   // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
+        float length_penalty;   // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L267
+
+        // fallback parameters
+        // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L274-L278
+        float temperature_inc;
+        float entropy_thold;    // similar to OpenAI's "compression_ratio_threshold"
+        float logprob_thold;
+        float no_speech_thold;  // TODO: not implemented
+
         struct {
-            int n_past;
+            int best_of;    // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264
         } greedy;
 
         struct {
-            int n_past;
-            int beam_width;
-            int n_best;
+            int beam_size;  // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265
+
+            float patience; // TODO: not implemented, ref: https://arxiv.org/pdf/2204.05424.pdf
         } beam_search;
 
+        // called for every newly generated text segment
         whisper_new_segment_callback new_segment_callback;
         void * new_segment_callback_user_data;
 
+        // called each time before the encoder starts
         whisper_encoder_begin_callback encoder_begin_callback;
         void * encoder_begin_callback_user_data;
     };