]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : sync with latest code from whisper.cpp
authorGeorgi Gerganov <redacted>
Sun, 4 Dec 2022 09:06:13 +0000 (11:06 +0200)
committerGeorgi Gerganov <redacted>
Sun, 4 Dec 2022 09:06:13 +0000 (11:06 +0200)
examples/whisper/main.cpp
examples/whisper/whisper.cpp
examples/whisper/whisper.h
include/ggml/ggml.h
src/ggml.c

index 70580315769a5110b1dc67ee90433c3dd4a5f057..465d43fb0796455bdfce509ab1a5b661500bb58d 100644 (file)
@@ -36,6 +36,10 @@ std::string to_timestamp(int64_t t, bool comma = false) {
     return std::string(buf);
 }
 
+int timestamp_to_sample(int64_t t, int n_samples) {
+    return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
+}
+
 // helper function to replace substrings
 void replace_all(std::string & s, const std::string & search, const std::string & replace) {
     for (size_t pos = 0; ; pos += replace.length()) {
@@ -48,7 +52,6 @@ void replace_all(std::string & s, const std::string & search, const std::string
 
 // command-line parameters
 struct whisper_params {
-    int32_t seed         = -1; // RNG seed, not used currently
     int32_t n_threads    = std::min(4, (int32_t) std::thread::hardware_concurrency());
     int32_t n_processors = 1;
     int32_t offset_t_ms  = 0;
@@ -59,15 +62,16 @@ struct whisper_params {
 
     float word_thold = 0.01f;
 
-    bool verbose              = false;
-    bool translate            = false;
-    bool output_txt           = false;
-    bool output_vtt           = false;
-    bool output_srt           = false;
-    bool output_wts           = false;
-    bool print_special_tokens = false;
-    bool print_colors         = false;
-    bool no_timestamps        = false;
+    bool speed_up      = false;
+    bool translate     = false;
+    bool diarize       = false;
+    bool output_txt    = false;
+    bool output_vtt    = false;
+    bool output_srt    = false;
+    bool output_wts    = false;
+    bool print_special = false;
+    bool print_colors  = false;
+    bool no_timestamps = false;
 
     std::string language  = "en";
     std::string model     = "models/ggml-base.en.bin";
@@ -86,57 +90,32 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
             continue;
         }
 
-        if (arg == "-s" || arg == "--seed") {
-            params.seed = std::stoi(argv[++i]);
-        } else if (arg == "-t" || arg == "--threads") {
-            params.n_threads = std::stoi(argv[++i]);
-        } else if (arg == "-p" || arg == "--processors") {
-            params.n_processors = std::stoi(argv[++i]);
-        } else if (arg == "-ot" || arg == "--offset-t") {
-            params.offset_t_ms = std::stoi(argv[++i]);
-        } else if (arg == "-on" || arg == "--offset-n") {
-            params.offset_n = std::stoi(argv[++i]);
-        } else if (arg == "-d" || arg == "--duration") {
-            params.duration_ms = std::stoi(argv[++i]);
-        } else if (arg == "-mc" || arg == "--max-context") {
-            params.max_context = std::stoi(argv[++i]);
-        } else if (arg == "-ml" || arg == "--max-len") {
-            params.max_len = std::stoi(argv[++i]);
-        } else if (arg == "-wt" || arg == "--word-thold") {
-            params.word_thold = std::stof(argv[++i]);
-        } else if (arg == "-v" || arg == "--verbose") {
-            params.verbose = true;
-        } else if (arg == "--translate") {
-            params.translate = true;
-        } else if (arg == "-l" || arg == "--language") {
-            params.language = argv[++i];
-            if (whisper_lang_id(params.language.c_str()) == -1) {
-                fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
-                whisper_print_usage(argc, argv, params);
-                exit(0);
-            }
-        } else if (arg == "-otxt" || arg == "--output-txt") {
-            params.output_txt = true;
-        } else if (arg == "-ovtt" || arg == "--output-vtt") {
-            params.output_vtt = true;
-        } else if (arg == "-osrt" || arg == "--output-srt") {
-            params.output_srt = true;
-        } else if (arg == "-owts" || arg == "--output-words") {
-            params.output_wts = true;
-        } else if (arg == "-ps" || arg == "--print_special") {
-            params.print_special_tokens = true;
-        } else if (arg == "-pc" || arg == "--print_colors") {
-            params.print_colors = true;
-        } else if (arg == "-nt" || arg == "--no_timestamps") {
-            params.no_timestamps = true;
-        } else if (arg == "-m" || arg == "--model") {
-            params.model = argv[++i];
-        } else if (arg == "-f" || arg == "--file") {
-            params.fname_inp.push_back(argv[++i]);
-        } else if (arg == "-h" || arg == "--help") {
+        if (arg == "-h" || arg == "--help") {
             whisper_print_usage(argc, argv, params);
             exit(0);
-        } else {
+        }
+        else if (arg == "-t"    || arg == "--threads")       { params.n_threads     = std::stoi(argv[++i]); }
+        else if (arg == "-p"    || arg == "--processors")    { params.n_processors  = std::stoi(argv[++i]); }
+        else if (arg == "-ot"   || arg == "--offset-t")      { params.offset_t_ms   = std::stoi(argv[++i]); }
+        else if (arg == "-on"   || arg == "--offset-n")      { params.offset_n      = std::stoi(argv[++i]); }
+        else if (arg == "-d"    || arg == "--duration")      { params.duration_ms   = std::stoi(argv[++i]); }
+        else if (arg == "-mc"   || arg == "--max-context")   { params.max_context   = std::stoi(argv[++i]); }
+        else if (arg == "-ml"   || arg == "--max-len")       { params.max_len       = std::stoi(argv[++i]); }
+        else if (arg == "-wt"   || arg == "--word-thold")    { params.word_thold    = std::stof(argv[++i]); }
+        else if (arg == "-su"   || arg == "--speed-up")      { params.speed_up      = true; }
+        else if (arg == "-tr"   || arg == "--translate")     { params.translate     = true; }
+        else if (arg == "-di"   || arg == "--diarize")       { params.diarize       = true; }
+        else if (arg == "-otxt" || arg == "--output-txt")    { params.output_txt    = true; }
+        else if (arg == "-ovtt" || arg == "--output-vtt")    { params.output_vtt    = true; }
+        else if (arg == "-osrt" || arg == "--output-srt")    { params.output_srt    = true; }
+        else if (arg == "-owts" || arg == "--output-words")  { params.output_wts    = true; }
+        else if (arg == "-ps"   || arg == "--print-special") { params.print_special = true; }
+        else if (arg == "-pc"   || arg == "--print-colors")  { params.print_colors  = true; }
+        else if (arg == "-nt"   || arg == "--no-timestamps") { params.no_timestamps = true; }
+        else if (arg == "-l"    || arg == "--language")      { params.language      = argv[++i]; }
+        else if (arg == "-m"    || arg == "--model")         { params.model         = argv[++i]; }
+        else if (arg == "-f"    || arg == "--file")          { params.fname_inp.push_back(argv[++i]); }
+        else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
             whisper_print_usage(argc, argv, params);
             exit(0);
@@ -151,33 +130,40 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
     fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
     fprintf(stderr, "\n");
     fprintf(stderr, "options:\n");
-    fprintf(stderr, "  -h,       --help           show this help message and exit\n");
-    fprintf(stderr, "  -s SEED,  --seed SEED      RNG seed (default: -1)\n");
-    fprintf(stderr, "  -t N,     --threads N      number of threads to use during computation (default: %d)\n", params.n_threads);
-    fprintf(stderr, "  -p N,     --processors N   number of processors to use during computation (default: %d)\n", params.n_processors);
-    fprintf(stderr, "  -ot N,    --offset-t N     time offset in milliseconds (default: %d)\n", params.offset_t_ms);
-    fprintf(stderr, "  -on N,    --offset-n N     segment index offset (default: %d)\n", params.offset_n);
-    fprintf(stderr, "  -d  N,    --duration N     duration of audio to process in milliseconds (default: %d)\n", params.duration_ms);
-    fprintf(stderr, "  -mc N,    --max-context N  maximum number of text context tokens to store (default: max)\n");
-    fprintf(stderr, "  -ml N,    --max-len N      maximum segment length in characters (default: %d)\n", params.max_len);
-    fprintf(stderr, "  -wt N,    --word-thold N   word timestamp probability threshold (default: %f)\n", params.word_thold);
-    fprintf(stderr, "  -v,       --verbose        verbose output\n");
-    fprintf(stderr, "            --translate      translate from source language to english\n");
-    fprintf(stderr, "  -otxt,    --output-txt     output result in a text file\n");
-    fprintf(stderr, "  -ovtt,    --output-vtt     output result in a vtt file\n");
-    fprintf(stderr, "  -osrt,    --output-srt     output result in a srt file\n");
-    fprintf(stderr, "  -owts,    --output-words   output script for generating karaoke video\n");
-    fprintf(stderr, "  -ps,      --print_special  print special tokens\n");
-    fprintf(stderr, "  -pc,      --print_colors   print colors\n");
-    fprintf(stderr, "  -nt,      --no_timestamps  do not print timestamps\n");
-    fprintf(stderr, "  -l LANG,  --language LANG  spoken language (default: %s)\n", params.language.c_str());
-    fprintf(stderr, "  -m FNAME, --model FNAME    model path (default: %s)\n", params.model.c_str());
-    fprintf(stderr, "  -f FNAME, --file FNAME     input WAV file path\n");
+    fprintf(stderr, "  -h,       --help          [default] show this help message and exit\n");
+    fprintf(stderr, "  -t N,     --threads N     [%-7d] number of threads to use during computation\n",    params.n_threads);
+    fprintf(stderr, "  -p N,     --processors N  [%-7d] number of processors to use during computation\n", params.n_processors);
+    fprintf(stderr, "  -ot N,    --offset-t N    [%-7d] time offset in milliseconds\n",                    params.offset_t_ms);
+    fprintf(stderr, "  -on N,    --offset-n N    [%-7d] segment index offset\n",                           params.offset_n);
+    fprintf(stderr, "  -d  N,    --duration N    [%-7d] duration of audio to process in milliseconds\n",   params.duration_ms);
+    fprintf(stderr, "  -mc N,    --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
+    fprintf(stderr, "  -ml N,    --max-len N     [%-7d] maximum segment length in characters\n",           params.max_len);
+    fprintf(stderr, "  -wt N,    --word-thold N  [%-7.2f] word timestamp probability threshold\n",         params.word_thold);
+    fprintf(stderr, "  -su,      --speed-up      [%-7s] speed up audio by x2 (reduced accuracy)\n",        params.speed_up ? "true" : "false");
+    fprintf(stderr, "  -tr,      --translate     [%-7s] translate from source language to english\n",      params.translate ? "true" : "false");
+    fprintf(stderr, "  -di,      --diarize       [%-7s] stereo audio diarization\n",                       params.diarize ? "true" : "false");
+    fprintf(stderr, "  -otxt,    --output-txt    [%-7s] output result in a text file\n",                   params.output_txt ? "true" : "false");
+    fprintf(stderr, "  -ovtt,    --output-vtt    [%-7s] output result in a vtt file\n",                    params.output_vtt ? "true" : "false");
+    fprintf(stderr, "  -osrt,    --output-srt    [%-7s] output result in a srt file\n",                    params.output_srt ? "true" : "false");
+    fprintf(stderr, "  -owts,    --output-words  [%-7s] output script for generating karaoke video\n",     params.output_wts ? "true" : "false");
+    fprintf(stderr, "  -ps,      --print-special [%-7s] print special tokens\n",                           params.print_special ? "true" : "false");
+    fprintf(stderr, "  -pc,      --print-colors  [%-7s] print colors\n",                                   params.print_colors ? "true" : "false");
+    fprintf(stderr, "  -nt,      --no-timestamps [%-7s] do not print timestamps\n",                        params.no_timestamps ? "false" : "true");
+    fprintf(stderr, "  -l LANG,  --language LANG [%-7s] spoken language\n",                                params.language.c_str());
+    fprintf(stderr, "  -m FNAME, --model FNAME   [%-7s] model path\n",                                     params.model.c_str());
+    fprintf(stderr, "  -f FNAME, --file FNAME    [%-7s] input WAV file path\n",                            "");
     fprintf(stderr, "\n");
 }
 
+struct whisper_print_user_data {
+    const whisper_params * params;
+
+    const std::vector<std::vector<float>> * pcmf32s;
+};
+
 void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
-    const whisper_params & params = *(whisper_params *) user_data;
+    const auto & params  = *((whisper_print_user_data *) user_data)->params;
+    const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
 
     const int n_segments = whisper_full_n_segments(ctx);
 
@@ -191,7 +177,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
         if (params.no_timestamps) {
             if (params.print_colors) {
                 for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
-                    if (params.print_special_tokens == false) {
+                    if (params.print_special == false) {
                         const whisper_token id = whisper_full_get_token_id(ctx, i, j);
                         if (id >= whisper_token_eot(ctx)) {
                             continue;
@@ -214,10 +200,37 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
             const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
             const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
 
+            std::string speaker = "";
+
+            if (params.diarize && pcmf32s.size() == 2) {
+                const int64_t n_samples = pcmf32s[0].size();
+
+                const int64_t is0 = timestamp_to_sample(t0, n_samples);
+                const int64_t is1 = timestamp_to_sample(t1, n_samples);
+
+                double energy0 = 0.0f;
+                double energy1 = 0.0f;
+
+                for (int64_t j = is0; j < is1; j++) {
+                    energy0 += fabs(pcmf32s[0][j]);
+                    energy1 += fabs(pcmf32s[1][j]);
+                }
+
+                if (energy0 > 1.1*energy1) {
+                    speaker = "(speaker 0)";
+                } else if (energy1 > 1.1*energy0) {
+                    speaker = "(speaker 1)";
+                } else {
+                    speaker = "(speaker ?)";
+                }
+
+                //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
+            }
+
             if (params.print_colors) {
                 printf("[%s --> %s]  ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
                 for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
-                    if (params.print_special_tokens == false) {
+                    if (params.print_special == false) {
                         const whisper_token id = whisper_full_get_token_id(ctx, i, j);
                         if (id >= whisper_token_eot(ctx)) {
                             continue;
@@ -229,13 +242,13 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
 
                     const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
 
-                    printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
+                    printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
                 }
                 printf("\n");
             } else {
                 const char * text = whisper_full_get_segment_text(ctx, i);
 
-                printf("[%s --> %s]  %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
+                printf("[%s --> %s]  %s%s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), speaker.c_str(), text);
             }
         }
     }
@@ -263,7 +276,7 @@ bool output_vtt(struct whisper_context * ctx, const char * fname) {
     std::ofstream fout(fname);
     if (!fout.is_open()) {
         fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
-        return 9;
+        return false;
     }
 
     fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
@@ -386,9 +399,9 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
                     ncnt += txt.size();
                 }
 
-                ::replace_all(txt_bg, "'", "");
+                ::replace_all(txt_bg, "'", "\u2019");
                 ::replace_all(txt_bg, "\"", "\\\"");
-                ::replace_all(txt_fg, "'", "");
+                ::replace_all(txt_fg, "'", "\u2019");
                 ::replace_all(txt_fg, "\"", "\\\"");
             }
 
@@ -428,16 +441,18 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
-    if (params.seed < 0) {
-        params.seed = time(NULL);
-    }
-
     if (params.fname_inp.empty()) {
         fprintf(stderr, "error: no input files specified\n");
         whisper_print_usage(argc, argv, params);
         return 2;
     }
 
+    if (whisper_lang_id(params.language.c_str()) == -1) {
+        fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
+        whisper_print_usage(argc, argv, params);
+        exit(0);
+    }
+
     // whisper init
 
     struct whisper_context * ctx = whisper_init(params.model.c_str());
@@ -450,53 +465,60 @@ int main(int argc, char ** argv) {
     for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
         const auto fname_inp = params.fname_inp[f];
 
+        std::vector<float> pcmf32; // mono-channel F32 PCM
+        std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
+
         // WAV input
-        std::vector<float> pcmf32;
         {
             drwav wav;
-            
+            std::vector<uint8_t> wav_data; // used for pipe input from stdin
+
             if (fname_inp == "-") {
-                std::vector<uint8_t> wav_data;
                 {
                     uint8_t buf[1024];
                     while (true)
                     {
                         const size_t n = fread(buf, 1, sizeof(buf), stdin);
-                        if (n == 0)
-                        {
+                        if (n == 0) {
                             break;
                         }
                         wav_data.insert(wav_data.end(), buf, buf + n);
                     }
                 }
 
-                if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), NULL) == false)
-                {
+                if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), NULL) == false) {
                     fprintf(stderr, "error: failed to open WAV file from stdin\n");
                     return 4;
                 }
+
+                fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
             }
             else if (drwav_init_file(&wav, fname_inp.c_str(), NULL) == false) {
                 fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
-                return 4;
+                return 5;
             }
 
             if (wav.channels != 1 && wav.channels != 2) {
                 fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
-                return 5;
+                return 6;
+            }
+
+            if (params.diarize && wav.channels != 2 && params.no_timestamps == false) {
+                fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", argv[0], fname_inp.c_str());
+                return 6;
             }
 
             if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
                 fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
-                return 6;
+                return 8;
             }
 
             if (wav.bitsPerSample != 16) {
                 fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
-                return 7;
+                return 9;
             }
 
-            int n = wav.totalPCMFrameCount;
+            const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
 
             std::vector<int16_t> pcm16;
             pcm16.resize(n*wav.channels);
@@ -514,6 +536,18 @@ int main(int argc, char ** argv) {
                     pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
                 }
             }
+
+            if (params.diarize) {
+                // convert to stereo, float
+                pcmf32s.resize(2);
+
+                pcmf32s[0].resize(n);
+                pcmf32s[1].resize(n);
+                for (int i = 0; i < n; i++) {
+                    pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
+                    pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
+                }
+            }
         }
 
         // print system information
@@ -548,30 +582,47 @@ int main(int argc, char ** argv) {
         {
             whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
 
-            wparams.print_realtime       = false;
-            wparams.print_progress       = false;
-            wparams.print_timestamps     = !params.no_timestamps;
-            wparams.print_special_tokens = params.print_special_tokens;
-            wparams.translate            = params.translate;
-            wparams.language             = params.language.c_str();
-            wparams.n_threads            = params.n_threads;
-            wparams.n_max_text_ctx       = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
-            wparams.offset_ms            = params.offset_t_ms;
-            wparams.duration_ms          = params.duration_ms;
-
-            wparams.token_timestamps     = params.output_wts || params.max_len > 0;
-            wparams.thold_pt             = params.word_thold;
-            wparams.max_len              = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
+            wparams.print_realtime   = false;
+            wparams.print_progress   = false;
+            wparams.print_timestamps = !params.no_timestamps;
+            wparams.print_special    = params.print_special;
+            wparams.translate        = params.translate;
+            wparams.language         = params.language.c_str();
+            wparams.n_threads        = params.n_threads;
+            wparams.n_max_text_ctx   = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
+            wparams.offset_ms        = params.offset_t_ms;
+            wparams.duration_ms      = params.duration_ms;
+
+            wparams.token_timestamps = params.output_wts || params.max_len > 0;
+            wparams.thold_pt         = params.word_thold;
+            wparams.max_len          = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
+
+            wparams.speed_up         = params.speed_up;
+
+            whisper_print_user_data user_data = { &params, &pcmf32s };
 
             // this callback is called on each new segment
             if (!wparams.print_realtime) {
                 wparams.new_segment_callback           = whisper_print_segment_callback;
-                wparams.new_segment_callback_user_data = &params;
+                wparams.new_segment_callback_user_data = &user_data;
+            }
+
+            // example for abort mechanism
+            // in this example, we do not abort the processing, but we could if the flag is set to true
+            // the callback is called before every encoder run - if it returns false, the processing is aborted
+            {
+                static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
+
+                wparams.encoder_begin_callback = [](struct whisper_context * ctx, void * user_data) {
+                    bool is_aborted = *(bool*)user_data;
+                    return !is_aborted;
+                };
+                wparams.encoder_begin_callback_user_data = &is_aborted;
             }
 
             if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
                 fprintf(stderr, "%s: failed to process audio\n", argv[0]);
-                return 8;
+                return 10;
             }
         }
 
index 7078863aa3e28e68c2bdda802aa43d038bd31081..42467efe693fa62e2c420e438ceba8a6eed56966 100644 (file)
@@ -424,6 +424,9 @@ struct whisper_context {
     int64_t t_last;
     whisper_token tid_last;
     std::vector<float> energy; // PCM signal energy
+
+    // [EXPERIMENTAL] speed-up techniques
+    int32_t exp_n_audio_ctx; // 0 - use default
 };
 
 // load the model from a ggml file
@@ -515,15 +518,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
         wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type));
         wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
         wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
-
-        // this is the total memory required to run the inference
-        const size_t mem_required =
-                   wctx.buf_model->size() +
-                   wctx.buf_memory.size() +
-                   wctx.buf_compute.size() +
-                   wctx.buf_compute_layer.size();
-
-        fprintf(stderr, "%s: mem_required  = %.2f MB\n", __func__, mem_required / 1024.0 / 1024.0);
     }
 
     // load mel filters
@@ -596,11 +590,21 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
         }
     }
 
+    {
+        // this is the total memory required to run the inference
+        const size_t mem_required =
+                   wctx.buf_model->size() +
+                   wctx.buf_memory.size() +
+                   wctx.buf_compute.size() +
+                   wctx.buf_compute_layer.size();
+
+        fprintf(stderr, "%s: mem_required  = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0);
+    }
+
     // for the big tensors, we have the option to store the data in 16-bit floats
     // in order to save memory and also to speed up the computation
     const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
 
-
     size_t ctx_size = 0;
     size_t ctx_mem_size = 0;
 
@@ -613,7 +617,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
         const int n_audio_state = hparams.n_audio_state;
         const int n_audio_layer = hparams.n_audio_layer;
 
-        const int n_text_ctx = hparams.n_text_ctx;
+        const int n_text_ctx   = hparams.n_text_ctx;
         const int n_text_state = hparams.n_text_state;
         const int n_text_layer = hparams.n_text_layer;
 
@@ -719,7 +723,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
 
         ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
 
-        fprintf(stderr, "%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
+        fprintf(stderr, "%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
     }
 
     // create the ggml context
@@ -748,7 +752,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
         const int n_audio_state = hparams.n_audio_state;
         const int n_audio_layer = hparams.n_audio_layer;
 
-        const int n_text_ctx = hparams.n_text_ctx;
+        const int n_text_ctx   = hparams.n_text_ctx;
         const int n_text_state = hparams.n_text_state;
         const int n_text_layer = hparams.n_text_layer;
 
@@ -967,7 +971,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
 
         // key/value memory for the cross-attention layer
         {
-            const int n_audio_ctx   = hparams.n_audio_ctx;
+            const int n_audio_ctx = hparams.n_audio_ctx;
 
             const int n_mem      = n_text_layer*n_audio_ctx;
             const int n_elements = n_text_state*n_mem;
@@ -980,7 +984,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
             ggml_nbytes(model.memory_k)       + ggml_nbytes(model.memory_v) +
             ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
 
-        fprintf(stderr, "%s: memory size = %8.2f MB\n", __func__, memory_size/1024.0/1024.0);
+        fprintf(stderr, "%s: memory size   = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
     }
 
     // load weights
@@ -1039,12 +1043,12 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
 
             fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
 
-            //printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
+            //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
             total_size += ggml_nbytes(tensor);
             model.n_loaded++;
         }
 
-        fprintf(stderr, "%s: model size  = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
+        fprintf(stderr, "%s: model size    = %7.2f MB\n", __func__, total_size/1024.0/1024.0);
 
         if (model.n_loaded == 0) {
             fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
@@ -1076,13 +1080,11 @@ static bool whisper_encode(
     const auto & mel_inp = wctx.mel;
     const auto & hparams = model.hparams;
 
-    const int n_ctx   = hparams.n_audio_ctx;
+    const int n_ctx   = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
     const int n_state = hparams.n_audio_state;
     const int n_head  = hparams.n_audio_head;
     const int n_layer = hparams.n_audio_layer;
 
-    const int N = n_ctx;
-
     const int n_mels = hparams.n_mels;
     assert(mel_inp.n_mel == n_mels);
 
@@ -1132,7 +1134,30 @@ static bool whisper_encode(
         cur = ggml_gelu(ctx0, cur);
     }
 
-    cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
+    // ===================================================================
+    // NOTE: experimenting with partial evaluation of the encoder (ignore)
+    //static int iter = -1;
+    //const int n_iter = 1500/n_ctx;
+
+    //iter = (iter + 1) % n_iter;
+
+    //if (iter == 0) {
+    //    memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
+    //    memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
+    //}
+
+    static int iter = 0;
+
+    const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
+    const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
+
+    struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
+
+    cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur));
+    // ===================================================================
+
+    // original:
+    //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
 
     struct ggml_tensor * inpL = cur;
 
@@ -1198,14 +1223,14 @@ static bool whisper_encode(
                 ggml_permute(ctxL,
                         ggml_cpy(ctxL,
                             Qcur,
-                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
+                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
                         0, 2, 1, 3);
 
             struct ggml_tensor * K =
                 ggml_permute(ctxL,
                         ggml_cpy(ctxL,
                             Kcur,
-                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
+                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
                         0, 2, 1, 3);
 
             struct ggml_tensor * V =
@@ -1213,9 +1238,9 @@ static bool whisper_encode(
                         ggml_permute(ctxL,
                             ggml_reshape_3d(ctxL,
                                 Vcur,
-                                n_state/n_head, n_head, N),
+                                n_state/n_head, n_head, n_ctx),
                             1, 2, 0, 3),
-                        ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, N, n_state/n_head, n_head)
+                        ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head)
                         );
 
             struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
@@ -1224,14 +1249,14 @@ static bool whisper_encode(
                 ggml_permute(ctxL,
                         ggml_cpy(ctxL,
                             Qcur,
-                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
+                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
                         0, 2, 1, 3);
 
             struct ggml_tensor * K =
                 ggml_permute(ctxL,
                         ggml_cpy(ctxL,
                             Kcur,
-                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
+                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
                         0, 2, 1, 3);
 
             // K * Q
@@ -1249,7 +1274,7 @@ static bool whisper_encode(
             //    ggml_permute(ctxL,
             //            ggml_cpy(ctxL,
             //                Vcur,
-            //                ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
+            //                ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
             //            1, 2, 0, 3);
 
             //struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
@@ -1259,9 +1284,9 @@ static bool whisper_encode(
                         ggml_permute(ctxL,
                             ggml_reshape_3d(ctxL,
                                 Vcur,
-                                n_state/n_head, n_head, N),
+                                n_state/n_head, n_head, n_ctx),
                             0, 2, 1, 3),
-                        ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, N, n_head)
+                        ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head)
                         );
 
             struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
@@ -1271,7 +1296,7 @@ static bool whisper_encode(
 
             cur = ggml_cpy(ctxL,
                     KQV_merged,
-                    ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
+                    ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, n_ctx));
         }
 
         // projection
@@ -1425,6 +1450,8 @@ static bool whisper_encode(
                         Vcross),
                     Vcross);
 
+            //struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
+            //struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
             struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx));
             struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx));
 
