]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
sync : latest whisper.cpp (scratch buffers in ggml)
authorGeorgi Gerganov <redacted>
Wed, 15 Feb 2023 18:59:36 +0000 (20:59 +0200)
committerGeorgi Gerganov <redacted>
Wed, 15 Feb 2023 18:59:48 +0000 (20:59 +0200)
examples/whisper/CMakeLists.txt
examples/whisper/common.cpp [new file with mode: 0644]
examples/whisper/common.h [new file with mode: 0644]
examples/whisper/main.cpp
examples/whisper/whisper.cpp
examples/whisper/whisper.h
include/ggml/ggml.h
src/ggml.c

index 55dd1b46eaaed769b1dea63fc834cf7b71aaa852..c8fa83a831d3e71ac190042d70a5633d6355d4dd 100644 (file)
@@ -10,6 +10,6 @@ target_link_libraries(whisper-cpp PRIVATE
     )
 
 set(TEST_TARGET whisper)
-add_executable(${TEST_TARGET} main.cpp)
+add_executable(${TEST_TARGET} main.cpp common.cpp)
 target_link_libraries(${TEST_TARGET} PRIVATE whisper-cpp)
 target_include_directories(${TEST_TARGET} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/..)
diff --git a/examples/whisper/common.cpp b/examples/whisper/common.cpp
new file mode 100644 (file)
index 0000000..194ef0e
--- /dev/null
@@ -0,0 +1,162 @@
+#include "common.h"
+
+// third-party utilities
+// use your favorite implementations
+#define DR_WAV_IMPLEMENTATION
+#include "dr_wav.h"
+
+#include <cmath>
+#include <regex>
+
+#ifndef M_PI
+#define M_PI 3.14159265358979323846
+#endif
+
+std::string trim(const std::string & s) {
+    std::regex e("^\\s+|\\s+$");
+    return std::regex_replace(s, e, "");
+}
+
+std::string replace(const std::string & s, const std::string & from, const std::string & to) {
+    std::string result = s;
+    size_t pos = 0;
+    while ((pos = result.find(from, pos)) != std::string::npos) {
+        result.replace(pos, from.length(), to);
+        pos += to.length();
+    }
+    return result;
+}
+
+bool read_wav(const std::string & fname, std::vector<float>& pcmf32, std::vector<std::vector<float>>& pcmf32s, bool stereo) {
+    drwav wav;
+    std::vector<uint8_t> wav_data; // used for pipe input from stdin
+
+    if (fname == "-") {
+        {
+            uint8_t buf[1024];
+            while (true)
+            {
+                const size_t n = fread(buf, 1, sizeof(buf), stdin);
+                if (n == 0) {
+                    break;
+                }
+                wav_data.insert(wav_data.end(), buf, buf + n);
+            }
+        }
+
+        if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
+            fprintf(stderr, "error: failed to open WAV file from stdin\n");
+            return false;
+        }
+
+        fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
+    }
+    else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) {
+        fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str());
+        return false;
+    }
+
+    if (wav.channels != 1 && wav.channels != 2) {
+        fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", __func__, fname.c_str());
+        return false;
+    }
+
+    if (stereo && wav.channels != 2) {
+        fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization\n", __func__, fname.c_str());
+        return false;
+    }
+
+    if (wav.sampleRate != COMMON_SAMPLE_RATE) {
+        fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", __func__, fname.c_str(), COMMON_SAMPLE_RATE/1000);
+        return false;
+    }
+
+    if (wav.bitsPerSample != 16) {
+        fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", __func__, fname.c_str());
+        return false;
+    }
+
+    const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
+
+    std::vector<int16_t> pcm16;
+    pcm16.resize(n*wav.channels);
+    drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
+    drwav_uninit(&wav);
+
+    // convert to mono, float
+    pcmf32.resize(n);
+    if (wav.channels == 1) {
+        for (uint64_t i = 0; i < n; i++) {
+            pcmf32[i] = float(pcm16[i])/32768.0f;
+        }
+    } else {
+        for (uint64_t i = 0; i < n; i++) {
+            pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
+        }
+    }
+
+    if (stereo) {
+        // convert to stereo, float
+        pcmf32s.resize(2);
+
+        pcmf32s[0].resize(n);
+        pcmf32s[1].resize(n);
+        for (uint64_t i = 0; i < n; i++) {
+            pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
+            pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
+        }
+    }
+
+    return true;
+}
+
+void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
+    const float rc = 1.0f / (2.0f * M_PI * cutoff);
+    const float dt = 1.0f / sample_rate;
+    const float alpha = dt / (rc + dt);
+
+    float y = data[0];
+
+    for (size_t i = 1; i < data.size(); i++) {
+        y = alpha * (y + data[i] - data[i - 1]);
+        data[i] = y;
+    }
+}
+
+bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
+    const int n_samples      = pcmf32.size();
+    const int n_samples_last = (sample_rate * last_ms) / 1000;
+
+    if (n_samples_last >= n_samples) {
+        // not enough samples - assume no speech
+        return false;
+    }
+
+    if (freq_thold > 0.0f) {
+        high_pass_filter(pcmf32, freq_thold, sample_rate);
+    }
+
+    float energy_all  = 0.0f;
+    float energy_last = 0.0f;
+
+    for (int i = 0; i < n_samples; i++) {
+        energy_all += fabsf(pcmf32[i]);
+
+        if (i >= n_samples - n_samples_last) {
+            energy_last += fabsf(pcmf32[i]);
+        }
+    }
+
+    energy_all  /= n_samples;
+    energy_last /= n_samples_last;
+
+    if (verbose) {
+        fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
+    }
+
+    if (energy_last > vad_thold*energy_all) {
+        return false;
+    }
+
+    return true;
+}
diff --git a/examples/whisper/common.h b/examples/whisper/common.h
new file mode 100644 (file)
index 0000000..04dd7cb
--- /dev/null
@@ -0,0 +1,40 @@
+#pragma once
+
+// needs to match WHISPER_SAMPLE_RATE
+#define COMMON_SAMPLE_RATE 16000
+
+#include <vector>
+#include <string>
+
+std::string trim(const std::string & s);
+
+std::string replace(
+        const std::string & s,
+        const std::string & from,
+        const std::string & to);
+
+// Read WAV audio file and store the PCM data into pcmf32
+// The sample rate of the audio must be equal to COMMON_SAMPLE_RATE
+// If stereo flag is set and the audio has 2 channels, the pcmf32s will contain 2 channel PCM
+bool read_wav(
+        const std::string & fname,
+        std::vector<float> & pcmf32,
+        std::vector<std::vector<float>> & pcmf32s,
+        bool stereo);
+
+// Apply a high-pass frequency filter to PCM audio
+// Suppresses frequencies below cutoff Hz
+void high_pass_filter(
+        std::vector<float> & data,
+        float cutoff,
+        float sample_rate);
+
+// Basic voice activity detection (VAD) using audio energy adaptive threshold
+bool vad_simple(
+        std::vector<float> & pcmf32,
+        int   sample_rate,
+        int   last_ms,
+        float vad_thold,
+        float freq_thold,
+        bool  verbose);
+
index 65b06ca516ae0aadf7694b8c274b7018b17baa6b..5bd7e424c61c85e4514c475b895d7c18b8f5b3b5 100644 (file)
@@ -1,9 +1,6 @@
-#include "whisper.h"
+#include "common.h"
 
-// third-party utilities
-// use your favorite implementations
-#define DR_WAV_IMPLEMENTATION
-#include "dr_wav.h"
+#include "whisper.h"
 
 #include <cmath>
 #include <fstream>