@@ -1474,7 +1501,7 @@ static bool whisper_decode(
     const int n_layer = hparams.n_text_layer;
 
     const int N = n_tokens;
-    const int M = hparams.n_audio_ctx;
+    const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
 
     struct ggml_init_params params = {
             .mem_size   = wctx.buf_compute.size(),
@@ -1819,7 +1846,9 @@ static bool whisper_decode(
 // the most basic sampling scheme - select the top token
 static whisper_token_data whisper_sample_best(
         const whisper_vocab & vocab,
-        const float * probs) {
+        const float * probs,
+              bool force_timestamp,
+              bool is_initial) {
     whisper_token_data result = {
         0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
     };
@@ -1842,7 +1871,18 @@ static whisper_token_data whisper_sample_best(
             max_tx = std::max(max_tx, probs_id[i].first);
         }
 
-        for (int i = vocab.token_beg; i < n_logits; i++) {
+        const auto i0 = is_initial ? vocab.token_beg + 101 : vocab.token_beg;
+        const auto i1 = is_initial ? vocab.token_beg + 101 : n_logits;
+
+        // the initial timestamp cannot be larger than 100
+        // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
+        if (is_initial) {
+            for (int i = i0; i < n_logits; ++ i) {
+                probs_id[i].first = -INFINITY;
+            }
+        }
+
+        for (int i = vocab.token_beg; i < i1; i++) {
             sum_ts += probs_id[i].first;
             if  (probs_id[i].first > max_ts) {
                 max_ts = probs_id[i].first;
@@ -1852,7 +1892,7 @@ static whisper_token_data whisper_sample_best(
 
         // if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a
         // timestamp token
-        if (sum_ts > max_tx) {
+        if (sum_ts > max_tx || force_timestamp) {
             // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
             for (int i = 0; i < vocab.token_beg; i++) {
                 probs_id[i].first = -INFINITY;
@@ -1894,39 +1934,6 @@ static whisper_token_data whisper_sample_best(
     return result;
 }
 
-// samples only from the timestamps tokens
-static whisper_vocab::id whisper_sample_timestamp(
-        const whisper_vocab & vocab,
-        const float * probs) {
-    int n_logits = vocab.id_to_token.size();
-
-    std::vector<std::pair<double, whisper_vocab::id>> probs_id;
-    probs_id.reserve(n_logits);
-
-    for (int i = vocab.token_beg + 1; i < n_logits; i++) {
-        probs_id.push_back(std::make_pair(probs[i], i));
-    }
-
-    const int top_k = 10;
-
-    // find the top K tokens
-    std::partial_sort(
-            probs_id.begin(),
-            probs_id.begin() + top_k, probs_id.end(),
-            [](const std::pair<double, whisper_vocab::id> & a, const std::pair<double, whisper_vocab::id> & b) {
-        return a.first > b.first;
-    });
-
-    probs_id.resize(top_k);
-
-    //printf("\n");
-    //for (int i = 0; i < (int) probs_id.size(); i++) {
-    //    printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
-    //}
-
-    return probs_id[0].second;
-}
-
 //  500 -> 00:05.000
 // 6000 -> 01:00.000
 static std::string to_timestamp(int64_t t, bool comma = false) {
@@ -2031,6 +2038,7 @@ static bool log_mel_spectrogram(
     const int n_mel,
     const int n_threads,
     const whisper_filters & filters,
+    const bool speed_up,
     whisper_mel & mel) {
 
     // Hanning window
@@ -2044,7 +2052,7 @@ static bool log_mel_spectrogram(
     mel.n_len = (n_samples)/fft_step;
     mel.data.resize(mel.n_mel*mel.n_len);
 
-    const int n_fft = 1 + fft_size/2;
+    const int n_fft = 1 + (speed_up ? fft_size/4 : fft_size/2);
 
     //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
     //printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
@@ -2091,6 +2099,13 @@ static bool log_mel_spectrogram(
                     //}
                 }
 
+                if (speed_up) {
+                    // scale down in the frequency domain results in a speed up in the time domain
+                    for (int j = 0; j < n_fft; j++) {
+                        fft_out[j] = 0.5*(fft_out[2*j] + fft_out[2*j + 1]);
+                    }
+                }
+
                 // mel spectrogram
                 for (int j = 0; j < mel.n_mel; j++) {
                     double sum = 0.0;
@@ -2161,6 +2176,12 @@ struct whisper_context * whisper_init(const char * path_model) {
 
 void whisper_free(struct whisper_context * ctx) {
     if (ctx) {
+        if (ctx->model.ctx) {
+            ggml_free(ctx->model.ctx);
+        }
+        if (ctx->model.ctx_mem) {
+            ggml_free(ctx->model.ctx_mem);
+        }
         if (ctx->buf_model) {
             delete ctx->buf_model;
         }
@@ -2171,7 +2192,21 @@ void whisper_free(struct whisper_context * ctx) {
 int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
     const int64_t t_start_us = ggml_time_us();
 
-    if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, ctx->mel)) {
+    if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) {
+        fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
+        return -1;
+    }
+
+    ctx->t_mel_us = ggml_time_us() - t_start_us;
+
+    return 0;
+}
+
+// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
+int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
+    const int64_t t_start_us = ggml_time_us();
+
+    if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) {
         fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
         return -1;
     }
@@ -2229,19 +2264,17 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
 struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) {
     const int64_t t_start_sample_us = ggml_time_us();
 
-    // TODO: simplify
-    auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
+    const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false);
 
     ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
 
     return res;
 }
 
-whisper_token whisper_sample_timestamp(struct whisper_context * ctx) {
+struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) {
     const int64_t t_start_sample_us = ggml_time_us();
 
-    // TODO: simplify
-    auto res = whisper_sample_timestamp(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
+    const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial);
 
     ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
 
@@ -2305,11 +2338,11 @@ whisper_token whisper_token_beg(struct whisper_context * ctx) {
     return ctx->vocab.token_beg;
 }
 
-whisper_token whisper_token_translate() {
+whisper_token whisper_token_translate(void) {
     return whisper_vocab::token_translate;
 }
 
-whisper_token whisper_token_transcribe() {
+whisper_token whisper_token_transcribe(void) {
     return whisper_vocab::token_transcribe;
 }
 
@@ -2325,6 +2358,27 @@ void whisper_print_timings(struct whisper_context * ctx) {
     fprintf(stderr, "%s:    total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
 }
 
+void whisper_reset_timings(struct whisper_context * ctx) {
+    ctx->t_sample_us = 0;
+    ctx->t_encode_us = 0;
+    ctx->t_decode_us = 0;
+}
+
+const char * whisper_print_system_info(void) {
+    static std::string s;
+
+    s  = "";
+    s += "AVX = "       + std::to_string(ggml_cpu_has_avx())       + " | ";
+    s += "AVX2 = "      + std::to_string(ggml_cpu_has_avx2())      + " | ";
+    s += "AVX512 = "    + std::to_string(ggml_cpu_has_avx512())    + " | ";
+    s += "NEON = "      + std::to_string(ggml_cpu_has_neon())      + " | ";
+    s += "FP16_VA = "   + std::to_string(ggml_cpu_has_fp16_va())   + " | ";
+    s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
+    s += "BLAS = "      + std::to_string(ggml_cpu_has_blas())      + " | ";
+
+    return s.c_str();
+}
+
 ////////////////////////////////////////////////////////////////////////////
 
 struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
@@ -2334,77 +2388,99 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
         case WHISPER_SAMPLING_GREEDY:
             {
                 result = {
-                    /*.strategy             =*/ WHISPER_SAMPLING_GREEDY,
+                    /*.strategy         =*/ WHISPER_SAMPLING_GREEDY,
+
+                    /*.n_threads        =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
+                    /*.n_max_text_ctx   =*/ 16384,
+                    /*.offset_ms        =*/ 0,
+                    /*.duration_ms      =*/ 0,
 
-                    /*.n_threads            =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
-                    /*.n_max_text_ctx       =*/ 16384,
-                    /*.offset_ms            =*/ 0,
-                    /*.duration_ms          =*/ 0,
+                    /*.translate        =*/ false,
+                    /*.no_context       =*/ false,
+                    /*.single_segment   =*/ false,
+                    /*.print_special    =*/ false,
+                    /*.print_progress   =*/ true,
+                    /*.print_realtime   =*/ false,
+                    /*.print_timestamps =*/ true,
 
-                    /*.translate            =*/ false,
-                    /*.no_context           =*/ false,
-                    /*.print_special_tokens =*/ false,
-                    /*.print_progress       =*/ true,
-                    /*.print_realtime       =*/ false,
-                    /*.print_timestamps     =*/ true,
+                    /*.token_timestamps =*/ false,
+                    /*.thold_pt         =*/ 0.01f,
+                    /*.thold_ptsum      =*/ 0.01f,
+                    /*.max_len          =*/ 0,
+                    /*.max_tokens       =*/ 0,
 
-                    /*.token_timestamps     =*/ false,
-                    /*.thold_pt             =*/ 0.01f,
-                    /*.thold_ptsum          =*/ 0.01f,
-                    /*.max_len              =*/ 0,
+                    /*.speed_up         =*/ false,
+                    /*.audio_ctx        =*/ 0,
 
-                    /*.language             =*/ "en",
+                    /*.prompt_tokens    =*/ nullptr,
+                    /*.prompt_n_tokens  =*/ 0,
 
-                    /*.greedy               =*/ {
+                    /*.language         =*/ "en",
+
+                    /*.greedy           =*/ {
                         /*.n_past =*/ 0,
                     },
 
-                    /*.beam_search          =*/ {
+                    /*.beam_search      =*/ {
                         /*.n_past     =*/ -1,
                         /*.beam_width =*/ -1,
                         /*.n_best     =*/ -1,
                     },
 
-                    /*.new_segment_callback =*/ nullptr,
+                    /*.new_segment_callback           =*/ nullptr,
                     /*.new_segment_callback_user_data =*/ nullptr,
+
+                    /*.encoder_begin_callback           =*/ nullptr,
+                    /*.encoder_begin_callback_user_data =*/ nullptr,
                 };
             } break;
         case WHISPER_SAMPLING_BEAM_SEARCH:
             {
                 result = {
-                    /*.strategy             =*/ WHISPER_SAMPLING_BEAM_SEARCH,
+                    /*.strategy         =*/ WHISPER_SAMPLING_BEAM_SEARCH,
 
-                    /*.n_threads            =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
-                    /*.n_max_text_ctx       =*/ 16384,
-                    /*.offset_ms            =*/ 0,
-                    /*.duration_ms          =*/ 0,
+                    /*.n_threads        =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
+                    /*.n_max_text_ctx   =*/ 16384,
+                    /*.offset_ms        =*/ 0,
+                    /*.duration_ms      =*/ 0,
 
-                    /*.translate            =*/ false,
-                    /*.no_context           =*/ false,
-                    /*.print_special_tokens =*/ false,
-                    /*.print_progress       =*/ true,
-                    /*.print_realtime       =*/ false,
-                    /*.print_timestamps     =*/ true,
+                    /*.translate        =*/ false,
+                    /*.no_context       =*/ false,
+                    /*.single_segment   =*/ false,
+                    /*.print_special    =*/ false,
+                    /*.print_progress   =*/ true,
+                    /*.print_realtime   =*/ false,
+                    /*.print_timestamps =*/ true,
 
-                    /*.token_timestamps     =*/ false,
-                    /*.thold_pt             =*/ 0.01f,
-                    /*.thold_ptsum          =*/ 0.01f,
-                    /*.max_len              =*/ 0,
+                    /*.token_timestamps =*/ false,
+                    /*.thold_pt         =*/ 0.01f,
+                    /*.thold_ptsum      =*/ 0.01f,
+                    /*.max_len          =*/ 0,
+                    /*.max_tokens       =*/ 0,
 
-                    /*.language             =*/ "en",
+                    /*.speed_up         =*/ false,
+                    /*.audio_ctx        =*/ 0,
 
-                    /*.greedy               =*/ {
+                    /*.prompt_tokens    =*/ nullptr,
+                    /*.prompt_n_tokens  =*/ 0,
+
+                    /*.language         =*/ "en",
+
+                    /*.greedy           =*/ {
                         /*.n_past =*/ -1,
                     },
 
-                    /*.beam_search          =*/ {
+                    /*.beam_search      =*/ {
                         /*.n_past     =*/ 0,
                         /*.beam_width =*/ 10,
                         /*.n_best     =*/ 5,
                     },
 
-                    /*.new_segment_callback =*/ nullptr,
+                    /*.new_segment_callback           =*/ nullptr,
                     /*.new_segment_callback_user_data =*/ nullptr,
+
+                    /*.encoder_begin_callback           =*/ nullptr,
+                    /*.encoder_begin_callback_user_data =*/ nullptr,
                 };
             } break;
     }
@@ -2485,9 +2561,16 @@ int whisper_full(
     result_all.clear();
 
     // compute log mel spectrogram
-    if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
-        fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
-        return -1;
+    if (params.speed_up) {
+        if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) {
+            fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
+            return -1;
+        }
+    } else {
+        if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
+            fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
+            return -1;
+        }
     }
 
     if (params.token_timestamps) {
@@ -2513,6 +2596,18 @@ int whisper_full(
         prompt_past.clear();
     }
 
+    // prepend the prompt tokens to the prompt_past
+    if (params.prompt_tokens && params.prompt_n_tokens > 0) {
+        // parse tokens from the pointer
+        for (int i = 0; i < params.prompt_n_tokens; i++) {
+            prompt_past.push_back(params.prompt_tokens[i]);
+        }
+        std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
+    }
+
+    // overwrite audio_ctx
+    ctx->exp_n_audio_ctx = params.audio_ctx;
+
     // these tokens determine the task that will be performed
     std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
     if (whisper_is_multilingual(ctx)) {
@@ -2548,6 +2643,13 @@ int whisper_full(
             break;
         }
 
+        if (params.encoder_begin_callback) {
+            if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
+                fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
+                break;
+            }
+        }
+
         // encode audio features starting at offset seek
         if (whisper_encode(ctx, seek, params.n_threads) != 0) {
             fprintf(stderr, "%s: failed to encode\n", __func__);
@@ -2570,7 +2672,6 @@ int whisper_full(
 
         prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
 
-        bool done = false;
         int seek_delta = 100*WHISPER_CHUNK_SIZE;
 
         // print the prompt
@@ -2584,7 +2685,9 @@ int whisper_full(
         int result_len = 0;
         tokens_cur.clear();
 
-        for (int i = 0; i < whisper_n_text_ctx(ctx)/2 - 4; ++i) {
+        bool failed = false;
+
+        for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
             if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
                 fprintf(stderr, "%s: failed to decode\n", __func__);
                 return 8;
@@ -2601,15 +2704,19 @@ int whisper_full(
             // feel free to experiment!
             //
             {
-                auto token = whisper_sample_best(ctx);
-
-                if (i == 0) {
-                    token.tid = whisper_token_beg(ctx);
-                }
+                const auto token = (i == 0) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx);
 
                 // timestamp token - update sliding window
                 if (token.id > whisper_token_beg(ctx)) {
-                    seek_delta = 2*(token.id - whisper_token_beg(ctx));
+                    const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
+
+                    // do not allow to go back in time
+                    if (seek_delta != 100*WHISPER_CHUNK_SIZE &&
+                        seek_delta > seek_delta_new && result_len < i) {
+                        break;
+                    }
+
+                    seek_delta = seek_delta_new;
                     result_len = i + 1;
                 }
 
@@ -2619,19 +2726,25 @@ int whisper_full(
 
                 //{
                 //    const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
-                //    printf("%s: %10s %6.3f '%s'\n", __func__, tt.c_str(), token.pt, ctx->vocab.id_to_token[token.id].c_str());
+                //    printf("%s: %10s %6d %6.3f '%s'\n", __func__, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
                 //}
 
                 // end of text token
-                if (token.id == whisper_token_eot(ctx)) {
+                if (token.id == whisper_token_eot(ctx) || (params.max_tokens > 0 && i > params.max_tokens)) {
                     if (result_len == 0) {
                         if (seek + seek_delta + 100 >= seek_end) {
                             result_len = i + 1;
                         } else {
-                            // TODO: figure out how to resolve this
-                            fprintf(stderr, "\n%s: failed to generate timestamp token - this should not happen\n\n", __func__);
+                            failed = true;
+                            break;
                         }
                     }
+
+                    if (params.single_segment) {
+                        result_len = i + 1;
+                        seek_delta = 100*WHISPER_CHUNK_SIZE;
+                    }
+
                     break;
                 }
 
@@ -2642,11 +2755,21 @@ int whisper_full(
                 }
             }
 
-            if (done) {
+            // sometimes, the decoding can get stuck in a repetition loop
+            // this is a simple strategy to avoid such cases - we simply flag the decoding as failed and advance
+            // the sliding window by 1 second
+            if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
+                failed = true;
                 break;
             }
         }
 
+        if (failed) {
+            fprintf(stderr, "\n%s: failed to generate timestamp token - using fallback strategy\n\n", __func__);
+            seek += 100;
+            continue;
+        }
+
         // shrink down to result_len
         tokens_cur.resize(result_len);
 
@@ -2666,23 +2789,26 @@ int whisper_full(
                 //        ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
                 //        ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
 
-                if (params.print_special_tokens == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
+                if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
                 } else {
                     text += whisper_token_to_str(ctx, tokens_cur[i].id);
                 }
-                if (tokens_cur[i].id > whisper_token_beg(ctx)) {
+                if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
                     const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
                     if (!text.empty()) {
+                        const auto tt0 = params.speed_up ? 2*t0 : t0;
+                        const auto tt1 = params.speed_up ? 2*t1 : t1;
+
                         if (params.print_realtime) {
                             if (params.print_timestamps) {
-                                printf("[%s --> %s]  %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text.c_str());
+                                printf("[%s --> %s]  %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
                             } else {
                                 printf("%s", text.c_str());
                                 fflush(stdout);
                             }
                         }
 
-                        result_all.push_back({ t0, t1, text, {} });
+                        result_all.push_back({ tt0, tt1, text, {} });
                         for (int j = i0; j <= i; j++) {
                             result_all.back().tokens.push_back(tokens_cur[j]);
                         }
@@ -2714,16 +2840,19 @@ int whisper_full(
             if (!text.empty()) {
                 const auto t1 = seek + seek_delta;
 
+                const auto tt0 = params.speed_up ? 2*t0 : t0;
+                const auto tt1 = params.speed_up ? 2*t1 : t1;
+
                 if (params.print_realtime) {
                     if (params.print_timestamps) {
-                        printf("[%s --> %s]  %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text.c_str());
+                        printf("[%s --> %s]  %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
                     } else {
                         printf("%s", text.c_str());
                         fflush(stdout);
                     }
                 }
 
-                result_all.push_back({ t0, t1, text, {} });
+                result_all.push_back({ tt0, tt1, text, {} });
                 for (int j = i0; j < (int) tokens_cur.size(); j++) {
                     result_all.back().tokens.push_back(tokens_cur[j]);
                 }
@@ -2755,7 +2884,7 @@ int whisper_full_parallel(
         struct whisper_full_params params,
         const float * samples,
         int n_samples,
-        const int n_processors) {
+        int n_processors) {
     if (n_processors == 1) {
         return whisper_full(ctx, params, samples, n_samples);
     }
@@ -2805,7 +2934,7 @@ int whisper_full_parallel(
 
             // key/value memory for the cross-attention layer
             {
-                const int n_audio_ctx   = hparams.n_audio_ctx;
+                const int n_audio_ctx = hparams.n_audio_ctx;
 
                 const int n_mem      = n_text_layer*n_audio_ctx;
                 const int n_elements = n_text_state*n_mem;
@@ -2813,10 +2942,6 @@ int whisper_full_parallel(
                 model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
                 model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
             }
-
-            const size_t memory_size =
-                ggml_nbytes(model.memory_k)       + ggml_nbytes(model.memory_v) +
-                ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
         }
     }
 
@@ -2936,20 +3061,6 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int
     return ctx->result_all[i_segment].tokens[i_token].p;
 }
 
-const char * whisper_print_system_info() {
-    static std::string s;
-
-    s  = "";
-    s += "AVX2 = "      + std::to_string(ggml_cpu_has_avx2())      + " | ";
-    s += "AVX512 = "    + std::to_string(ggml_cpu_has_avx512())    + " | ";
-    s += "NEON = "      + std::to_string(ggml_cpu_has_neon())      + " | ";
-    s += "FP16_VA = "   + std::to_string(ggml_cpu_has_fp16_va())   + " | ";
-    s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
-    s += "BLAS = "      + std::to_string(ggml_cpu_has_blas())      + " | ";
-
-    return s.c_str();
-}
-
 // =================================================================================================
 
 //
@@ -3036,9 +3147,6 @@ static void whisper_exp_compute_token_level_timestamps(
     const int64_t t0 = segment.t0;
     const int64_t t1 = segment.t1;
 
-    const int s0 = timestamp_to_sample(t0, n_samples);
-    const int s1 = timestamp_to_sample(t1, n_samples);
-
     const int n = tokens.size();
 
     if (n == 0) {
index 4c112f49c0ca8ca6fa8653ae9f3ff0ee852867a3..def77d4c3c2a76751c48996d4fc5f0f49c44bc52 100644 (file)
@@ -72,16 +72,16 @@ extern "C" {
         whisper_token id;  // token id
         whisper_token tid; // forced timestamp token id
 
-        float p;     // probability of the token
-        float pt;    // probability of the timestamp token
-        float ptsum; // sum of probabilities of all timestamp tokens
+        float p;           // probability of the token
+        float pt;          // probability of the timestamp token
+        float ptsum;       // sum of probabilities of all timestamp tokens
 
         // token-level timestamp data
         // do not use if you haven't computed token-level timestamps
-        int64_t t0; // start time of the token
-        int64_t t1; //   end time of the token
+        int64_t t0;        // start time of the token
+        int64_t t1;        //   end time of the token
 
-        float vlen; // voice length of the token
+        float vlen;        // voice length of the token
     } whisper_token_data;
 
     // Allocates all memory needed for the model and loads the model from the given file.
@@ -96,9 +96,9 @@ extern "C" {
     // Returns 0 on success
     WHISPER_API int whisper_pcm_to_mel(
             struct whisper_context * ctx,
-            const float * samples,
-            int n_samples,
-            int n_threads);
+                       const float * samples,
+                               int   n_samples,
+                               int   n_threads);
 
     // This can be used to set a custom log mel spectrogram inside the provided whisper context.
     // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
@@ -106,9 +106,9 @@ extern "C" {
     // Returns 0 on success
     WHISPER_API int whisper_set_mel(
             struct whisper_context * ctx,
-            const float * data,
-            int n_len,
-            int n_mel);
+                       const float * data,
+                               int   n_len,
+                               int   n_mel);
 
     // Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
     // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
@@ -116,8 +116,8 @@ extern "C" {
     // Returns 0 on success
     WHISPER_API int whisper_encode(
             struct whisper_context * ctx,
-            int offset,
-            int n_threads);
+                               int   offset,
+                               int   n_threads);
 
     // Run the Whisper decoder to obtain the logits and probabilities for the next token.
     // Make sure to call whisper_encode() first.
@@ -126,10 +126,10 @@ extern "C" {
     // Returns 0 on success
     WHISPER_API int whisper_decode(
             struct whisper_context * ctx,
-            const whisper_token * tokens,
-            int n_tokens,
-            int n_past,
-            int n_threads);
+               const whisper_token * tokens,
+                               int   n_tokens,
+                               int   n_past,
+                               int   n_threads);
 
     // Token sampling methods.
     // These are provided for convenience and can be used after each call to whisper_decode().
@@ -137,7 +137,7 @@ extern "C" {
     // whisper_sample_best() returns the token with the highest probability
     // whisper_sample_timestamp() returns the most probable timestamp token
     WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
-    WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
+    WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial);
 
     // Return the id of the specified language, returns -1 if not found
     WHISPER_API int whisper_lang_id(const char * lang);
@@ -162,11 +162,15 @@ extern "C" {
     WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx);
 
     // Task tokens
-    WHISPER_API whisper_token whisper_token_translate ();
-    WHISPER_API whisper_token whisper_token_transcribe();
+    WHISPER_API whisper_token whisper_token_translate (void);
+    WHISPER_API whisper_token whisper_token_transcribe(void);
 
     // Performance information
     WHISPER_API void whisper_print_timings(struct whisper_context * ctx);
+    WHISPER_API void whisper_reset_timings(struct whisper_context * ctx);
+
+    // Print system information
+    WHISPER_API const char * whisper_print_system_info(void);
 
     ////////////////////////////////////////////////////////////////////////////
 
@@ -181,17 +185,26 @@ extern "C" {
     // Use the whisper_full_...() functions to obtain the text segments
     typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data);
 
+    // Encoder begin callback
+    // If not NULL, called before the encoder starts
+    // If it returns false, the computation is aborted
+    typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
+
+    // Parameters for the whisper_full() function
+    // If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
+    // whisper_full_default_params()
     struct whisper_full_params {
         enum whisper_sampling_strategy strategy;
 
         int n_threads;
         int n_max_text_ctx;
-        int offset_ms;      // start offset in ms
-        int duration_ms;    // audio duration to process in ms
+        int offset_ms;          // start offset in ms
+        int duration_ms;        // audio duration to process in ms
 
         bool translate;
         bool no_context;
-        bool print_special_tokens;
+        bool single_segment;    // force single segment output (useful for streaming)
+        bool print_special;
         bool print_progress;
         bool print_realtime;
         bool print_timestamps;
@@ -201,6 +214,16 @@ extern "C" {
         float thold_pt;         // timestamp token probability threshold (~0.01)
         float thold_ptsum;      // timestamp token sum probability threshold (~0.01)
         int   max_len;          // max segment length in characters
+        int   max_tokens;       // max tokens per segment (0 = no limit)
+
+        // [EXPERIMENTAL] speed-up techniques
+        bool speed_up;          // speed-up the audio by 2x using Phase Vocoder
+        int  audio_ctx;         // overwrite the audio context size (0 = use default)
+
+        // tokens to provide the whisper model as initial prompt
+        // these are prepended to any existing text context from a previous call
+        const whisper_token * prompt_tokens;
+        int prompt_n_tokens;
 
         const char * language;
 
@@ -216,6 +239,9 @@ extern "C" {
 
         whisper_new_segment_callback new_segment_callback;
         void * new_segment_callback_user_data;
+
+        whisper_encoder_begin_callback encoder_begin_callback;
+        void * encoder_begin_callback_user_data;
     };
 
     WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
@@ -223,20 +249,20 @@ extern "C" {
     // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
     // Uses the specified decoding strategy to obtain the text.
     WHISPER_API int whisper_full(
-            struct whisper_context * ctx,
-            struct whisper_full_params params,
-            const float * samples,
-            int n_samples);
+                struct whisper_context * ctx,
+            struct whisper_full_params   params,
+                           const float * samples,
+                                   int   n_samples);
 
     // Split the input audio in chunks and process each chunk separately using whisper_full()
     // It seems this approach can offer some speedup in some cases.
     // However, the transcription accuracy can be worse at the beginning and end of each chunk.
     WHISPER_API int whisper_full_parallel(
-            struct whisper_context * ctx,
-            struct whisper_full_params params,
-            const float * samples,
-            int n_samples,
-            const int n_processors);
+                struct whisper_context * ctx,
+            struct whisper_full_params   params,
+                           const float * samples,
+                                   int   n_samples,
+                                   int   n_processors);
 
     // Number of generated text segments.
     // A segment can be a few words, a sentence, or even a paragraph.
@@ -263,9 +289,6 @@ extern "C" {
     // Get the probability of the specified token in the specified segment.
     WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
 
-    // Print system information
-    WHISPER_API const char * whisper_print_system_info();
-
 #ifdef __cplusplus
 }
 #endif
index f352e716350c5581e8877fc2b7efea4edcf4acf4..3e4e962a69ec166538a817876f2cbe0c391a89f3 100644 (file)
@@ -723,6 +723,7 @@ enum ggml_opt_result ggml_opt(
 // system info
 //
 
+int ggml_cpu_has_avx(void);
 int ggml_cpu_has_avx2(void);
 int ggml_cpu_has_avx512(void);
 int ggml_cpu_has_neon(void);
index 484b6dcce1262133cf7510c59fd822877a87b6d3..b6d528d9fd6dd414351808cb648c5a774b95c70e 100644 (file)
 #include <stdio.h>
 
 #if defined _MSC_VER || defined(__MINGW32__)
+
+#if !defined(__MINGW32__)
 #include <Windows.h>
+#else
+// ref: https://github.com/ggerganov/whisper.cpp/issues/168
+#include <windows.h>
+#include <errno.h>
+#endif
 
 typedef volatile LONG atomic_int;
 typedef atomic_int atomic_bool;
@@ -37,7 +44,7 @@ typedef HANDLE pthread_t;
 
 typedef DWORD thread_ret_t;
 static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
-    HANDLE handle = CreateThread(NULL, 0, func, arg, 0, NULL);
+    HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
     if (handle == NULL)
     {
         return EAGAIN;
@@ -372,6 +379,49 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
 
     sumf = _mm_cvtss_f32(r1);
 
+    // leftovers
+    for (int i = n32; i < n; ++i) {
+        sumf += x[i]*y[i];
+    }
+#elif defined(__AVX__)
+    // AVX 256-bit
+    const int n32 = (n & ~31);
+
+    __m256 sum0 = _mm256_setzero_ps();
+    __m256 sum1 = _mm256_setzero_ps();
+    __m256 sum2 = _mm256_setzero_ps();
+    __m256 sum3 = _mm256_setzero_ps();
+
+    __m256 x0, x1, x2, x3;
+    __m256 y0, y1, y2, y3;
+
+    for (int i = 0; i < n32; i += 32) {
+        x0 = _mm256_loadu_ps(x + i + 0);
+        x1 = _mm256_loadu_ps(x + i + 8);
+        x2 = _mm256_loadu_ps(x + i + 16);
+        x3 = _mm256_loadu_ps(x + i + 24);
+
+        y0 = _mm256_loadu_ps(y + i + 0);
+        y1 = _mm256_loadu_ps(y + i + 8);
+        y2 = _mm256_loadu_ps(y + i + 16);
+        y3 = _mm256_loadu_ps(y + i + 24);
+
+       sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0);
+       sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1);
+       sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2);
+       sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3);
+    }
+
+    sum0 = _mm256_add_ps(sum0, sum1);
+    sum2 = _mm256_add_ps(sum2, sum3);
+    sum0 = _mm256_add_ps(sum0, sum2);
+
+    const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0), _mm256_extractf128_ps(sum0, 1));
+    const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4));
+    const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2));
+
+    sumf = _mm_cvtss_f32(r1);
+
     // leftovers
     for (int i = n32; i < n; ++i) {
         sumf += x[i]*y[i];
@@ -569,6 +619,50 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
 
     sumf = _mm_cvtss_f32(r1);
 
+    // leftovers
+    for (int i = n32; i < n; ++i) {
+        //GGML_ASSERT(false);
+        sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
+    }
+#elif defined(__AVX__)
+    // AVX 256-bit
+    const int n32 = (n & ~31);
+
+    __m256 sum0 = _mm256_setzero_ps();
+    __m256 sum1 = _mm256_setzero_ps();
+    __m256 sum2 = _mm256_setzero_ps();
+    __m256 sum3 = _mm256_setzero_ps();
+
+    __m256 x0, x1, x2, x3;
+    __m256 y0, y1, y2, y3;
+
+    for (int i = 0; i < n32; i += 32) {
+        x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 )));
+        x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 )));
+        x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16)));
+        x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24)));
+
+        y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 )));
+        y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 )));
+        y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16)));
+        y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24)));
+
+       sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0);
+       sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1);
+       sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2);
+       sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3);
+    }
+
+    const __m256 sum01 = _mm256_add_ps(sum0, sum1);
+    const __m256 sum23 = _mm256_add_ps(sum2, sum3);
+    const __m256 sum0123 = _mm256_add_ps(sum01, sum23);
+
+    const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0123), _mm256_extractf128_ps(sum0123, 1));
+    const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4));
+    const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2));
+
+    sumf = _mm_cvtss_f32(r1);
+
     // leftovers
     for (int i = n32; i < n; ++i) {
         //GGML_ASSERT(false);
@@ -698,6 +792,41 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
         _mm256_storeu_ps(y + i + 24, y3);
     }
 
+    // leftovers
+    for (int i = n32; i < n; ++i) {
+        y[i] += x[i]*v;
+    }
+#elif defined(__AVX__)
+    // AVX 256-bit
+    const int n32 = (n & ~31);
+
+    const __m256 v4 = _mm256_set1_ps(v);
+
+    __m256 x0, x1, x2, x3;
+    __m256 y0, y1, y2, y3;
+
+    for (int i = 0; i < n32; i += 32) {
+        x0 = _mm256_loadu_ps(x + i + 0);
+        x1 = _mm256_loadu_ps(x + i + 8);
+        x2 = _mm256_loadu_ps(x + i + 16);
+        x3 = _mm256_loadu_ps(x + i + 24);
+
+        y0 = _mm256_loadu_ps(y + i + 0);
+        y1 = _mm256_loadu_ps(y + i + 8);
+        y2 = _mm256_loadu_ps(y + i + 16);
+        y3 = _mm256_loadu_ps(y + i + 24);
+
+       y0 = _mm256_add_ps(_mm256_mul_ps(x0, v4), y0);
+       y1 = _mm256_add_ps(_mm256_mul_ps(x1, v4), y1);
+       y2 = _mm256_add_ps(_mm256_mul_ps(x2, v4), y2);
+       y3 = _mm256_add_ps(_mm256_mul_ps(x3, v4), y3);
+
+        _mm256_storeu_ps(y + i + 0, y0);
+        _mm256_storeu_ps(y + i + 8, y1);
+        _mm256_storeu_ps(y + i + 16, y2);
+        _mm256_storeu_ps(y + i + 24, y3);
+    }
+
     // leftovers
     for (int i = n32; i < n; ++i) {
         y[i] += x[i]*v;
@@ -859,6 +988,42 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
         _mm_storeu_si128((__m128i*)(y + i + 24), _mm256_cvtps_ph(y3, 0));
     }
 