@@ -53,22 +50,24 @@ void replace_all(std::string & s, const std::string & search, const std::string
 // command-line parameters
 struct whisper_params {
     int32_t n_threads    = std::min(4, (int32_t) std::thread::hardware_concurrency());
-    int32_t n_processors = 1;
-    int32_t offset_t_ms  = 0;
-    int32_t offset_n     = 0;
-    int32_t duration_ms  = 0;
+    int32_t n_processors =  1;
+    int32_t offset_t_ms  =  0;
+    int32_t offset_n     =  0;
+    int32_t duration_ms  =  0;
     int32_t max_context  = -1;
-    int32_t max_len      = 0;
-    int32_t best_of      = 5;
+    int32_t max_len      =  0;
+    int32_t best_of      =  5;
     int32_t beam_size    = -1;
 
-    float word_thold    = 0.01f;
-    float entropy_thold = 2.4f;
-    float logprob_thold = -1.0f;
+    float word_thold    =  0.01f;
+    float entropy_thold =  2.40f;
+    float logprob_thold = -1.00f;
 
     bool speed_up       = false;
     bool translate      = false;
     bool diarize        = false;
+    bool split_on_word  = false;
+    bool no_fallback    = false;
     bool output_txt     = false;
     bool output_vtt     = false;
     bool output_srt     = false;
@@ -84,6 +83,7 @@ struct whisper_params {
     std::string model    = "models/ggml-base.en.bin";
 
     std::vector<std::string> fname_inp = {};
+    std::vector<std::string> fname_out = {};
 };
 
 void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@@ -91,7 +91,12 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
 bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
     for (int i = 1; i < argc; i++) {
         std::string arg = argv[i];
-
+           
+        if (arg == "-"){
+            params.fname_inp.push_back(arg);
+            continue;
+        }
+       
         if (arg[0] != '-') {
             params.fname_inp.push_back(arg);
             continue;
@@ -116,11 +121,14 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-su"   || arg == "--speed-up")       { params.speed_up       = true; }
         else if (arg == "-tr"   || arg == "--translate")      { params.translate      = true; }
         else if (arg == "-di"   || arg == "--diarize")        { params.diarize        = true; }
+        else if (arg == "-sow"  || arg == "--split-on-word")  { params.split_on_word  = true; }
+        else if (arg == "-nf"   || arg == "--no-fallback")    { params.no_fallback    = true; }
         else if (arg == "-otxt" || arg == "--output-txt")     { params.output_txt     = true; }
         else if (arg == "-ovtt" || arg == "--output-vtt")     { params.output_vtt     = true; }
         else if (arg == "-osrt" || arg == "--output-srt")     { params.output_srt     = true; }
         else if (arg == "-owts" || arg == "--output-words")   { params.output_wts     = true; }
         else if (arg == "-ocsv" || arg == "--output-csv")     { params.output_csv     = true; }
+        else if (arg == "-of"   || arg == "--output-file")    { params.fname_out.emplace_back(argv[++i]); }
         else if (arg == "-ps"   || arg == "--print-special")  { params.print_special  = true; }
         else if (arg == "-pc"   || arg == "--print-colors")   { params.print_colors   = true; }
         else if (arg == "-pp"   || arg == "--print-progress") { params.print_progress = true; }
@@ -144,35 +152,38 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
     fprintf(stderr, "\n");
     fprintf(stderr, "options:\n");
-    fprintf(stderr, "  -h,       --help            [default] show this help message and exit\n");
-    fprintf(stderr, "  -t N,     --threads N       [%-7d] number of threads to use during computation\n",    params.n_threads);
-    fprintf(stderr, "  -p N,     --processors N    [%-7d] number of processors to use during computation\n", params.n_processors);
-    fprintf(stderr, "  -ot N,    --offset-t N      [%-7d] time offset in milliseconds\n",                    params.offset_t_ms);
-    fprintf(stderr, "  -on N,    --offset-n N      [%-7d] segment index offset\n",                           params.offset_n);
-    fprintf(stderr, "  -d  N,    --duration N      [%-7d] duration of audio to process in milliseconds\n",   params.duration_ms);
-    fprintf(stderr, "  -mc N,    --max-context N   [%-7d] maximum number of text context tokens to store\n", params.max_context);
-    fprintf(stderr, "  -ml N,    --max-len N       [%-7d] maximum segment length in characters\n",           params.max_len);
-    fprintf(stderr, "  -bo N,    --best-of N       [%-7d] number of best candidates to keep\n",              params.best_of);
-    fprintf(stderr, "  -bs N,    --beam-size N     [%-7d] beam size for beam search\n",                      params.beam_size);
-    fprintf(stderr, "  -wt N,    --word-thold N    [%-7.2f] word timestamp probability threshold\n",         params.word_thold);
-    fprintf(stderr, "  -et N,    --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n",           params.entropy_thold);
-    fprintf(stderr, "  -lpt N,   --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n",   params.logprob_thold);
-    fprintf(stderr, "  -su,      --speed-up        [%-7s] speed up audio by x2 (reduced accuracy)\n",        params.speed_up ? "true" : "false");
-    fprintf(stderr, "  -tr,      --translate       [%-7s] translate from source language to english\n",      params.translate ? "true" : "false");
-    fprintf(stderr, "  -di,      --diarize         [%-7s] stereo audio diarization\n",                       params.diarize ? "true" : "false");
-    fprintf(stderr, "  -otxt,    --output-txt      [%-7s] output result in a text file\n",                   params.output_txt ? "true" : "false");
-    fprintf(stderr, "  -ovtt,    --output-vtt      [%-7s] output result in a vtt file\n",                    params.output_vtt ? "true" : "false");
-    fprintf(stderr, "  -osrt,    --output-srt      [%-7s] output result in a srt file\n",                    params.output_srt ? "true" : "false");
-    fprintf(stderr, "  -owts,    --output-words    [%-7s] output script for generating karaoke video\n",     params.output_wts ? "true" : "false");
-    fprintf(stderr, "  -ocsv,    --output-csv      [%-7s] output result in a CSV file\n",                    params.output_csv ? "true" : "false");
-    fprintf(stderr, "  -ps,      --print-special   [%-7s] print special tokens\n",                           params.print_special ? "true" : "false");
-    fprintf(stderr, "  -pc,      --print-colors    [%-7s] print colors\n",                                   params.print_colors ? "true" : "false");
-    fprintf(stderr, "  -pp,      --print-progress  [%-7s] print progress\n",                                 params.print_progress ? "true" : "false");
-    fprintf(stderr, "  -nt,      --no-timestamps   [%-7s] do not print timestamps\n",                        params.no_timestamps ? "false" : "true");
-    fprintf(stderr, "  -l LANG,  --language LANG   [%-7s] spoken language ('auto' for auto-detect)\n",       params.language.c_str());
-    fprintf(stderr, "            --prompt PROMPT   [%-7s] initial prompt\n",                                 params.prompt.c_str());
-    fprintf(stderr, "  -m FNAME, --model FNAME     [%-7s] model path\n",                                     params.model.c_str());
-    fprintf(stderr, "  -f FNAME, --file FNAME      [%-7s] input WAV file path\n",                            "");
+    fprintf(stderr, "  -h,        --help              [default] show this help message and exit\n");
+    fprintf(stderr, "  -t N,      --threads N         [%-7d] number of threads to use during computation\n",    params.n_threads);
+    fprintf(stderr, "  -p N,      --processors N      [%-7d] number of processors to use during computation\n", params.n_processors);
+    fprintf(stderr, "  -ot N,     --offset-t N        [%-7d] time offset in milliseconds\n",                    params.offset_t_ms);
+    fprintf(stderr, "  -on N,     --offset-n N        [%-7d] segment index offset\n",                           params.offset_n);
+    fprintf(stderr, "  -d  N,     --duration N        [%-7d] duration of audio to process in milliseconds\n",   params.duration_ms);
+    fprintf(stderr, "  -mc N,     --max-context N     [%-7d] maximum number of text context tokens to store\n", params.max_context);
+    fprintf(stderr, "  -ml N,     --max-len N         [%-7d] maximum segment length in characters\n",           params.max_len);
+    fprintf(stderr, "  -sow,      --split-on-word     [%-7s] split on word rather than on token\n",             params.split_on_word ? "true" : "false");
+    fprintf(stderr, "  -bo N,     --best-of N         [%-7d] number of best candidates to keep\n",              params.best_of);
+    fprintf(stderr, "  -bs N,     --beam-size N       [%-7d] beam size for beam search\n",                      params.beam_size);
+    fprintf(stderr, "  -wt N,     --word-thold N      [%-7.2f] word timestamp probability threshold\n",         params.word_thold);
+    fprintf(stderr, "  -et N,     --entropy-thold N   [%-7.2f] entropy threshold for decoder fail\n",           params.entropy_thold);
+    fprintf(stderr, "  -lpt N,    --logprob-thold N   [%-7.2f] log probability threshold for decoder fail\n",   params.logprob_thold);
+    fprintf(stderr, "  -su,       --speed-up          [%-7s] speed up audio by x2 (reduced accuracy)\n",        params.speed_up ? "true" : "false");
+    fprintf(stderr, "  -tr,       --translate         [%-7s] translate from source language to english\n",      params.translate ? "true" : "false");
+    fprintf(stderr, "  -di,       --diarize           [%-7s] stereo audio diarization\n",                       params.diarize ? "true" : "false");
+    fprintf(stderr, "  -nf,       --no-fallback       [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
+    fprintf(stderr, "  -otxt,     --output-txt        [%-7s] output result in a text file\n",                   params.output_txt ? "true" : "false");
+    fprintf(stderr, "  -ovtt,     --output-vtt        [%-7s] output result in a vtt file\n",                    params.output_vtt ? "true" : "false");
+    fprintf(stderr, "  -osrt,     --output-srt        [%-7s] output result in a srt file\n",                    params.output_srt ? "true" : "false");
+    fprintf(stderr, "  -owts,     --output-words      [%-7s] output script for generating karaoke video\n",     params.output_wts ? "true" : "false");
+    fprintf(stderr, "  -ocsv,     --output-csv        [%-7s] output result in a CSV file\n",                    params.output_csv ? "true" : "false");
+    fprintf(stderr, "  -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n",      "");
+    fprintf(stderr, "  -ps,       --print-special     [%-7s] print special tokens\n",                           params.print_special ? "true" : "false");
+    fprintf(stderr, "  -pc,       --print-colors      [%-7s] print colors\n",                                   params.print_colors ? "true" : "false");
+    fprintf(stderr, "  -pp,       --print-progress    [%-7s] print progress\n",                                 params.print_progress ? "true" : "false");
+    fprintf(stderr, "  -nt,       --no-timestamps     [%-7s] do not print timestamps\n",                        params.no_timestamps ? "false" : "true");
+    fprintf(stderr, "  -l LANG,   --language LANG     [%-7s] spoken language ('auto' for auto-detect)\n",       params.language.c_str());
+    fprintf(stderr, "             --prompt PROMPT     [%-7s] initial prompt\n",                                 params.prompt.c_str());
+    fprintf(stderr, "  -m FNAME,  --model FNAME       [%-7s] model path\n",                                     params.model.c_str());
+    fprintf(stderr, "  -f FNAME,  --file FNAME        [%-7s] input WAV file path\n",                            "");
     fprintf(stderr, "\n");
 }
 
@@ -343,9 +354,6 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
     const int n_segments = whisper_full_n_segments(ctx);
     for (int i = 0; i < n_segments; ++i) {
         const char * text = whisper_full_get_segment_text(ctx, i);
-        if (text[0] == ' ') {
-            text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character.
-        }
         const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
         const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
 
@@ -514,90 +522,14 @@ int main(int argc, char ** argv) {
 
     for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
         const auto fname_inp = params.fname_inp[f];
+               const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
 
-        std::vector<float> pcmf32; // mono-channel F32 PCM
+        std::vector<float> pcmf32;               // mono-channel F32 PCM
         std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
 
-        // WAV input
-        {
-            drwav wav;
-            std::vector<uint8_t> wav_data; // used for pipe input from stdin
-
-            if (fname_inp == "-") {
-                {
-                    uint8_t buf[1024];
-                    while (true)
-                    {
-                        const size_t n = fread(buf, 1, sizeof(buf), stdin);
-                        if (n == 0) {
-                            break;
-                        }
-                        wav_data.insert(wav_data.end(), buf, buf + n);
-                    }
-                }
-
-                if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
-                    fprintf(stderr, "error: failed to open WAV file from stdin\n");
-                    return 4;
-                }
-
-                fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
-            }
-            else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
-                fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
-                return 5;
-            }
-
-            if (wav.channels != 1 && wav.channels != 2) {
-                fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
-                return 6;
-            }
-
-            if (params.diarize && wav.channels != 2 && params.no_timestamps == false) {
-                fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", argv[0], fname_inp.c_str());
-                return 6;
-            }
-
-            if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
-                fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", argv[0], fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
-                return 8;
-            }
-
-            if (wav.bitsPerSample != 16) {
-                fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
-                return 9;
-            }
-
-            const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
-
-            std::vector<int16_t> pcm16;
-            pcm16.resize(n*wav.channels);
-            drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
-            drwav_uninit(&wav);
-
-            // convert to mono, float
-            pcmf32.resize(n);
-            if (wav.channels == 1) {
-                for (uint64_t i = 0; i < n; i++) {
-                    pcmf32[i] = float(pcm16[i])/32768.0f;
-                }
-            } else {
-                for (uint64_t i = 0; i < n; i++) {
-                    pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
-                }
-            }
-
-            if (params.diarize) {
-                // convert to stereo, float
-                pcmf32s.resize(2);
-
-                pcmf32s[0].resize(n);
-                pcmf32s[1].resize(n);
-                for (uint64_t i = 0; i < n; i++) {
-                    pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
-                    pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
-                }
-            }
+        if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) {
+            fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str());
+            continue;
         }
 
         // print system information
@@ -646,18 +578,20 @@ int main(int argc, char ** argv) {
 
             wparams.token_timestamps = params.output_wts || params.max_len > 0;
             wparams.thold_pt         = params.word_thold;
-            wparams.entropy_thold    = params.entropy_thold;
-            wparams.logprob_thold    = params.logprob_thold;
             wparams.max_len          = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
+            wparams.split_on_word    = params.split_on_word;
 
             wparams.speed_up         = params.speed_up;
 
+            wparams.prompt_tokens     = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
+            wparams.prompt_n_tokens   = prompt_tokens.empty() ? 0       : prompt_tokens.size();
+
             wparams.greedy.best_of        = params.best_of;
             wparams.beam_search.beam_size = params.beam_size;
-            wparams.temperature_inc = -1;
 
-            wparams.prompt_tokens     = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
-            wparams.prompt_n_tokens   = prompt_tokens.empty() ? 0       : prompt_tokens.size();
+            wparams.temperature_inc  = params.no_fallback ? 0.0f : wparams.temperature_inc;
+            wparams.entropy_thold    = params.entropy_thold;
+            wparams.logprob_thold    = params.logprob_thold;
 
             whisper_print_user_data user_data = { &params, &pcmf32s };
 
@@ -692,34 +626,33 @@ int main(int argc, char ** argv) {
 
             // output to text file
             if (params.output_txt) {
-                const auto fname_txt = fname_inp + ".txt";
+                const auto fname_txt = fname_out + ".txt";
                 output_txt(ctx, fname_txt.c_str());
             }
 
             // output to VTT file
             if (params.output_vtt) {
-                const auto fname_vtt = fname_inp + ".vtt";
+                const auto fname_vtt = fname_out + ".vtt";
                 output_vtt(ctx, fname_vtt.c_str());
             }
 
             // output to SRT file
             if (params.output_srt) {
-                const auto fname_srt = fname_inp + ".srt";
+                const auto fname_srt = fname_out + ".srt";
                 output_srt(ctx, fname_srt.c_str(), params);
             }
 
             // output to WTS file
             if (params.output_wts) {
-                const auto fname_wts = fname_inp + ".wts";
+                const auto fname_wts = fname_out + ".wts";
                 output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
             }
 
-           // output to CSV file
+            // output to CSV file
             if (params.output_csv) {
-                const auto fname_csv = fname_inp + ".csv";
+                const auto fname_csv = fname_out + ".csv";
                 output_csv(ctx, fname_csv.c_str());
             }
-
         }
     }
 
index 05bf58e16e3f77d0031aa0611e21aec0f5721e06..331d4084c6b7ed7a17bfb6462c2bb8fc2f735055 100644 (file)
 #include <regex>
 #include <random>
 
+#if defined(GGML_BIG_ENDIAN)
+#include <bit>
+
+template<typename T>
+static T byteswap(T value) {
+    return std::byteswap(value);
+}
+
+template<>
+float byteswap(float value) {
+    return std::bit_cast<float>(byteswap(std::bit_cast<std::uint32_t>(value)));
+}
+
+template<typename T>
+static void byteswap_tensor_data(ggml_tensor * tensor) {
+    T * datum = reinterpret_cast<T *>(tensor->data);
+    for (int i = 0; i < ggml_nelements(tensor); i++) {
+        datum[i] = byteswap(datum[i]);
+    }
+}
+
+static void byteswap_tensor(ggml_tensor * tensor) {
+    switch (tensor->type) {
+        case GGML_TYPE_I16: {
+            byteswap_tensor_data<int16_t>(tensor);
+            break;
+        }
+        case GGML_TYPE_F16: {
+            byteswap_tensor_data<ggml_fp16_t>(tensor);
+            break;
+        }
+        case GGML_TYPE_I32: {
+            byteswap_tensor_data<int32_t>(tensor);
+            break;
+        }
+        case GGML_TYPE_F32: {
+            byteswap_tensor_data<float>(tensor);
+            break;
+        }
+        default: { // GML_TYPE_I8
+            break;
+        }
+    }
+}
+
+#define BYTESWAP_VALUE(d) d = byteswap(d)
+#define BYTESWAP_FILTERS(f)            \
+    do {                              \
+        for (auto & datum : f.data) { \
+            datum = byteswap(datum);  \
+        }                             \
+    } while (0)
+#define BYTESWAP_TENSOR(t)       \
+    do {                         \
+        byteswap_tensor(tensor); \
+    } while (0)
+#else
+#define BYTESWAP_VALUE(d) do {} while (0)
+#define BYTESWAP_FILTERS(f) do {} while (0)
+#define BYTESWAP_TENSOR(t) do {} while (0)
+#endif
+
 #define WHISPER_ASSERT(x) \
     do { \
         if (!(x)) { \
 //#define WHISPER_USE_FLASH_FF
 #define WHISPER_MAX_DECODERS 16
 
+#define WHISPER_USE_SCRATCH
+#define WHISPER_MAX_SCRATCH_BUFFERS 16
+
 // available whisper models
 enum e_model {
     MODEL_UNKNOWN,
@@ -155,6 +220,38 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
 
 static const size_t MB = 1024*1024;
 
+static const std::map<e_model, size_t> MEM_REQ_SCRATCH0 = {
+    { MODEL_TINY,     12ull*MB },
+    { MODEL_BASE,     15ull*MB },
+    { MODEL_SMALL,    23ull*MB },
+    { MODEL_MEDIUM,   31ull*MB },
+    { MODEL_LARGE,    38ull*MB },
+};
+
+static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
+    { MODEL_TINY,     18ull*MB },
+    { MODEL_BASE,     24ull*MB },
+    { MODEL_SMALL,    36ull*MB },
+    { MODEL_MEDIUM,   48ull*MB },
+    { MODEL_LARGE,    60ull*MB },
+};
+
+static const std::map<e_model, size_t> MEM_REQ_SCRATCH2 = {
+    { MODEL_TINY,      4ull*MB },
+    { MODEL_BASE,      4ull*MB },
+    { MODEL_SMALL,     6ull*MB },
+    { MODEL_MEDIUM,    7ull*MB },
+    { MODEL_LARGE,     9ull*MB },
+};
+
+static const std::map<e_model, size_t> MEM_REQ_SCRATCH3 = {
+    { MODEL_TINY,      4ull*MB },
+    { MODEL_BASE,      4ull*MB },
+    { MODEL_SMALL,     6ull*MB },
+    { MODEL_MEDIUM,    7ull*MB },
+    { MODEL_LARGE,     9ull*MB },
+};
+
 static const std::map<e_model, size_t> MEM_REQ_MODEL = {
     { MODEL_TINY,     74ull*MB },
     { MODEL_BASE,    142ull*MB },
@@ -180,35 +277,19 @@ static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
 };
 
 static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
-    { MODEL_TINY,     80ull*MB },
-    { MODEL_BASE,    128ull*MB },
-    { MODEL_SMALL,   300ull*MB },
-    { MODEL_MEDIUM,  680ull*MB },
-    { MODEL_LARGE,  1100ull*MB },
-};
-
-static const std::map<e_model, size_t> MEM_REQ_ENCODE_LAYER = {
-    { MODEL_TINY,    104ull*MB },
-    { MODEL_BASE,    138ull*MB },
-    { MODEL_SMALL,   208ull*MB },
-    { MODEL_MEDIUM,  280ull*MB },
-    { MODEL_LARGE,   354ull*MB },
+    { MODEL_TINY,      6ull*MB },
+    { MODEL_BASE,      8ull*MB },
+    { MODEL_SMALL,    13ull*MB },
+    { MODEL_MEDIUM,   22ull*MB },
+    { MODEL_LARGE,    33ull*MB },
 };
 
 static const std::map<e_model, size_t> MEM_REQ_DECODE = {
-    { MODEL_TINY,    200ull*MB },
-    { MODEL_BASE,    202ull*MB },
-    { MODEL_SMALL,   204ull*MB },
-    { MODEL_MEDIUM,  206ull*MB },
-    { MODEL_LARGE,   208ull*MB },
-};
-
-static const std::map<e_model, size_t> MEM_REQ_DECODE_LAYER = {
-    { MODEL_TINY,     32ull*MB },
-    { MODEL_BASE,     44ull*MB },
-    { MODEL_SMALL,    64ull*MB },
-    { MODEL_MEDIUM,   84ull*MB },
-    { MODEL_LARGE,   110ull*MB },
+    { MODEL_TINY,      3ull*MB },
+    { MODEL_BASE,      5ull*MB },
+    { MODEL_SMALL,    10ull*MB },
+    { MODEL_MEDIUM,   18ull*MB },
+    { MODEL_LARGE,    27ull*MB },
 };
 
 struct whisper_mel {
@@ -474,6 +555,12 @@ struct whisper_context {
     int64_t t_decode_us = 0;
     int64_t t_start_us  = 0;
 
+    int32_t n_sample = 0; // number of tokens sampled
+    int32_t n_encode = 0; // number of encoder calls
+    int32_t n_decode = 0; // number of decoder calls
+    int32_t n_fail_p = 0; // number of logprob threshold failures
+    int32_t n_fail_h = 0; // number of entropy threshold failures
+
     ggml_type wtype; // weight type (FP32 or FP16)
 
     whisper_mel mel;
@@ -489,7 +576,10 @@ struct whisper_context {
 
     // memory buffers used by encode / decode contexts
     std::vector<uint8_t> buf_compute;
-    std::vector<uint8_t> buf_compute_layer;
+    std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
+
+    int    buf_last = 0;
+    size_t buf_max_size[WHISPER_MAX_SCRATCH_BUFFERS] = { 0 };
 
     // decode output (2-dimensional array: [n_tokens][n_vocab])
     std::vector<float> logits;
@@ -502,6 +592,8 @@ struct whisper_context {
 
     mutable std::mt19937 rng; // used for sampling at t > 0.0
 
+    int lang_id;
+
     // [EXPERIMENTAL] token-level timestamps data
     int64_t t_beg;
     int64_t t_last;
@@ -510,11 +602,43 @@ struct whisper_context {
 
     // [EXPERIMENTAL] speed-up techniques
     int32_t exp_n_audio_ctx; // 0 - use default
+
+    void use_buf(struct ggml_context * ctx, int i) {
+#if defined(WHISPER_USE_SCRATCH)
+        size_t last_size = 0;
+
+        if (i == -1) {
+            last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, });
+        } else {
+            auto & buf = buf_scratch[i];
+            last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), });
+        }
+
+        if (buf_last >= 0) {
+            buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
+        }
+
+        buf_last = i;
+#else
+        (void) i;
+        (void) ctx;
+#endif
+    }
+
+    size_t get_buf_max_mem(int i) const {
+#if defined(WHISPER_USE_SCRATCH)
+        return buf_max_size[i];
+#else
+        (void) i;
+        return 0;
+#endif
+    }
 };
 
 template<typename T>
 static void read_safe(whisper_model_loader * loader, T & dest) {
     loader->read(loader->context, &dest, sizeof(T));
+    BYTESWAP_VALUE(dest);
 }
 
 static bool kv_cache_init(
@@ -675,10 +799,13 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         {
             // this is the total memory required to run the inference
             const size_t mem_required =
-                scale*MEM_REQ_MODEL.at       (model.type) +
-                scale*MEM_REQ_KV_CROSS.at    (model.type) +
-                scale*std::max(MEM_REQ_ENCODE.at(model.type),       MEM_REQ_DECODE.at(model.type)) +
-                scale*std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type));
+                     MEM_REQ_SCRATCH0.at (model.type) +
+                     MEM_REQ_SCRATCH1.at (model.type) +
+                     MEM_REQ_SCRATCH2.at (model.type) +
+                     MEM_REQ_SCRATCH3.at (model.type) +
+                scale*MEM_REQ_MODEL.at   (model.type) +
+                scale*MEM_REQ_KV_CROSS.at(model.type) +
+                scale*std::max(MEM_REQ_ENCODE.at(model.type),       MEM_REQ_DECODE.at(model.type));
 
             // this is the memory required by one decoder
             const size_t mem_required_decoder =
@@ -714,8 +841,12 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
             fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
         }
 
-        wctx.buf_compute.resize      (scale*std::max(MEM_REQ_ENCODE.at(model.type),       MEM_REQ_DECODE.at(model.type)));
-        wctx.buf_compute_layer.resize(scale*std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
+        wctx.buf_compute.resize(scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
+
+        wctx.buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(model.type));
+        wctx.buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(model.type));
+        wctx.buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(model.type));
+        wctx.buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(model.type));
     }
 
     // load mel filters