+    // leftovers
+    for (int i = n32; i < n; ++i) {
+        GGML_ASSERT(false);
+        y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);
+    }
+#elif defined(__AVX__)
+    // AVX 256-bit
+    const int n32 = (n & ~31);
+
+    const __m256 v8 = _mm256_set1_ps(v);
+
+    __m256 x0, x1, x2, x3;
+    __m256 y0, y1, y2, y3;
+
+    for (int i = 0; i < n32; i += 32) {
+        y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 )));
+        y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 )));
+        y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16)));
+        y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24)));
+
+        x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 )));
+        x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 )));
+        x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16)));
+        x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24)));
+
+       y0 = _mm256_add_ps(_mm256_mul_ps(x0, v8), y0);
+       y1 = _mm256_add_ps(_mm256_mul_ps(x1, v8), y1);
+       y2 = _mm256_add_ps(_mm256_mul_ps(x2, v8), y2);
+       y3 = _mm256_add_ps(_mm256_mul_ps(x3, v8), y3);
+
+        _mm_storeu_si128((__m128i*)(y + i + 0 ), _mm256_cvtps_ph(y0, 0));
+        _mm_storeu_si128((__m128i*)(y + i + 8 ), _mm256_cvtps_ph(y1, 0));
+        _mm_storeu_si128((__m128i*)(y + i + 16), _mm256_cvtps_ph(y2, 0));
+        _mm_storeu_si128((__m128i*)(y + i + 24), _mm256_cvtps_ph(y3, 0));
+    }
+
     // leftovers
     for (int i = n32; i < n; ++i) {
         GGML_ASSERT(false);
@@ -8081,6 +8246,14 @@ enum ggml_opt_result ggml_opt(
 
 ////////////////////////////////////////////////////////////////////////////////
 
+int ggml_cpu_has_avx(void) {
+#if defined(__AVX__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
 int ggml_cpu_has_avx2(void) {
 #if defined(__AVX2__)
     return 1;