@@ -727,6 +858,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
 
         filters.data.resize(filters.n_mel * filters.n_fft);
         loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float));
+        BYTESWAP_FILTERS(filters);
     }
 
     // load vocab
@@ -1190,6 +1322,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
             }
 
             loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
+            BYTESWAP_TENSOR(tensor);
 
             //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
             total_size += ggml_nbytes(tensor);
@@ -1246,6 +1379,8 @@ static bool whisper_encode(
 
     struct ggml_context * ctx0 = ggml_init(params);
 
+    wctx.use_buf(ctx0, 0);
+
     struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
     assert(mel->type == GGML_TYPE_F32);
     {
@@ -1266,6 +1401,8 @@ static bool whisper_encode(
 
     // convolution + gelu
     {
+        wctx.use_buf(ctx0, 1);
+
         cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
         cur = ggml_add(ctx0,
                 ggml_repeat(ctx0,
@@ -1275,6 +1412,8 @@ static bool whisper_encode(
 
         cur = ggml_gelu(ctx0, cur);
 
+        wctx.use_buf(ctx0, 0);
+
         cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
         cur = ggml_add(ctx0,
                 ggml_repeat(ctx0,
@@ -1285,6 +1424,8 @@ static bool whisper_encode(
         cur = ggml_gelu(ctx0, cur);
     }
 
+    wctx.use_buf(ctx0, 3);
+
     // ===================================================================
     // NOTE: experimenting with partial evaluation of the encoder (ignore)
     //static int iter = -1;
@@ -1305,6 +1446,7 @@ static bool whisper_encode(
     struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
 
     cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur));
+
     // ===================================================================
 
     // original:
@@ -1315,153 +1457,158 @@ static bool whisper_encode(
     for (int il = 0; il < n_layer; ++il) {
         const auto & layer = model.layers_encoder[il];
 
-        // create separate context for each layer to reduce memory usage
-
-        struct ggml_init_params paramsL;
-        paramsL.mem_size   = wctx.buf_compute_layer.size();
-        paramsL.mem_buffer = wctx.buf_compute_layer.data();
-
-        struct ggml_context * ctxL = ggml_init(paramsL);
-
         // norm
         {
-            cur = ggml_norm(ctxL, inpL);
+            wctx.use_buf(ctx0, 0);
+
+            cur = ggml_norm(ctx0, inpL);
 
             // cur = ln_0_w*cur + ln_0_b
-            cur = ggml_add(ctxL,
-                    ggml_mul(ctxL,
-                        ggml_repeat(ctxL, layer.attn_ln_0_w, cur),
+            cur = ggml_add(ctx0,
+                    ggml_mul(ctx0,
+                        ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
                         cur),
-                    ggml_repeat(ctxL, layer.attn_ln_0_b, cur));
+                    ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
         }
 
         // self-attention
         {
-            struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
+            wctx.use_buf(ctx0, 1);
+
+            struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
                     layer.attn_q_w,
                     cur);
 
-            Qcur = ggml_add(ctxL,
-                    ggml_repeat(ctxL,
+            Qcur = ggml_add(ctx0,
+                    ggml_repeat(ctx0,
                         layer.attn_q_b,
                         Qcur),
                     Qcur);
 
-            //Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
+            //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
 
             // note: no bias for Key
-            struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
+            struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
                     layer.attn_k_w,
                     cur);
 
-            //Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
+            //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
 
-            struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
+            struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
                     layer.attn_v_w,
                     cur);
 
-            Vcur = ggml_add(ctxL,
-                    ggml_repeat(ctxL,
+            Vcur = ggml_add(ctx0,
+                    ggml_repeat(ctx0,
                         layer.attn_v_b,
                         Vcur),
                     Vcur);
 
             // ------
 
+            wctx.use_buf(ctx0, 0);
+
 #ifdef WHISPER_USE_FLASH_ATTN
             struct ggml_tensor * Q =
-                ggml_permute(ctxL,
-                        ggml_cpy(ctxL,
+                ggml_permute(ctx0,
+                        ggml_cpy(ctx0,
                             Qcur,
-                            ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
+                            ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
                         0, 2, 1, 3);
 
             struct ggml_tensor * K =
-                ggml_permute(ctxL,
-                        ggml_cpy(ctxL,
+                ggml_permute(ctx0,
+                        ggml_cpy(ctx0,
                             Kcur,
-                            ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
+                            ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
                         0, 2, 1, 3);
 
             struct ggml_tensor * V =
-                ggml_cpy(ctxL,
-                        ggml_permute(ctxL,
-                            ggml_reshape_3d(ctxL,
+                ggml_cpy(ctx0,
+                        ggml_permute(ctx0,
+                            ggml_reshape_3d(ctx0,
                                 Vcur,
                                 n_state/n_head, n_head, n_ctx),
                             1, 2, 0, 3),
-                        ggml_new_tensor_3d(ctxL, wctx.wtype, n_ctx, n_state/n_head, n_head)
+                        ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head)
                         );
 
-            struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
+            struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
 #else
             struct ggml_tensor * Q =
-                ggml_permute(ctxL,
-                        ggml_cpy(ctxL,
+                ggml_permute(ctx0,
+                        ggml_cpy(ctx0,
                             Qcur,
-                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
+                            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
                         0, 2, 1, 3);
 
             struct ggml_tensor * K =
-                ggml_permute(ctxL,
-                        ggml_cpy(ctxL,
+                ggml_permute(ctx0,
+                        ggml_cpy(ctx0,
                             Kcur,
-                            ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
+                            ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
                         0, 2, 1, 3);
 
             // K * Q
-            struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
 
             struct ggml_tensor * KQ_scaled =
-                ggml_scale(ctxL,
+                ggml_scale(ctx0,
                         KQ,
-                        ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
+                        ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
                         );
 
-            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_scaled);
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled);
 
             //struct ggml_tensor * V_trans =
-            //    ggml_permute(ctxL,
-            //            ggml_cpy(ctxL,
+            //    ggml_permute(ctx0,
+            //            ggml_cpy(ctx0,
             //                Vcur,
-            //                ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
+            //                ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
             //            1, 2, 0, 3);
 
-            //struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
+            //struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
 
             struct ggml_tensor * V =
-                ggml_cpy(ctxL,
-                        ggml_permute(ctxL,
-                            ggml_reshape_3d(ctxL,
+                ggml_cpy(ctx0,
+                        ggml_permute(ctx0,
+                            ggml_reshape_3d(ctx0,
                                 Vcur,
                                 n_state/n_head, n_head, n_ctx),
                             0, 2, 1, 3),
-                        ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_ctx, n_head)
+                        ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_ctx, n_head)
                         );
 
-            struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_transpose(ctx0, V), KQ_soft_max);
 #endif
+            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 
-            struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
+            wctx.use_buf(ctx0, 1);
 
-            cur = ggml_cpy(ctxL,
+            cur = ggml_cpy(ctx0,
                     KQV_merged,
-                    ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, n_ctx));
+                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
         }
 
         // projection
         {
-            cur = ggml_mul_mat(ctxL,
+            wctx.use_buf(ctx0, 0);
+
+            cur = ggml_mul_mat(ctx0,
                     layer.attn_ln_1_w,
                     cur);
 
-            cur = ggml_add(ctxL,
-                    ggml_repeat(ctxL, layer.attn_ln_1_b, cur),
+            wctx.use_buf(ctx0, 1);
+
+            cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
                     cur);
         }
 
+        wctx.use_buf(ctx0, 2);
+
         // add the input
-        cur = ggml_add(ctxL, cur, inpL);
+        cur = ggml_add(ctx0, cur, inpL);
 
         struct ggml_tensor * inpFF = cur;
 
@@ -1469,75 +1616,75 @@ static bool whisper_encode(
         {
             // norm
             {
-                cur = ggml_norm(ctxL, inpFF);
+                wctx.use_buf(ctx0, 0);
+
+                cur = ggml_norm(ctx0, inpFF);
+
+                wctx.use_buf(ctx0, 1);
 
                 // cur = mlp_ln_w*cur + mlp_ln_b
-                cur = ggml_add(ctxL,
-                        ggml_mul(ctxL,
-                            ggml_repeat(ctxL, layer.mlp_ln_w, cur),
+                cur = ggml_add(ctx0,
+                        ggml_mul(ctx0,
+                            ggml_repeat(ctx0, layer.mlp_ln_w, cur),
                             cur),
-                        ggml_repeat(ctxL, layer.mlp_ln_b, cur));
+                        ggml_repeat(ctx0, layer.mlp_ln_b, cur));
             }
 
 #ifdef WHISPER_USE_FLASH_FF
-            cur = ggml_flash_ff(ctxL,
-                    ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, wctx.wtype, n_state, N)),
+            wctx.use_buf(ctx0, 0);
+
+            cur = ggml_flash_ff(ctx0,
+                    ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wctx.wtype, n_state, n_ctx)),
                     layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
 #else
+            wctx.use_buf(ctx0, 0);
+
             // fully connected
-            cur = ggml_mul_mat(ctxL,
+            cur = ggml_mul_mat(ctx0,
                     layer.mlp_0_w,
                     cur);
 
-            cur = ggml_add(ctxL,
-                    ggml_repeat(ctxL, layer.mlp_0_b, cur),
+            wctx.use_buf(ctx0, 1);
+
+            cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, layer.mlp_0_b, cur),
                     cur);
 
+            wctx.use_buf(ctx0, 0);
+
             // GELU activation
-            cur = ggml_gelu(ctxL, cur);
+            cur = ggml_gelu(ctx0, cur);
+
+            wctx.use_buf(ctx0, 1);
 
             // projection
-            cur = ggml_mul_mat(ctxL,
+            cur = ggml_mul_mat(ctx0,
                     layer.mlp_1_w,
                     cur);
 
-            cur = ggml_add(ctxL,
-                    ggml_repeat(ctxL, layer.mlp_1_b, cur),
+            wctx.use_buf(ctx0, 0);
+
+            cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, layer.mlp_1_b, cur),
                     cur);
 #endif
         }
 
-        // output from this layer
-        struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF);
-
-        {
-            struct ggml_cgraph gf = {};
-            gf.n_threads = n_threads;
-
-            ggml_build_forward_expand(&gf, inpO);
-            ggml_graph_compute       (ctxL, &gf);
-
-            //ggml_graph_print(&gf);
-        }
-
-        // TODO: this is a hack to have per-layer computation graphs - need to come up with something better
-        // input for next layer (inpO -> inpL)
-        memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
-        inpL->op = GGML_OP_NONE;
-        inpL->src0 = nullptr;
-        inpL->src1 = nullptr;
-
-        //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
+        wctx.use_buf(ctx0, 3);
 
-        ggml_free(ctxL);
+        inpL = ggml_add(ctx0, cur, inpFF);
     }
 
     cur = inpL;
 
     // norm
     {
+        wctx.use_buf(ctx0, 0);
+
         cur = ggml_norm(ctx0, cur);
 
+        wctx.use_buf(ctx0, 1);
+
         // cur = ln_f_g*cur + ln_f_b
         cur = ggml_add(ctx0,
                 ggml_mul(ctx0,
@@ -1546,6 +1693,8 @@ static bool whisper_encode(
                 ggml_repeat(ctx0, model.e_ln_b, cur));
     }
 
+    wctx.use_buf(ctx0, -1);
+
     // run the computation
     {
         struct ggml_cgraph gf = {};
@@ -1584,12 +1733,16 @@ static bool whisper_encode(
         for (int il = 0; il < model.hparams.n_text_layer; ++il) {
             auto & layer = model.layers_decoder[il];
 
+            wctx.use_buf(ctx0, 0);
+
             struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
                     layer.cross_attn_k_w,
                     cur);
 
             Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
 
+            wctx.use_buf(ctx0, 1);
+
             struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
                     layer.cross_attn_v_w,
                     cur);
@@ -1600,6 +1753,8 @@ static bool whisper_encode(
                         Vcross),
                     Vcross);
 
+            wctx.use_buf(ctx0, -1);
+
             //struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
             //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
             struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*n_ctx));
@@ -1615,18 +1770,24 @@ static bool whisper_encode(
 
     ////////////////////////////////////////////////////////////////////////////
 
-    //printf("%s: used_mem = %f MB\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0);
+    //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
+    //        ggml_used_mem(ctx0)/1024.0/1024.0,
+    //        wctx.get_buf_max_mem(0)/1024.0/1024.0,
+    //        wctx.get_buf_max_mem(1)/1024.0/1024.0,
+    //        wctx.get_buf_max_mem(2)/1024.0/1024.0,
+    //        wctx.get_buf_max_mem(3)/1024.0/1024.0);
 
     ggml_free(ctx0);
 
     wctx.t_encode_us += ggml_time_us() - t_start_us;
+    wctx.n_encode++;
 
     return true;
 }
 
 // evaluate the decoder
 //
-// given text prompt + audio features -> predicts the probabilities for the next token
+// given text prompt + audio features -> computes the logits for the next token
 //
 //   - model:      the model
 //   - n_threads:  number of threads to use
@@ -1670,6 +1831,9 @@ static bool whisper_decode(
 
     struct ggml_context * ctx0 = ggml_init(params);
 
+    struct ggml_cgraph gf = {};
+    gf.n_threads = n_threads;
+
     struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
     memcpy(embd->data, tokens, N*ggml_element_size(embd));
 
@@ -1678,6 +1842,8 @@ static bool whisper_decode(
         ((int32_t *) position->data)[i] = n_past + i;
     }
 
+    wctx.use_buf(ctx0, 3);
+
     // token encoding + position encoding
     struct ggml_tensor * cur =
         ggml_add(ctx0,
@@ -1689,211 +1855,248 @@ static bool whisper_decode(
     for (int il = 0; il < n_layer; ++il) {
         const auto & layer = model.layers_decoder[il];
 
-        struct ggml_init_params paramsL;
-        paramsL.mem_size   = wctx.buf_compute_layer.size();
-        paramsL.mem_buffer = wctx.buf_compute_layer.data();
-
-        struct ggml_context * ctxL = ggml_init(paramsL);
-        struct ggml_cgraph gf = {};
-        gf.n_threads = n_threads;
-
         // norm
         {
-            cur = ggml_norm(ctxL, inpL);
+            wctx.use_buf(ctx0, 0);
+
+            cur = ggml_norm(ctx0, inpL);
 
             // cur = ln_0_w*cur + ln_0_b
-            cur = ggml_add(ctxL,
-                    ggml_mul(ctxL,
-                        ggml_repeat(ctxL, layer.attn_ln_0_w, cur),
+            cur = ggml_add(ctx0,
+                    ggml_mul(ctx0,
+                        ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
                         cur),
-                    ggml_repeat(ctxL, layer.attn_ln_0_b, cur));
+                    ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
         }
 
         // self-attention
         {
-            struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
+            wctx.use_buf(ctx0, 1);
+
+            struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
                     layer.attn_q_w,
                     cur);
 
-            Qcur = ggml_add(ctxL,
-                    ggml_repeat(ctxL,
+            Qcur = ggml_add(ctx0,
+                    ggml_repeat(ctx0,
                         layer.attn_q_b,
                         Qcur),
                     Qcur);
 
-            Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
+            Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
 
             // note: no bias for Key
-            struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
+            struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
                     layer.attn_k_w,
                     cur);
 
-            Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
+            Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
 
-            struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
+            struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
                     layer.attn_v_w,
                     cur);
 
-            Vcur = ggml_add(ctxL,
-                    ggml_repeat(ctxL,
+            Vcur = ggml_add(ctx0,
+                    ggml_repeat(ctx0,
                         layer.attn_v_b,
                         Vcur),
                     Vcur);
 
             // store key and value to memory
             {
-                struct ggml_tensor * k = ggml_view_1d(ctxL, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
-                struct ggml_tensor * v = ggml_view_1d(ctxL, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past));
+                struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
+                struct ggml_tensor * v = ggml_view_1d(ctx0, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past));
 
-                ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k));
-                ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v));
+                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
+                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
             }
 
             // ------
 
+            wctx.use_buf(ctx0, 0);
+
             struct ggml_tensor * Q =
-                ggml_permute(ctxL,
-                        ggml_cpy(ctxL,
+                ggml_permute(ctx0,
+                        ggml_cpy(ctx0,
                             Qcur,
-                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
+                            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)),
                         0, 2, 1, 3);
 
             struct ggml_tensor * K =
-                ggml_permute(ctxL,
-                        ggml_reshape_3d(ctxL,
-                            ggml_view_1d(ctxL, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state),
+                ggml_permute(ctx0,
+                        ggml_reshape_3d(ctx0,
+                            ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state),
                             n_state/n_head, n_head, n_past + N),
                         0, 2, 1, 3);
 
+            wctx.use_buf(ctx0, 1);
+
             // K * Q
-            struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+
+            wctx.use_buf(ctx0, 0);
 
             //struct ggml_tensor * KQ_scaled =
-            //    ggml_scale(ctxL,
+            //    ggml_scale(ctx0,
             //            KQ,
-            //            ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
+            //            ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
             //            );
 
-            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ, n_past);
+            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
+
+            wctx.use_buf(ctx0, 1);
+
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
 
-            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked);
+            wctx.use_buf(ctx0, 0);
 
             struct ggml_tensor * V_trans =
-                ggml_permute(ctxL,
-                        ggml_reshape_3d(ctxL,
-                            ggml_view_1d(ctxL, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state),
+                ggml_permute(ctx0,
+                        ggml_reshape_3d(ctx0,
+                            ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state),
                             n_state/n_head, n_head, n_past + N),
                         1, 2, 0, 3);
 
-            struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
+            wctx.use_buf(ctx0, 1);
 
-            struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
 
-            cur = ggml_cpy(ctxL,
+            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+
+            cur = ggml_cpy(ctx0,
                     KQV_merged,
-                    ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
+                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
         }
 
+        // projection
         {
-            cur = ggml_mul_mat(ctxL,
+            wctx.use_buf(ctx0, 0);
+
+            cur = ggml_mul_mat(ctx0,
                     layer.attn_ln_1_w,
                     cur);
 
-            cur = ggml_add(ctxL,
-                    ggml_repeat(ctxL, layer.attn_ln_1_b, cur),
+            wctx.use_buf(ctx0, 1);
+
+            cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
                     cur);
         }
 
+        wctx.use_buf(ctx0, 2);
+
         // add the input
-        struct ggml_tensor * inpCA = ggml_add(ctxL, cur, inpL);
+        struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
 
         // norm
         {
-            cur = ggml_norm(ctxL, inpCA); // note: we use inpCA here
+            wctx.use_buf(ctx0, 0);
+
+            cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
+
+            wctx.use_buf(ctx0, 1);
 
             // cur = ln_0_w*cur + ln_0_b
-            cur = ggml_add(ctxL,
-                    ggml_mul(ctxL,
-                        ggml_repeat(ctxL, layer.cross_attn_ln_0_w, cur),
+            cur = ggml_add(ctx0,
+                    ggml_mul(ctx0,
+                        ggml_repeat(ctx0, layer.cross_attn_ln_0_w, cur),
                         cur),
-                    ggml_repeat(ctxL, layer.cross_attn_ln_0_b, cur));
+                    ggml_repeat(ctx0, layer.cross_attn_ln_0_b, cur));
         }
 
         // cross-attention
         {
-            struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
+            wctx.use_buf(ctx0, 0);
+
+            struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
                     layer.cross_attn_q_w,
                     cur);
 
-            Qcur = ggml_add(ctxL,
-                    ggml_repeat(ctxL,
+            Qcur = ggml_add(ctx0,
+                    ggml_repeat(ctx0,
                         layer.cross_attn_q_b,
                         Qcur),
                     Qcur);
 
-            Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
+            Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
 
             // Kcross is already scaled
             struct ggml_tensor * Kcross =
-                ggml_reshape_3d(ctxL,
-                        ggml_view_1d(ctxL, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state),
+                ggml_reshape_3d(ctx0,
+                        ggml_view_1d(ctx0, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state),
                         n_state/n_head, n_head, M);
 
             struct ggml_tensor * Vcross =
-                ggml_reshape_3d(ctxL,
-                        ggml_view_1d(ctxL, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state),
+                ggml_reshape_3d(ctx0,
+                        ggml_view_1d(ctx0, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state),
                         n_state/n_head, n_head, M);
 
+            struct ggml_tensor * V_trans = ggml_permute(ctx0, Vcross, 1, 2, 0, 3);
+
             // ------
 
+            wctx.use_buf(ctx0, 1);
+
             struct ggml_tensor * Q =
-                ggml_permute(ctxL,
-                        ggml_cpy(ctxL,
+                ggml_permute(ctx0,
+                        ggml_cpy(ctx0,
                             Qcur,
-                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
+                            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)),
                         0, 2, 1, 3);
 
-            struct ggml_tensor * K = ggml_permute(ctxL, Kcross, 0, 2, 1, 3);
+            struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
+
+            wctx.use_buf(ctx0, 0);
 
             // K * Q
-            struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
 
             //struct ggml_tensor * KQ_scaled =
-            //    ggml_scale(ctxL,
+            //    ggml_scale(ctx0,
             //            KQ,
-            //            ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
+            //            ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
             //            );
 
             // no masking for cross-attention
-            //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past);
+            //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
 
-            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ);
+            wctx.use_buf(ctx0, 1);
 
-            struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3);
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
 
-            struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
+            wctx.use_buf(ctx0, 0);
 
-            struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
+
+            wctx.use_buf(ctx0, 1);
+
+            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 
             // cur = KQV_merged.contiguous().view(n_state, N)
-            cur = ggml_cpy(ctxL,
+            cur = ggml_cpy(ctx0,
                     KQV_merged,
-                    ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
+                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
         }
 
         // projection
         {
-            cur = ggml_mul_mat(ctxL,
+            wctx.use_buf(ctx0, 0);
+
+            cur = ggml_mul_mat(ctx0,
                     layer.cross_attn_ln_1_w,
                     cur);
 
-            cur = ggml_add(ctxL,
-                    ggml_repeat(ctxL, layer.cross_attn_ln_1_b, cur),
+            wctx.use_buf(ctx0, 1);
+
+            cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
                     cur);
         }
 
+        wctx.use_buf(ctx0, 2);
+
         // add the input
-        cur = ggml_add(ctxL, cur, inpCA);
+        cur = ggml_add(ctx0, cur, inpCA);
 
         struct ggml_tensor * inpFF = cur;
 
@@ -1901,68 +2104,67 @@ static bool whisper_decode(
         {
             // norm
             {
-                cur = ggml_norm(ctxL, inpFF);
+                wctx.use_buf(ctx0, 0);
+
+                cur = ggml_norm(ctx0, inpFF);
+
+                wctx.use_buf(ctx0, 1);
 
                 // cur = mlp_ln_w*cur + mlp_ln_b
-                cur = ggml_add(ctxL,
-                        ggml_mul(ctxL,
-                            ggml_repeat(ctxL, layer.mlp_ln_w, cur),
+                cur = ggml_add(ctx0,
+                        ggml_mul(ctx0,
+                            ggml_repeat(ctx0, layer.mlp_ln_w, cur),
                             cur),
-                        ggml_repeat(ctxL, layer.mlp_ln_b, cur));
+                        ggml_repeat(ctx0, layer.mlp_ln_b, cur));
             }
 
+            wctx.use_buf(ctx0, 0);
+
             // fully connected
-            cur = ggml_mul_mat(ctxL,
+            cur = ggml_mul_mat(ctx0,
                     layer.mlp_0_w,
                     cur);
 
-            cur = ggml_add(ctxL,
-                    ggml_repeat(ctxL, layer.mlp_0_b, cur),
+            wctx.use_buf(ctx0, 1);
+
+            cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, layer.mlp_0_b, cur),
                     cur);
 
+            wctx.use_buf(ctx0, 0);
+
             // GELU activation
-            cur = ggml_gelu(ctxL, cur);
+            cur = ggml_gelu(ctx0, cur);
+
+            wctx.use_buf(ctx0, 1);
 
             // projection
-            cur = ggml_mul_mat(ctxL,
+            cur = ggml_mul_mat(ctx0,
                     layer.mlp_1_w,
                     cur);
 
-            cur = ggml_add(ctxL,
-                    ggml_repeat(ctxL, layer.mlp_1_b, cur),
-                    cur);
-        }
+            wctx.use_buf(ctx0, 0);
 
-        // output from this layer
-        struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF);
-
-        {
-            ggml_build_forward_expand(&gf, inpO);
-            ggml_graph_compute       (ctxL, &gf);
-
-            //ggml_graph_print(&gf);
+            cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, layer.mlp_1_b, cur),
+                    cur);
         }
 
-        // TODO: this is a hack to have per-layer computation graphs - need to come up with something better
-        // input for next layer (inpO -> inpL)
-        memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
-        inpL->op = GGML_OP_NONE;
-        inpL->src0 = nullptr;
-        inpL->src1 = nullptr;
+        wctx.use_buf(ctx0, 3);
 
-        if (N > 1) {
-            //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
-        }
-
-        ggml_free(ctxL);
+        inpL = ggml_add(ctx0, cur, inpFF);
     }
 
     cur = inpL;
 
     // norm
     {
+        wctx.use_buf(ctx0, 0);
+
         cur = ggml_norm(ctx0, cur);
 
+        wctx.use_buf(ctx0, 1);
+
         cur = ggml_add(ctx0,
                 ggml_mul(ctx0,
                     ggml_repeat(ctx0, model.d_ln_w, cur),
@@ -1970,29 +2172,44 @@ static bool whisper_decode(
                 ggml_repeat(ctx0, model.d_ln_b, cur));
     }
 
+    wctx.use_buf(ctx0, 0);
+
+    // compute logits only for the last token
+    // comment this line to compute logits for all N tokens
+    // might be useful in the future
+    cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
+
     struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
 
+    wctx.use_buf(ctx0, -1);
+
     // run the computation
     {
-        struct ggml_cgraph gf = {};
-        gf.n_threads = n_threads;
-
         ggml_build_forward_expand(&gf, logits);
         ggml_graph_compute       (ctx0, &gf);
     }
 
-    logits_out.resize(N*n_vocab);
-    memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
+    // extract logits for all N tokens
+    //logits_out.resize(N*n_vocab);
+    //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
+
+    // extract logits only for the last token
+    logits_out.resize(n_vocab);
+    memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
 
     if (N > 1) {
-        //const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N;
-        //printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token);
-        //printf("%s: max mem = %f MB\n", __func__, mem_per_token*model.hparams.n_text_ctx);
+        //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
+        //        ggml_used_mem(ctx0)/1024.0/1024.0,
+        //        wctx.get_buf_max_mem(0)/1024.0/1024.0,
+        //        wctx.get_buf_max_mem(1)/1024.0/1024.0,
+        //        wctx.get_buf_max_mem(2)/1024.0/1024.0,
+        //        wctx.get_buf_max_mem(3)/1024.0/1024.0);
     }
 
     ggml_free(ctx0);
 
     wctx.t_decode_us += ggml_time_us() - t_start_us;
+    wctx.n_decode++;
 
     return true;
 }
@@ -2644,12 +2861,17 @@ whisper_token whisper_token_transcribe(void) {
 void whisper_print_timings(struct whisper_context * ctx) {
     const int64_t t_end_us = ggml_time_us();
 
+    const int32_t n_sample = std::max(1, ctx->n_sample);
+    const int32_t n_encode = std::max(1, ctx->n_encode);
+    const int32_t n_decode = std::max(1, ctx->n_decode);
+
     fprintf(stderr, "\n");
+    fprintf(stderr, "%s:     fallbacks = %3d p / %3d h\n", __func__, ctx->n_fail_p, ctx->n_fail_h);
     fprintf(stderr, "%s:     load time = %8.2f ms\n", __func__, ctx->t_load_us/1000.0f);
     fprintf(stderr, "%s:      mel time = %8.2f ms\n", __func__, ctx->t_mel_us/1000.0f);
-    fprintf(stderr, "%s:   sample time = %8.2f ms\n", __func__, ctx->t_sample_us/1000.0f);
-    fprintf(stderr, "%s:   encode time = %8.2f ms / %.2f ms per layer\n", __func__, ctx->t_encode_us/1000.0f, ctx->t_encode_us/1000.0f/ctx->model.hparams.n_audio_layer);
-    fprintf(stderr, "%s:   decode time = %8.2f ms / %.2f ms per layer\n", __func__, ctx->t_decode_us/1000.0f, ctx->t_decode_us/1000.0f/ctx->model.hparams.n_text_layer);
+    fprintf(stderr, "%s:   sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_sample_us, n_sample, 1e-3f*ctx->t_sample_us/n_sample);
+    fprintf(stderr, "%s:   encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_encode_us, n_encode, 1e-3f*ctx->t_encode_us/n_encode);
+    fprintf(stderr, "%s:   decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_decode_us, n_decode, 1e-3f*ctx->t_decode_us/n_decode);
     fprintf(stderr, "%s:    total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
 }
 
@@ -2683,7 +2905,7 @@ const char * whisper_print_system_info(void) {
 
 struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
     struct whisper_full_params result = {
-        /*.strategy         =*/ WHISPER_SAMPLING_GREEDY,
+        /*.strategy         =*/ strategy,
 
         /*.n_threads        =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
         /*.n_max_text_ctx   =*/ 16384,
@@ -2702,6 +2924,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
         /*.thold_pt         =*/ 0.01f,
         /*.thold_ptsum      =*/ 0.01f,
         /*.max_len          =*/ 0,
+        /*.split_on_word    =*/ false,
         /*.max_tokens       =*/ 0,
 
         /*.speed_up         =*/ false,
@@ -2713,6 +2936,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
         /*.language         =*/ "en",
 
         /*.suppress_blank   =*/ true,
+        /*.suppress_non_speech_tokens =*/true,
 
         /*.temperature      =*/  0.0f,
         /*.max_initial_ts   =*/  1.0f,
@@ -2768,9 +2992,35 @@ static void whisper_exp_compute_token_level_timestamps(
                          float   thold_pt,
                          float   thold_ptsum);
 
+// trim from start (in place)
+static inline void ltrim(std::string &s) {
+    s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) {
+        return !std::isspace(ch);
+    }));
+}
+
+// trim from end (in place)
+static inline void rtrim(std::string &s) {
+    s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) {
+        return !std::isspace(ch);
+    }).base(), s.end());
+}
+
+// trim from both ends (in place)
+static inline void trim(std::string &s) {
+    rtrim(s);
+    ltrim(s);
+}
+
+static inline bool should_split_on_word(const char * txt, bool split_on_word) {
+    if (!split_on_word) return true;
+
+    return txt[0] == ' ';
+}
+
 // wrap the last segment to max_len characters
 // returns the number of new segments
-static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
+static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool split_on_word) {
     auto segment = ctx.result_all.back();
 
     int res = 1;
@@ -2785,11 +3035,14 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
         }
 
         const auto txt = whisper_token_to_str(&ctx, token.id);
-
         const int cur = strlen(txt);
 
-        if (acc + cur > max_len && i > 0) {
+        if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) {
             // split here
+            if (split_on_word) {
+                trim(text);
+            }
+
             ctx.result_all.back().text = std::move(text);
             ctx.result_all.back().t1 = token.t0;
             ctx.result_all.back().tokens.resize(i);
@@ -2817,11 +3070,22 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
         }
     }
 
+    if (split_on_word) {
+        trim(text);
+    }
     ctx.result_all.back().text = std::move(text);
 
     return res;
 }
 
+static const std::vector<std::string> non_speech_tokens
+{
+    "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
+    "_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--",
+    "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
+    "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
+};
+
 // process the logits for the selected decoder
 // - applies logit filters
 // - computes logprobs and probs
@@ -2878,6 +3142,37 @@ static void whisper_process_logits(
         logits[vocab.token_sot]  = -INFINITY;
         logits[vocab.token_solm] = -INFINITY;
 
+        // suppress task tokens
+        logits[vocab.token_translate]  = -INFINITY;
+        logits[vocab.token_transcribe] = -INFINITY;
+
+
+        // suppress non-speech tokens
+        // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
+        if (params.suppress_non_speech_tokens)
+        {
+            for (const std::string &token : non_speech_tokens)
+            {
+                std::string suppress_tokens[] = {token, " " + token};
+                for (const std::string &suppress_token : suppress_tokens)
+                {
+                    if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end())
+                    {
+                        logits[vocab.token_to_id.at(suppress_token)] = -INFINITY;
+                    }
+                }
+            }
+            // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
+            if (vocab.token_to_id.find(" -") != vocab.token_to_id.end())
+            {
+                logits[vocab.token_to_id.at(" -")] = -INFINITY;
+            }
+            if (vocab.token_to_id.find(" '") != vocab.token_to_id.end())
+            {
+                logits[vocab.token_to_id.at(" '")] = -INFINITY;
+            }
+        }
+
         // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
         // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424
         {
@@ -2910,6 +3205,16 @@ static void whisper_process_logits(
             }
         }
 
+        // condition timestamp tokens to be increasing
+        // ref: https://github.com/openai/whisper/pull/831#issuecomment-1385910556
+        if (decoder.has_ts) {
+            const int tid0 = decoder.seek_delta/2;
+
+            for (int i = vocab.token_beg; i < vocab.token_beg + tid0; ++i) {
+                logits[i] = -INFINITY;
+            }
+        }
+
         // populate the logprobs array (log_softmax)
         {
             const float logit_max = *std::max_element(logits.begin(), logits.end());
@@ -3004,7 +3309,7 @@ static void whisper_process_logits(
 }
 
 static whisper_token_data whisper_sample_token(
-      const whisper_context & ctx,
+            whisper_context & ctx,
       const whisper_decoder & decoder,
                        bool   best) {
     whisper_token_data result = {
@@ -3059,6 +3364,8 @@ static whisper_token_data whisper_sample_token(
         result.pt  = result.p;
     }
 
+    ctx.n_sample++;
+
     return result;
 }
 
@@ -3091,10 +3398,10 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
     std::vector<whisper_token_data> result;
     result.reserve(k);
 
-    whisper_token tid;
+    whisper_token tid = vocab.token_beg;
 
-    float pt;
-    float ptsum;
+    float pt    = 0.0;
+    float ptsum = 0.0;
 
     {
         double sum_ts = 0.0;
@@ -3127,6 +3434,8 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
         }
     }
 
+    ctx.n_sample++;
+
     return result;
 }
 
@@ -3211,7 +3520,7 @@ int whisper_full(
             fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
             return -3;
         }
-
+        ctx->lang_id = lang_id;
         params.language = whisper_lang_str(lang_id);
 
         fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
@@ -3308,6 +3617,7 @@ int whisper_full(
     std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
     if (whisper_is_multilingual(ctx)) {
         const int lang_id = whisper_lang_id(params.language);
+        ctx->lang_id = lang_id;
         prompt_init.push_back(whisper_token_lang(ctx, lang_id));
         if (params.translate) {
             prompt_init.push_back(whisper_token_translate());
@@ -3432,7 +3742,7 @@ int whisper_full(
                 prompt.clear();
 
                 // if we have already generated some text, use it as a prompt to condition the next generation
-                if (!prompt_past.empty() && t_cur > 0.5f) {
+                if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) {
                     int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
 
                     prompt = { whisper_token_prev(ctx) };
@@ -3443,11 +3753,11 @@ int whisper_full(
                 prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
 
                 // print the prompt
-                //WHISPER_PRINT_DEBUG("\n\n");
-                //for (int i = 0; i < (int) prompt.size(); i++) {
-                //    WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str());
-                //}
-                //WHISPER_PRINT_DEBUG("\n\n");
+                WHISPER_PRINT_DEBUG("\n\n");
+                for (int i = 0; i < (int) prompt.size(); i++) {
+                    WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str());
+                }
+                WHISPER_PRINT_DEBUG("\n\n");
 
                 if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
                     fprintf(stderr, "%s: failed to decode\n", __func__);
@@ -3544,7 +3854,7 @@ int whisper_full(
                         return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
                     });
 
-                    int cur_c = 0;
+                    unsigned int cur_c = 0;
 
                     for (int j = 0; j < n_decoders_cur; ++j) {
                         auto & decoder = ctx->decoders[j];
@@ -3555,7 +3865,7 @@ int whisper_full(
 
                         auto & cur = beam_candidates[cur_c++];
 
-                        while (beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
+                        while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
                             ++cur_c;
                         }
 
@@ -3721,11 +4031,12 @@ int whisper_full(
                     WHISPER_PRINT_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n",
                             __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy);
 
-                    if (decoder.sequence.result_len > 8 && decoder.sequence.entropy < params.entropy_thold) {
+                    if (decoder.sequence.result_len > 32 && decoder.sequence.entropy < params.entropy_thold) {
                         WHISPER_PRINT_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n",
                                 __func__, j, decoder.sequence.entropy, params.entropy_thold);
 
                         decoder.failed = true;
+                        ctx->n_fail_h++;
 
                         continue;
                     }
@@ -3747,6 +4058,7 @@ int whisper_full(
 
                 if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
                     success = false;
+                    ctx->n_fail_p++;
                 }
 
                 if (success) {
@@ -3801,6 +4113,7 @@ int whisper_full(
 
                     if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
                         const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
+
                         if (!text.empty()) {
                             const auto tt0 = params.speed_up ? 2*t0 : t0;
                             const auto tt1 = params.speed_up ? 2*t1 : t1;
@@ -3828,7 +4141,7 @@ int whisper_full(
                                         *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
 
                                 if (params.max_len > 0) {
-                                    n_new = whisper_wrap_segment(*ctx, params.max_len);
+                                    n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
                                 }
                             }
                             if (params.new_segment_callback) {
@@ -3872,7 +4185,7 @@ int whisper_full(
                                 *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
 
                         if (params.max_len > 0) {
-                            n_new = whisper_wrap_segment(*ctx, params.max_len);
+                            n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
                         }
                     }
                     if (params.new_segment_callback) {
@@ -4025,6 +4338,10 @@ int whisper_full_n_segments(struct whisper_context * ctx) {
     return ctx->result_all.size();
 }
 
+int whisper_full_lang_id(struct whisper_context * ctx) {
+    return ctx->lang_id; 
+}
+
 int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
     return ctx->result_all[i_segment].t0;
 }
@@ -4059,6 +4376,145 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int
 
 // =================================================================================================
 
+//
+// Temporary interface needed for exposing ggml interface
+// Will be removed in the future when ggml becomes a separate library
+//
+
+WHISPER_API int whisper_bench_memcpy(int n_threads) {
+    ggml_time_init();
+
+    size_t n    = 50;
+    size_t arr  = n_threads > 0 ? 1024 : n_threads; // trick to avoid compiler optimizations
+
+    // 1 GB array
+    const size_t size = arr*1024llu*1024llu;
+
+    char * src = (char *) malloc(size);
+    char * dst = (char *) malloc(size);
+
+    for (size_t i = 0; i < size; i++) src[i] = i;
+
+    memcpy(dst, src, size); // heat-up
+
+    double tsum = 0.0;
+
+    for (size_t i = 0; i < n; i++) {
+        const int64_t t0 = ggml_time_us();
+
+        memcpy(dst, src, size);
+
+        const int64_t t1 = ggml_time_us();
+
+        tsum += (t1 - t0)*1e-6;
+
+        src[0] = rand();
+    }
+
+    fprintf(stderr, "memcpy: %.2f GB/s\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu));
+
+    // needed to prevent the compile from optimizing the memcpy away
+    {
+        double sum = 0.0;
+
+        for (size_t i = 0; i < size; i++) sum += dst[i];
+
+        fprintf(stderr, "sum:    %s %f\n", sum == -536870910.00 ? "ok" : "error", sum);
+    }
+
+    free(src);
+    free(dst);
+
+    return 0;
+}
+
+WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
+    ggml_time_init();
+
+    const int n_max = 128;
+
+    const std::vector<size_t> sizes = {
+        64, 128, 256, 512, 1024, 2048, 4096,
+    };
+
+    const size_t N_max = sizes.back();
+
+    // a: N*N*sizeof(float)
+    // b: N*N*sizeof(float)
+    // c: N*N*sizeof(float)
+    // when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
+    std::vector<char> buf(4llu*N_max*N_max*sizeof(float) + 4*256);
+
+    for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
+
+    for (int j = 0; j < (int) sizes.size(); j++) {
+        int n_fp16 = 0;
+        int n_fp32 = 0;
+
+        // GFLOPS/s
+        double s_fp16 = 0.0;
+        double s_fp32 = 0.0;
+
+        const size_t N = sizes[j];
+
+        for (int k = 0; k < 2; ++k) {
+            const ggml_type wtype = k == 0 ? GGML_TYPE_F16 : GGML_TYPE_F32;
+
+            double & s = k == 0 ? s_fp16 : s_fp32;
+            int    & n = k == 0 ? n_fp16   : n_fp32;
+
+            struct ggml_init_params gparams = {
+                /*.mem_size   =*/ buf.size(),
+                /*.mem_buffer =*/ buf.data(),
+            };
+
+            struct ggml_context * ctx0 = ggml_init(gparams);
+
+            struct ggml_tensor * a = ggml_new_tensor_2d(ctx0, wtype,         N, N);
+            struct ggml_tensor * b = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, N, N);
+
+            struct ggml_tensor * c = ggml_mul_mat(ctx0, a, b);
+
+            struct ggml_cgraph gf = ggml_build_forward(c);
+
+            gf.n_threads = n_threads;
+
+            double tsum = 0.0;
+
+            // heat-up
+            ggml_graph_compute(ctx0, &gf);
+
+            for (int i = 0; i < n_max; ++i) {
+                const int64_t t0 = ggml_time_us();
+
+                ggml_graph_compute(ctx0, &gf);
+
+                const int64_t t1 = ggml_time_us();
+
+                tsum += (t1 - t0)*1e-6;
+                n++;
+
+                if (tsum > 1.0 && n >= 3) {
+                    break;
+                }
+            }
+
+            ggml_free(ctx0);
+
+            s = ((2.0*N*N*N*n)/tsum)*1e-9;
+        }
+
+        fprintf(stderr, "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n",
+            N, N, s_fp16, n_fp16, s_fp32, n_fp32);
+    }
+
+    return 0;
+}
+
+// =================================================================================================
+
+// =================================================================================================
+
 //
 // Experimental stuff below
 //
index 84504b7b23f9d9e70a26397ba23469a5e4382070..7eece797c16b84f31116289aaa50e84afb7c4fa6 100644 (file)
@@ -113,6 +113,16 @@ extern "C" {
                                int   n_samples,
                                int   n_threads);
 
+    // Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2. 
+    // The resulting spectrogram is stored inside the provided whisper context.
+    // Returns 0 on success
+    WHISPER_API int whisper_pcm_to_mel_phase_vocoder(
+        struct whisper_context* ctx,
+        const float* samples,
+        int   n_samples,
+        int   n_threads);
+
+
     // This can be used to set a custom log mel spectrogram inside the provided whisper context.
     // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
     // n_mel must be 80
@@ -245,7 +255,7 @@ extern "C" {
         int duration_ms;        // audio duration to process in ms
 
         bool translate;
-        bool no_context;        // do not use initial prompt for the decoder (if any)
+        bool no_context;        // do not use past transcription (if any) as initial prompt for the decoder
         bool single_segment;    // force single segment output (useful for streaming)
         bool print_special;     // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
         bool print_progress;    // print progress information
@@ -257,6 +267,7 @@ extern "C" {
         float thold_pt;         // timestamp token probability threshold (~0.01)
         float thold_ptsum;      // timestamp token sum probability threshold (~0.01)
         int   max_len;          // max segment length in characters
+        bool  split_on_word;    // split on word rather than on token (when used with max_len)
         int   max_tokens;       // max tokens per segment (0 = no limit)
 
         // [EXPERIMENTAL] speed-up techniques
@@ -274,6 +285,7 @@ extern "C" {
 
         // common decoding parameters:
         bool suppress_blank;    // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
+        bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
 
         float temperature;      // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
         float max_initial_ts;   // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
@@ -329,6 +341,9 @@ extern "C" {
     // A segment can be a few words, a sentence, or even a paragraph.
     WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx);
 
+    // Language id associated with the current context
+    WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx);
+
     // Get the start and end time of the specified segment.
     WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment);
     WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment);
@@ -350,6 +365,13 @@ extern "C" {
     // Get the probability of the specified token in the specified segment.
     WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
 
+    ////////////////////////////////////////////////////////////////////////////
+
+    // Temporary helpers needed for exposing ggml interface
+
+    WHISPER_API int whisper_bench_memcpy(int n_threads);
+    WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads);
+
 #ifdef __cplusplus
 }
 #endif
index f3c9e5a31991d746627fcbd420b94e6f069aa156..18f317bec04dee0df2aede91ea477ad6ff6f621f 100644 (file)
@@ -301,6 +301,13 @@ struct ggml_cgraph {
     int64_t perf_time_us;
 };
 
+// scratch buffer
+struct ggml_scratch {
+    size_t offs;
+    size_t size;
+    void * data;
+};
+
 struct ggml_init_params {
     // memory pool
     size_t mem_size;   // bytes
@@ -327,6 +334,8 @@ void ggml_free(struct ggml_context * ctx);
 
 size_t ggml_used_mem(const struct ggml_context * ctx);
 
+size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch);
+
 struct ggml_tensor * ggml_new_tensor(
         struct ggml_context * ctx,
         enum   ggml_type type,
index c59ee64af00c48455e49510ca1760d17e2b9169d..d67612c36a38f646c1fba12138bbd66894afb446 100644 (file)
@@ -339,8 +339,12 @@ int64_t ggml_cycles_per_ms(void) {
 #if defined(__cpp_lib_hardware_interference_size)
 #define CACHE_LINE_SIZE hardware_destructive_interference_size
 #else
+#if defined(__POWER9_VECTOR__)
+#define CACHE_LINE_SIZE 128
+#else
 #define CACHE_LINE_SIZE 64
 #endif
+#endif
 
 static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
 
@@ -609,9 +613,12 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
 #define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ?                   \
   vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \
   vec_extract_fp32_from_shortl(vec_xl(0, p))
-#define GGML_F16_VEC_STORE(p, r, i)                                      \
-  if (i & 0x1)                                                           \
-    vec_xst(vec_pack_to_short_fp32(r[i], r[i - 1]), 0, p - GGML_F16_EPR)
+#define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i]
+#define GGML_F16_VEC_STORE(p, r, i)                             \
+  if (i & 0x1)                                                  \
+    vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)],  \
+                                   r[i - GGML_ENDIAN_BYTE(0)]), \
+            0, p - GGML_F16_EPR)
 
 #elif defined(__wasm_simd128__)
 
@@ -1251,7 +1258,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
 //
 
 struct ggml_object {
-    size_t offset;
+    size_t offs;
     size_t size;
 
     struct ggml_object * next;
@@ -1277,6 +1284,9 @@ struct ggml_context {
 
     struct ggml_object * objects_begin;
     struct ggml_object * objects_end;
+
+    struct ggml_scratch scratch;
+    struct ggml_scratch scratch_save;
 };
 
 struct ggml_context_container {
@@ -1339,7 +1349,7 @@ inline static void ggml_critical_section_end(void) {
 
 void ggml_print_object(const struct ggml_object * obj) {
     GGML_PRINT(" - ggml_object: offset = %zu, size = %zu, next = %p\n",
-            obj->offset, obj->size, (const void *) obj->next);
+            obj->offs, obj->size, (const void *) obj->next);
 }
 
 void ggml_print_objects(const struct ggml_context * ctx) {
@@ -1535,12 +1545,14 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
     }
 
     *ctx = (struct ggml_context) {
-        .mem_size         = params.mem_size,
-        .mem_buffer       = params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
-        .mem_buffer_owned = params.mem_buffer ? false : true,
-        .n_objects        = 0,
-        .objects_begin    = NULL,
-        .objects_end      = NULL,
+        /*.mem_size         =*/ params.mem_size,
+        /*.mem_buffer       =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
+        /*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
+        /*.n_objects        =*/ 0,
+        /*.objects_begin    =*/ NULL,
+        /*.objects_end      =*/ NULL,
+        /*.scratch          =*/ { 0, 0, NULL, },
+        /*.scratch_save     =*/ { 0, 0, NULL, },
     };
 
     ggml_assert_aligned(ctx->mem_buffer);
@@ -1563,7 +1575,7 @@ void ggml_free(struct ggml_context * ctx) {
             g_state.contexts[i].used = false;
 
             GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n",
-                    __func__, i, ctx->n_objects, ctx->objects_end->offset + ctx->objects_end->size);
+                    __func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size);
 
             if (ctx->mem_buffer_owned) {
                 free(ctx->mem_buffer);
@@ -1582,7 +1594,15 @@ void ggml_free(struct ggml_context * ctx) {
 }
 
 size_t ggml_used_mem(const struct ggml_context * ctx) {
-    return ctx->objects_end->offset + ctx->objects_end->size;
+    return ctx->objects_end->offs + ctx->objects_end->size;
+}
+
+size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch) {
+    const size_t result = ctx->scratch.data ? ctx->scratch.offs : 0;
+
+    ctx->scratch = scratch;
+
+    return result;
 }
 
 ////////////////////////////////////////////////////////////////////////////////
@@ -1596,9 +1616,9 @@ struct ggml_tensor * ggml_new_tensor_impl(
     // always insert objects at the end of the context's memory pool
     struct ggml_object * obj_cur = ctx->objects_end;
 
-    const size_t cur_offset = obj_cur == NULL ? 0 : obj_cur->offset;
-    const size_t cur_size   = obj_cur == NULL ? 0 : obj_cur->size;
-    const size_t cur_end    = cur_offset + cur_size;
+    const size_t cur_offs = obj_cur == NULL ? 0 : obj_cur->offs;
+    const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size;
+    const size_t cur_end  = cur_offs + cur_size;
 
     size_t size_needed = 0;
 
@@ -1609,25 +1629,52 @@ struct ggml_tensor * ggml_new_tensor_impl(
         }
         // align to GGML_MEM_ALIGN
         size_needed = ((size_needed + GGML_MEM_ALIGN - 1)/GGML_MEM_ALIGN)*GGML_MEM_ALIGN;
-
-    }
-    size_needed += sizeof(struct ggml_tensor);
-
-    if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
-        GGML_PRINT("%s: not enough space in the context's memory pool\n", __func__);
-        assert(false);
-        return NULL;
     }
 
     char * const mem_buffer = ctx->mem_buffer;
-
     struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end);
 
-    *obj_new = (struct ggml_object) {
-        .offset = cur_end + GGML_OBJECT_SIZE,
-        .size   = size_needed,
-        .next   = NULL,
-    };
+    if (ctx->scratch.data == NULL || data != NULL) {
+        size_needed += sizeof(struct ggml_tensor);
+
+        if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
+            GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
+                    __func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size);
+            assert(false);
+            return NULL;
+        }
+
+        *obj_new = (struct ggml_object) {
+            .offs = cur_end + GGML_OBJECT_SIZE,
+            .size = size_needed,
+            .next = NULL,
+        };
+    } else {
+        if (ctx->scratch.offs + size_needed > ctx->scratch.size) {
+            GGML_PRINT("%s: not enough space in the scratch memory\n", __func__);
+            assert(false);
+            return NULL;
+        }
+
+        if (cur_end + sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE > ctx->mem_size) {
+            GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
+                    __func__, cur_end + sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE, ctx->mem_size);
+            assert(false);
+            return NULL;
+        }
+
+        data = (char * const) ctx->scratch.data + ctx->scratch.offs;
+
+        *obj_new = (struct ggml_object) {
+            .offs = cur_end + GGML_OBJECT_SIZE,
+            .size = sizeof(struct ggml_tensor),
+            .next = NULL,
+        };
+
+        //printf("scratch offs = %zu, size_needed = %zu\n", ctx->scratch.offs, size_needed);
+
+        ctx->scratch.offs += size_needed;
+    }
 
     if (obj_cur != NULL) {
         obj_cur->next = obj_new;
@@ -1638,9 +1685,9 @@ struct ggml_tensor * ggml_new_tensor_impl(
 
     ctx->objects_end = obj_new;
 
-    //GGML_PRINT_DEBUG("%s: inserted new object at %zu\n", __func__, cur_end);
+    //printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size);
 
-    struct ggml_tensor * const result = (struct ggml_tensor *)(mem_buffer + obj_new->offset);
+    struct ggml_tensor * const result = (struct ggml_tensor *)(mem_buffer + obj_new->offs);
 
     ggml_assert_aligned(result);
 
@@ -1683,7 +1730,7 @@ struct ggml_tensor * ggml_new_tensor(
         struct ggml_context * ctx,
         enum   ggml_type type,
         int    n_dims,
-        const int* ne) {
+        const int * ne) {
     return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL);
 }
 
@@ -1725,16 +1772,26 @@ struct ggml_tensor * ggml_new_tensor_4d(
 }
 
 struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {
+    ctx->scratch_save = ctx->scratch;
+    ctx->scratch.data = NULL;
+
     struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
 
+    ctx->scratch = ctx->scratch_save;
+
     ggml_set_i32(result, value);
 
     return result;
 }
 
 struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) {
+    ctx->scratch_save = ctx->scratch;
+    ctx->scratch.data = NULL;
+
     struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
 
+    ctx->scratch = ctx->scratch_save;
+
     ggml_set_f32(result, value);
 
     return result;
@@ -2343,7 +2400,7 @@ struct ggml_tensor * ggml_repeat(
     result->op   = GGML_OP_REPEAT;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
-    result->src1 = NULL;
+    result->src1 = b;
 
     return result;
 }
@@ -2959,9 +3016,7 @@ struct ggml_tensor * ggml_diag_mask_inf(
     // TODO: when implement backward, fix this:
     //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
     struct ggml_tensor * result = ggml_view_tensor(ctx, a);
-
-    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
-    ((int32_t *) b->data)[0] = n_past;
+    struct ggml_tensor * b = ggml_new_i32(ctx, n_past);
 
     result->op   = GGML_OP_DIAG_MASK_INF;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -4293,7 +4348,9 @@ static bool ggml_compute_forward_mul_mat_use_blas(
     const int ne1 = dst->ne[1];
 
     // TODO: find the optimal values for these
-    if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ne0 >= 32 && ne1 >= 32 && ne10 >= 32) {
+    if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && (
+             (ne0 >= 32 && ne1  >= 32   && ne10 >= 32)
+            )) {
         //printf("BLAS: %d %d %d\n", ne0, ne1, ne10);
         return true;
     }
@@ -4373,7 +4430,9 @@ static void ggml_compute_forward_mul_mat_f32(
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
         GGML_ASSERT(nb10 == sizeof(float));
 
-        if (params->ith != 0) return;
+        if (params->ith != 0) {
+            return;
+        }
 
         if (params->type == GGML_TASK_INIT) {
             return;
@@ -4616,7 +4675,9 @@ static void ggml_compute_forward_mul_mat_f16_f32(
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
         GGML_ASSERT(nb10 == sizeof(float));
 
-        if (params->ith != 0) return;
+        if (params->ith != 0) {
+            return;
+        }
 
         if (params->type == GGML_TASK_INIT) {
             return;
@@ -7054,7 +7115,7 @@ struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cg
 #ifdef __APPLE__
 
 //#include <os/lock.h>
-
+//
 //typedef os_unfair_lock ggml_lock_t;
 //
 //#define ggml_lock_init(x)    UNUSED(x)
@@ -7161,6 +7222,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
             if (state->params.ith < state->params.nth) {
                 ggml_compute_forward(&state->params, state->node);
             }
+
             state->node = NULL;
         } else {
             break;
@@ -7205,6 +7267,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                 .node   = NULL,
                 .shared = &state_shared,
             };
+
             int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
             assert(rc == 0);
             UNUSED(rc);
@@ -7273,8 +7336,12 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                                 node->src1->type == GGML_TYPE_F32) {
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
                                 if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
-                                    node->n_tasks = 1;
+                                    node->n_tasks = 1; // TODO: this actually is doing nothing
+                                                       //       the threads are still spinning
                                     cur = sizeof(float)*(node->src0->ne[0]*node->src0->ne[1]);
+                                    //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]);
+                                    //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]);
+                                    //printf("cur = %zu\n", cur);
                                 } else {
                                     cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1);
                                 }