]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
sync : latest changes from whisper.cpp
authorGeorgi Gerganov <redacted>
Tue, 1 Nov 2022 20:13:15 +0000 (22:13 +0200)
committerGeorgi Gerganov <redacted>
Tue, 1 Nov 2022 20:13:15 +0000 (22:13 +0200)
README.md
examples/whisper/main.cpp
examples/whisper/whisper.cpp
examples/whisper/whisper.h
include/ggml/ggml.h
src/ggml.c
src/msvc_thread_atomic.h [deleted file]

index fc065c0a00b89be4b89fe9809b69e158c65ace02..8c3f862a7194b379dbceb2be476844299c879eb3 100644 (file)
--- a/README.md
+++ b/README.md
@@ -13,7 +13,9 @@ Tensor library for machine learning
 - No third-party dependencies
 - Zero memory allocations during runtime
 
-*Note that this project is under development and not ready for production use*
+***Note that this project is under development and not ready for production use.
+More active development is happening in the ***[whisper.cpp](https://github.com/ggerganov/whisper.cpp) ***repo
+so if you are interested in this project, make sure to follow what is happening there***
 
 ## Whisper inference (example)
 
index 995eefc18e73a8c44301033cbbf3f5013ef4d242..1be0032e0da6dcddfedbadacbc47ff4041442398 100644 (file)
@@ -5,15 +5,23 @@
 #define DR_WAV_IMPLEMENTATION
 #include "dr_wav.h"
 
+#include <cmath>
 #include <fstream>
 #include <cstdio>
 #include <string>
 #include <thread>
 #include <vector>
 
+// Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9]
+// Lowest is red, middle is yellow, highest is green.
+const std::vector<std::string> k_colors = {
+    "\033[38;5;196m", "\033[38;5;202m", "\033[38;5;208m", "\033[38;5;214m", "\033[38;5;220m",
+    "\033[38;5;226m", "\033[38;5;190m", "\033[38;5;154m", "\033[38;5;118m", "\033[38;5;82m",
+};
+
 //  500 -> 00:05.000
 // 6000 -> 01:00.000
-std::string to_timestamp(int64_t t) {
+std::string to_timestamp(int64_t t, bool comma = false) {
     int64_t msec = t * 10;
     int64_t hr = msec / (1000 * 60 * 60);
     msec = msec - hr * (1000 * 60 * 60);
@@ -23,23 +31,64 @@ std::string to_timestamp(int64_t t) {
     msec = msec - sec * 1000;
 
     char buf[32];
-    snprintf(buf, sizeof(buf), "%02d:%02d:%02d.%03d", (int) hr, (int) min, (int) sec, (int) msec);
+    snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec);
 
     return std::string(buf);
 }
 
+void replace_all(std::string & s, const std::string & search, const std::string & replace) {
+    for (size_t pos = 0; ; pos += replace.length()) {
+        pos = s.find(search, pos);
+        if (pos == std::string::npos) break;
+        s.erase(pos, search.length());
+        s.insert(pos, replace);
+    }
+}
+
+// a cost-function that is high for text that takes longer to pronounce
+float voice_length(const std::string & text) {
+    float res = 0.0f;
+
+    for (size_t i = 0; i < text.size(); ++i) {
+        if (text[i] == ' ') {
+            res += 0.01f;
+        } else if (text[i] == ',') {
+            res += 2.00f;
+        } else if (text[i] == '.') {
+            res += 3.00f;
+        } else if (text[i] == '!') {
+            res += 3.00f;
+        } else if (text[i] == '?') {
+            res += 3.00f;
+        } else if (text[i] >= '0' && text[i] <= '9') {
+            res += 3.00f;
+        } else {
+            res += 1.00f;
+        }
+    }
+
+    return res;
+}
+
 // command-line parameters
 struct whisper_params {
-    int32_t seed      = -1; // RNG seed, not used currently
-    int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
-    int32_t offset_ms = 0;
+    int32_t seed         = -1; // RNG seed, not used currently
+    int32_t n_threads    = std::min(4, (int32_t) std::thread::hardware_concurrency());
+    int32_t n_processors = 1;
+    int32_t offset_t_ms  = 0;
+    int32_t offset_n     = 0;
+    int32_t max_context  = -1;
+
+    float word_thold = 0.01f;
 
     bool verbose              = false;
     bool translate            = false;
     bool output_txt           = false;
     bool output_vtt           = false;
     bool output_srt           = false;
+    bool output_wts           = false;
     bool print_special_tokens = false;
+    bool print_colors         = false;
     bool no_timestamps        = false;
 
     std::string language  = "en";
@@ -63,8 +112,16 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
             params.seed = std::stoi(argv[++i]);
         } else if (arg == "-t" || arg == "--threads") {
             params.n_threads = std::stoi(argv[++i]);
-        } else if (arg == "-o" || arg == "--offset") {
-            params.offset_ms = std::stoi(argv[++i]);
+        } else if (arg == "-p" || arg == "--processors") {
+            params.n_processors = std::stoi(argv[++i]);
+        } else if (arg == "-ot" || arg == "--offset-t") {
+            params.offset_t_ms = std::stoi(argv[++i]);
+        } else if (arg == "-on" || arg == "--offset-n") {
+            params.offset_n = std::stoi(argv[++i]);
+        } else if (arg == "-mc" || arg == "--max-context") {
+            params.max_context = std::stoi(argv[++i]);
+        } else if (arg == "-wt" || arg == "--word-thold") {
+            params.word_thold = std::stof(argv[++i]);
         } else if (arg == "-v" || arg == "--verbose") {
             params.verbose = true;
         } else if (arg == "--translate") {
@@ -82,8 +139,12 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
             params.output_vtt = true;
         } else if (arg == "-osrt" || arg == "--output-srt") {
             params.output_srt = true;
+        } else if (arg == "-owts" || arg == "--output-words") {
+            params.output_wts = true;
         } else if (arg == "-ps" || arg == "--print_special") {
             params.print_special_tokens = true;
+        } else if (arg == "-pc" || arg == "--print_colors") {
+            params.print_colors = true;
         } else if (arg == "-nt" || arg == "--no_timestamps") {
             params.no_timestamps = true;
         } else if (arg == "-m" || arg == "--model") {
@@ -111,13 +172,19 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
     fprintf(stderr, "  -h,       --help           show this help message and exit\n");
     fprintf(stderr, "  -s SEED,  --seed SEED      RNG seed (default: -1)\n");
     fprintf(stderr, "  -t N,     --threads N      number of threads to use during computation (default: %d)\n", params.n_threads);
-    fprintf(stderr, "  -o N,     --offset N       offset in milliseconds (default: %d)\n", params.offset_ms);
+    fprintf(stderr, "  -p N,     --processors N   number of processors to use during computation (default: %d)\n", params.n_processors);
+    fprintf(stderr, "  -ot N,    --offset-t N     time offset in milliseconds (default: %d)\n", params.offset_t_ms);
+    fprintf(stderr, "  -on N,    --offset-n N     segment index offset (default: %d)\n", params.offset_n);
+    fprintf(stderr, "  -mc N,    --max-context N  maximum number of text context tokens to store (default: max)\n");
+    fprintf(stderr, "  -wt N,    --word-thold N   word timestamp probability threshold (default: %f)\n", params.word_thold);
     fprintf(stderr, "  -v,       --verbose        verbose output\n");
     fprintf(stderr, "            --translate      translate from source language to english\n");
     fprintf(stderr, "  -otxt,    --output-txt     output result in a text file\n");
     fprintf(stderr, "  -ovtt,    --output-vtt     output result in a vtt file\n");
     fprintf(stderr, "  -osrt,    --output-srt     output result in a srt file\n");
+    fprintf(stderr, "  -owts,    --output-words   output word-level timestamps to a text file\n");
     fprintf(stderr, "  -ps,      --print_special  print special tokens\n");
+    fprintf(stderr, "  -pc,      --print_colors   print colors\n");
     fprintf(stderr, "  -nt,      --no_timestamps  do not print timestamps\n");
     fprintf(stderr, "  -l LANG,  --language LANG  spoken language (default: %s)\n", params.language.c_str());
     fprintf(stderr, "  -m FNAME, --model FNAME    model path (default: %s)\n", params.model.c_str());
@@ -125,6 +192,505 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
     fprintf(stderr, "\n");
 }
 
+void whisper_print_segment_callback(struct whisper_context * ctx, void * user_data) {
+    const whisper_params & params = *(whisper_params *) user_data;
+
+    const int n_segments = whisper_full_n_segments(ctx);
+
+    // print the last segment
+    const int i = n_segments - 1;
+    if (i == 0) {
+        printf("\n");
+    }
+
+    if (params.no_timestamps) {
+        if (params.print_colors) {
+            for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
+                if (params.print_special_tokens == false) {
+                    const whisper_token id = whisper_full_get_token_id(ctx, i, j);
+                    if (id >= whisper_token_eot(ctx)) {
+                        continue;
+                    }
+                }
+
+                const char * text = whisper_full_get_token_text(ctx, i, j);
+                const float  p    = whisper_full_get_token_p   (ctx, i, j);
+
+                const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
+
+                printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
+            }
+        } else {
+            const char * text = whisper_full_get_segment_text(ctx, i);
+            printf("%s", text);
+        }
+        fflush(stdout);
+    } else {
+        const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
+        const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+
+        if (params.print_colors) {
+            printf("[%s --> %s]  ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
+            for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
+                if (params.print_special_tokens == false) {
+                    const whisper_token id = whisper_full_get_token_id(ctx, i, j);
+                    if (id >= whisper_token_eot(ctx)) {
+                        continue;
+                    }
+                }
+
+                const char * text = whisper_full_get_token_text(ctx, i, j);
+                const float  p    = whisper_full_get_token_p   (ctx, i, j);
+
+                const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
+
+                printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
+            }
+            printf("\n");
+        } else {
+            const char * text = whisper_full_get_segment_text(ctx, i);
+
+            printf("[%s --> %s]  %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
+        }
+    }
+}
+
+bool output_txt(struct whisper_context * ctx, const char * fname) {
+    std::ofstream fout(fname);
+    if (!fout.is_open()) {
+        fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
+        return false;
+    }
+
+    fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
+
+    const int n_segments = whisper_full_n_segments(ctx);
+    for (int i = 0; i < n_segments; ++i) {
+        const char * text = whisper_full_get_segment_text(ctx, i);
+        fout << text;
+    }
+
+    return true;
+}
+
+bool output_vtt(struct whisper_context * ctx, const char * fname) {
+    std::ofstream fout(fname);
+    if (!fout.is_open()) {
+        fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
+        return 9;
+    }
+
+    fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
+
+    fout << "WEBVTT\n\n";
+
+    const int n_segments = whisper_full_n_segments(ctx);
+    for (int i = 0; i < n_segments; ++i) {
+        const char * text = whisper_full_get_segment_text(ctx, i);
+        const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
+        const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+
+        fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
+        fout << text << "\n\n";
+    }
+
+    return true;
+}
+
+bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params) {
+    std::ofstream fout(fname);
+    if (!fout.is_open()) {
+        fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
+        return false;
+    }
+
+    fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
+
+    const int n_segments = whisper_full_n_segments(ctx);
+    for (int i = 0; i < n_segments; ++i) {
+        const char * text = whisper_full_get_segment_text(ctx, i);
+        const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
+        const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+
+        fout << i + 1 + params.offset_n << "\n";
+        fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n";
+        fout << text << "\n\n";
+    }
+
+    return true;
+}
+
+// word-level timestamps (experimental)
+// TODO: probably still has bugs, needs refactoring, etc..
+// TODO: auto threshold
+// TODO: extra pass to detect unused speech and assign to tokens
+// TODO: font parameter adjustments
+bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, const std::vector<float> & pcmf32) {
+    if (params.output_wts) {
+        std::vector<float> pcm_avg(pcmf32.size(), 0);
+
+        // average the fabs of the signal
+        {
+            const int hw = 32;
+
+            for (int i = 0; i < pcmf32.size(); i++) {
+                float sum = 0;
+                for (int j = -hw; j <= hw; j++) {
+                    if (i + j >= 0 && i + j < pcmf32.size()) {
+                        sum += fabs(pcmf32[i + j]);
+                    }
+                }
+                pcm_avg[i] = sum/(2*hw + 1);
+            }
+        }
+
+        struct token_info {
+            int64_t t0 = -1;
+            int64_t t1 = -1;
+
+            int64_t tt0 = -1;
+            int64_t tt1 = -1;
+
+            whisper_token id;
+            whisper_token tid;
+
+            float p     = 0.0f;
+            float pt    = 0.0f;
+            float ptsum = 0.0f;
+
+            std::string text;
+            float vlen = 0.0f; // voice length of this token
+        };
+
+        int64_t t_beg  = 0;
+        int64_t t_last = 0;
+
+        whisper_token tid_last = 0;
+
+        std::ofstream fout(fname);
+
+        fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
+
+        fout << "!/bin/bash" << "\n";
+        fout << "\n";
+
+        fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE << ":rate=25:color=black -vf \"";
+
+        bool is_first = true;
+
+        for (int i = 0; i < whisper_full_n_segments(ctx); i++) {
+            const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
+            const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+
+            const char *text = whisper_full_get_segment_text(ctx, i);
+
+            const int s0 = std::max(0,                   (int) (t0*WHISPER_SAMPLE_RATE/100));
+            const int s1 = std::min((int) pcmf32.size(), (int) (t1*WHISPER_SAMPLE_RATE/100));
+
+            const int n = whisper_full_n_tokens(ctx, i);
+
+            std::vector<token_info> tokens(n);
+
+            if (n <= 1) {
+                continue;
+            }
+
+            for (int j = 0; j < n; ++j) {
+                struct whisper_token_data token = whisper_full_get_token_data(ctx, i, j);
+
+                if (j == 0) {
+                    if (token.id == whisper_token_beg(ctx)) {
+                        tokens[j    ].t0 = t0;
+                        tokens[j    ].t1 = t0;
+                        tokens[j + 1].t0 = t0;
+
+                        t_beg  = t0;
+                        t_last = t0;
+                        tid_last = whisper_token_beg(ctx);
+                    } else {
+                        tokens[j    ].t0 = t_last;
+                    }
+                }
+
+                const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx));
+
+                tokens[j].id    = token.id;
+                tokens[j].tid   = token.tid;
+                tokens[j].p     = token.p;
+                tokens[j].pt    = token.pt;
+                tokens[j].ptsum = token.ptsum;
+
+                tokens[j].text = whisper_token_to_str(ctx, token.id);
+                //tokens[j].vlen = tokens[j].pt;
+                tokens[j].vlen = voice_length(tokens[j].text);
+
+                if (token.pt > params.word_thold && token.ptsum > 0.01 && token.tid > tid_last && tt <= t1) {
+                    if (j > 0) {
+                        tokens[j - 1].t1 = tt;
+                    }
+                    tokens[j].t0 = tt;
+                    tid_last = token.tid;
+                }
+            }
+
+            tokens[n - 2].t1 = t1;
+            tokens[n - 1].t0 = t1;
+            tokens[n - 1].t1 = t1;
+
+            t_last = t1;
+
+            int p0 = 0;
+            int p1 = 0;
+            while (true) {
+                while (p1 < n && tokens[p1].t1 < 0) {
+                    p1++;
+                }
+
+                if (p1 >= n) {
+                    p1--;
+                }
+
+                if (p1 > p0) {
+                    double psum = 0.0;
+                    for (int j = p0; j <= p1; j++) {
+                        psum += tokens[j].vlen;
+                    }
+
+                    //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
+
+                    const double dt = tokens[p1].t1 - tokens[p0].t0;
+
+                    for (int j = p0 + 1; j <= p1; j++) {
+                        const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum;
+                        //const double ct = tokens[j - 1].t0 + (dt*(j - p0))/(p1 - p0 + 1);
+                        //const double ct = tokens[p0].t0 + (dt*(j - p0))/(p1 - p0 + 1);
+
+                        tokens[j - 1].t1 = ct;
+                        tokens[j    ].t0 = ct;
+                    }
+                }
+
+                p1++;
+                p0 = p1;
+                if (p1 >= n) {
+                    break;
+                }
+            }
+
+            for (int j = 0; j < n - 1; j++) {
+                if (tokens[j].t1 < 0) {
+                    tokens[j + 1].t0 = tokens[j].t1;
+                }
+
+                if (j > 0) {
+                    if (tokens[j - 1].t1 > tokens[j].t0) {
+                        tokens[j].t0 = tokens[j - 1].t1;
+                        tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1);
+                    }
+                }
+
+                tokens[j].tt0 = tokens[j].t0;
+                tokens[j].tt1 = tokens[j].t1;
+            }
+
+            // VAD
+            {
+                const int hw = WHISPER_SAMPLE_RATE/8;
+
+                for (int j = 0; j < n; j++) {
+                    if (tokens[j].id >= whisper_token_eot(ctx)) {
+                        continue;
+                    }
+
+                    const int64_t t0 = tokens[j].t0;
+                    const int64_t t1 = tokens[j].t1;
+
+                    int s0 = std::max(0,                        (int) (t0*WHISPER_SAMPLE_RATE/100));
+                    int s1 = std::min((int) pcmf32.size() - 1,  (int) (t1*WHISPER_SAMPLE_RATE/100));
+
+                    const int ss0 = std::max(0,                       (int) (t0*WHISPER_SAMPLE_RATE/100) - hw);
+                    const int ss1 = std::min((int) pcmf32.size() - 1, (int) (t1*WHISPER_SAMPLE_RATE/100) + hw);
+
+                    const int n = ss1 - ss0;
+
+                    float sum = 0.0f;
+
+                    for (int k = ss0; k < ss1; k++) {
+                        sum += pcm_avg[k];
+                    }
+
+                    const float thold = 0.5*sum/n;
+
+                    {
+                        int k = s0;
+                        if (pcm_avg[k] > thold && j > 0) {
+                            while (k > 0 && pcm_avg[k] > thold) {
+                                k--;
+                            }
+                            tokens[j].t0 = (int64_t) (100*k/WHISPER_SAMPLE_RATE);
+                            if (tokens[j].t0 < tokens[j - 1].t1) {
+                                tokens[j].t0 = tokens[j - 1].t1;
+                            } else {
+                                s0 = k;
+                            }
+                        } else {
+                            while (pcm_avg[k] < thold && k < s1) {
+                                k++;
+                            }
+                            s0 = k;
+                            tokens[j].t0 = 100*k/WHISPER_SAMPLE_RATE;
+                        }
+                    }
+
+                    {
+                        int k = s1;
+                        if (pcm_avg[k] > thold) {
+                            while (k < (int) pcmf32.size() - 1 && pcm_avg[k] > thold) {
+                                k++;
+                            }
+                            tokens[j].t1 = 100*k/WHISPER_SAMPLE_RATE;
+                            if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) {
+                                tokens[j].t1 = tokens[j + 1].t0;
+                            } else {
+                                s1 = k;
+                            }
+                        } else {
+                            while (pcm_avg[k] < thold && k > s0) {
+                                k--;
+                            }
+                            s1 = k;
+                            tokens[j].t1 = 100*k/WHISPER_SAMPLE_RATE;
+                        }
+                    }
+                }
+            }
+
+            const int t_expand = 0;
+
+            for (int j = 0; j < n; j++) {
+                if (j > 0) {
+                    tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
+                }
+                if (j < n - 1) {
+                    tokens[j].t1 = tokens[j].t1 + t_expand;
+                }
+            }
+
+            for (int j = 0; j < n; ++j) {
+                const auto & token = tokens[j];
+                const auto tt = token.pt > params.word_thold && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]";
+                printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
+                        tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, token.text.c_str());
+
+                if (tokens[j].id >= whisper_token_eot(ctx)) {
+                    continue;
+                }
+
+                //printf("[%s --> %s] %s\n", to_timestamp(token.t0).c_str(), to_timestamp(token.t1).c_str(), whisper_token_to_str(ctx, token.id));
+
+                //fout << "# " << to_timestamp(token.t0) << " --> " << to_timestamp(token.t1) << " " << whisper_token_to_str(ctx, token.id) << "\n";
+            }
+
+            static const int line_wrap = 60;
+            static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
+
+            if (!is_first) {
+                fout << ",";
+            }
+
+            // background text
+            fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'";
+
+            is_first = false;
+
+            for (int j = 0; j < n; ++j) {
+                const auto & token = tokens[j];
+
+                if (tokens[j].id >= whisper_token_eot(ctx)) {
+                    continue;
+                }
+
+                std::string txt_bg;
+                std::string txt_fg; // highlight token
+                std::string txt_ul; // underline
+
+                txt_bg = "> ";
+                txt_fg = "> ";
+                txt_ul = "\\ \\ ";
+
+                {
+                    int ncnt = 0;
+                    for (int k = 0; k < n; ++k) {
+                        const auto & token2 = tokens[k];
+
+                        if (tokens[k].id >= whisper_token_eot(ctx)) {
+                            continue;
+                        }
+
+                        const std::string txt = whisper_token_to_str(ctx, token2.id);
+
+                        txt_bg += txt;
+
+                        if (k == j) {
+                            for (int l = 0; l < (int) txt.size(); ++l) {
+                                txt_fg += txt[l];
+                                txt_ul += "_";
+                            }
+                            txt_fg += "|";
+                        } else {
+                            for (int l = 0; l < (int) txt.size(); ++l) {
+                                txt_fg += "\\ ";
+                                txt_ul += "\\ ";
+                            }
+                        }
+
+                        ncnt += txt.size();
+
+                        if (ncnt > line_wrap) {
+                            if (k < j) {
+                                txt_bg = "> ";
+                                txt_fg = "> ";
+                                txt_ul = "\\ \\ ";
+                                ncnt = 0;
+                            } else {
+                                break;
+                            }
+                        }
+                    }
+
+                    ::replace_all(txt_bg, "'", "’");
+                    ::replace_all(txt_bg, "\"", "\\\"");
+                    ::replace_all(txt_fg, "'", "’");
+                    ::replace_all(txt_fg, "\"", "\\\"");
+                }
+
+                // background text
+                fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='" << txt_bg << "':enable='between(t," << token.tt0/100.0 << "," << token.tt1/100.0 << ")'";
+
+                // foreground text
+                fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2:text='" << txt_fg << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'";
+
+                // underline
+                fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2+16:text='" << txt_ul << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'";
+            }
+        }
+
+        fout << "\" -c:v libx264 -pix_fmt yuv420p -y " << fname_inp << ".mp4" << "\n";
+
+        fout << "\n\n";
+        fout << "echo \"Your video has been saved to " << fname_inp << ".mp4\"" << "\n";
+        fout << "\n";
+        fout << "echo \"  ffplay " << fname_inp << ".mp4\"\n";
+        fout << "\n";
+
+        fout.close();
+
+        fprintf(stderr, "%s: run 'source %s' to generate karaoke video\n", __func__, fname);
+    }
+
+    return true;
+}
+
 int main(int argc, char ** argv) {
     whisper_params params;
 
@@ -146,6 +712,11 @@ int main(int argc, char ** argv) {
 
     struct whisper_context * ctx = whisper_init(params.model.c_str());
 
+    if (ctx == nullptr) {
+        fprintf(stderr, "error: failed to initialize whisper context\n");
+        return 3;
+    }
+
     for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
         const auto fname_inp = params.fname_inp[f];
 
@@ -156,22 +727,22 @@ int main(int argc, char ** argv) {
             if (!drwav_init_file(&wav, fname_inp.c_str(), NULL)) {
                 fprintf(stderr, "%s: failed to open WAV file '%s' - check your input\n", argv[0], fname_inp.c_str());
                 whisper_print_usage(argc, argv, {});
-                return 3;
+                return 4;
             }
 
             if (wav.channels != 1 && wav.channels != 2) {
                 fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
-                return 4;
+                return 5;
             }
 
             if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
                 fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
-                return 5;
+                return 6;
             }
 
             if (wav.bitsPerSample != 16) {
                 fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
-                return 6;
+                return 7;
             }
 
             int n = wav.totalPCMFrameCount;
@@ -194,6 +765,13 @@ int main(int argc, char ** argv) {
             }
         }
 
+        // print system information
+        {
+            fprintf(stderr, "\n");
+            fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
+                    params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info());
+        }
+
         // print some info about the processing
         {
             fprintf(stderr, "\n");
@@ -204,8 +782,9 @@ int main(int argc, char ** argv) {
                     fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
                 }
             }
-            fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
-                    __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, params.n_threads,
+            fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
+                    __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
+                    params.n_threads, params.n_processors,
                     params.language.c_str(),
                     params.translate ? "translate" : "transcribe",
                     params.no_timestamps ? 0 : 1);
@@ -218,108 +797,54 @@ int main(int argc, char ** argv) {
         {
             whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
 
-            wparams.print_realtime       = true;
+            wparams.print_realtime       = false;
             wparams.print_progress       = false;
             wparams.print_timestamps     = !params.no_timestamps;
             wparams.print_special_tokens = params.print_special_tokens;
             wparams.translate            = params.translate;
             wparams.language             = params.language.c_str();
             wparams.n_threads            = params.n_threads;
-            wparams.offset_ms            = params.offset_ms;
-
-            if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
-                fprintf(stderr, "%s: failed to process audio\n", argv[0]);
-                return 7;
-            }
+            wparams.n_max_text_ctx       = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
+            wparams.offset_ms            = params.offset_t_ms;
 
-            // print result
+            // this callback is called on each new segment
             if (!wparams.print_realtime) {
-                printf("\n");
-
-                const int n_segments = whisper_full_n_segments(ctx);
-                for (int i = 0; i < n_segments; ++i) {
-                    const char * text = whisper_full_get_segment_text(ctx, i);
-
-                    if (params.no_timestamps) {
-                        printf("%s", text);
-                        fflush(stdout);
-                    } else {
-                        const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
-                        const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+                wparams.new_segment_callback           = whisper_print_segment_callback;
+                wparams.new_segment_callback_user_data = &params;
+            }
 
-                        printf("[%s --> %s]  %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
-                    }
-                }
+            if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
+                fprintf(stderr, "%s: failed to process audio\n", argv[0]);
+                return 8;
             }
+        }
 
+        // output stuff
+        {
             printf("\n");
 
             // output to text file
             if (params.output_txt) {
-
                 const auto fname_txt = fname_inp + ".txt";
-                std::ofstream fout_txt(fname_txt);
-                if (!fout_txt.is_open()) {
-                    fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_txt.c_str());
-                    return 8;
-                }
-
-                fprintf(stderr, "%s: saving output to '%s.txt'\n", __func__, fname_inp.c_str());
-
-                const int n_segments = whisper_full_n_segments(ctx);
-                for (int i = 0; i < n_segments; ++i) {
-                    const char * text = whisper_full_get_segment_text(ctx, i);
-                    fout_txt << text;
-                }
+                output_txt(ctx, fname_txt.c_str());
             }
 
             // output to VTT file
             if (params.output_vtt) {
-
                 const auto fname_vtt = fname_inp + ".vtt";
-                std::ofstream fout_vtt(fname_vtt);
-                if (!fout_vtt.is_open()) {
-                    fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_vtt.c_str());
-                    return 9;
-                }
-
-                fprintf(stderr, "%s: saving output to '%s.vtt'\n", __func__, fname_inp.c_str());
-
-                fout_vtt << "WEBVTT\n\n";
-
-                const int n_segments = whisper_full_n_segments(ctx);
-                for (int i = 0; i < n_segments; ++i) {
-                    const char * text = whisper_full_get_segment_text(ctx, i);
-                    const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
-                    const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
-
-                    fout_vtt << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
-                    fout_vtt << text << "\n\n";
-                }
+                output_vtt(ctx, fname_vtt.c_str());
             }
 
             // output to SRT file
             if (params.output_srt) {
-
                 const auto fname_srt = fname_inp + ".srt";
-                std::ofstream fout_srt(fname_srt);
-                if (!fout_srt.is_open()) {
-                    fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_srt.c_str());
-                    return 10;
-                }
-
-                fprintf(stderr, "%s: saving output to '%s.srt'\n", __func__, fname_inp.c_str());
-
-                const int n_segments = whisper_full_n_segments(ctx);
-                for (int i = 0; i < n_segments; ++i) {
-                    const char * text = whisper_full_get_segment_text(ctx, i);
-                    const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
-                    const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+                output_srt(ctx, fname_srt.c_str(), params);
+            }
 
-                    fout_srt << i + 1 << "\n";
-                    fout_srt << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
-                    fout_srt << text << "\n\n";
-                }
+            // output to WTS file
+            if (params.output_wts) {
+                const auto fname_wts = fname_inp + ".wts";
+                output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, pcmf32);
             }
         }
     }
index 2d2b8cedcb173644b4153dba698911dfac216a75..7f2b49b893ca7dad24c6f12e08724da912c05517 100644 (file)
@@ -1,3 +1,4 @@
+#define WHISPER_BUILD
 #include "whisper.h"
 
 #include "ggml.h"
@@ -210,16 +211,13 @@ struct whisper_vocab {
     }
 };
 
-struct whisper_result {
-    int64_t t;
-    whisper_token id;
-};
-
 struct whisper_segment {
     int64_t t0;
     int64_t t1;
 
     std::string text;
+
+    std::vector<whisper_token_data> tokens;
 };
 
 // medium
@@ -379,8 +377,12 @@ struct whisper_model {
     struct ggml_tensor * memory_cross_k;
     struct ggml_tensor * memory_cross_v;
 
-    //
+    // context
     struct ggml_context * ctx;
+    struct ggml_context * ctx_mem;
+
+    // tensors
+    int n_loaded;
     std::map<std::string, struct ggml_tensor *> tensors;
 };
 
@@ -392,9 +394,10 @@ struct whisper_context {
     int64_t t_decode_us = 0;
     int64_t t_start_us  = 0;
 
-    std::vector<uint8_t> buf_model;
-    std::vector<uint8_t> buf_compute;
-    std::vector<uint8_t> buf_compute_layer;
+    std::vector<uint8_t> * buf_model; // the model buffer is read-only and can be shared between processors
+    std::vector<uint8_t>   buf_memory;
+    std::vector<uint8_t>   buf_compute;
+    std::vector<uint8_t>   buf_compute_layer;
 
     whisper_model model;
     whisper_vocab vocab;
@@ -404,7 +407,6 @@ struct whisper_context {
     std::vector<float> probs;
     std::vector<float> logits;
 
-    std::vector<whisper_result>  result_cur;
     std::vector<whisper_segment> result_all;
 
     std::vector<whisper_token> prompt_past;
@@ -494,13 +496,16 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
         fprintf(stderr, "%s: f16           = %d\n", __func__, hparams.f16);
         fprintf(stderr, "%s: type          = %d\n", __func__, model.type);
 
-        wctx.buf_model.resize(MEM_REQ_MODEL.at(model.type));
+        wctx.buf_model = new std::vector<uint8_t>();
+        wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type));
+        wctx.buf_memory.resize(std::max(MEM_REQ_MODEL.at(model.type), MEM_REQ_MODEL.at(model.type))); // TODO: TMP !!!
         wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
         wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
 
         // this is the total memory required to run the inference
         const size_t mem_required =
-                   wctx.buf_model.size() +
+                   wctx.buf_model->size() +
+                   wctx.buf_memory.size() +
                    wctx.buf_compute.size() +
                    wctx.buf_compute_layer.size();
 
@@ -583,6 +588,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
 
 
     size_t ctx_size = 0;
+    size_t ctx_mem_size = 0;
 
     {
         const auto & hparams = model.hparams;
@@ -691,11 +697,11 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
             ctx_size += n_text_layer*(             n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
         }
 
-        ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_k
-        ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_v
+        ctx_mem_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_k
+        ctx_mem_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_v
 
-        ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_k
-        ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_v
+        ctx_mem_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_k
+        ctx_mem_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_v
 
         ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
 
@@ -705,8 +711,8 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
     // create the ggml context
     {
         struct ggml_init_params params = {
-            .mem_size   = wctx.buf_model.size(),
-            .mem_buffer = wctx.buf_model.data(),
+            .mem_size   = wctx.buf_model->size(),
+            .mem_buffer = wctx.buf_model->data(),
         };
 
         model.ctx = ggml_init(params);
@@ -716,6 +722,20 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
         }
     }
 
+    // create the ggml memory context
+    {
+        struct ggml_init_params params = {
+            .mem_size   = wctx.buf_memory.size(),
+            .mem_buffer = wctx.buf_memory.data(),
+        };
+
+        model.ctx_mem = ggml_init(params);
+        if (!model.ctx_mem) {
+            fprintf(stderr, "%s: ggml_init() failed\n", __func__);
+            return false;
+        }
+    }
+
     // prepare memory for the weights
     {
         auto & ctx = model.ctx;
@@ -914,7 +934,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
 
     // key + value memory
     {
-        auto & ctx = model.ctx;
+        auto & ctx = model.ctx_mem;
 
         const auto & hparams = model.hparams;
 
@@ -946,14 +966,15 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
             ggml_nbytes(model.memory_k)       + ggml_nbytes(model.memory_v) +
             ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
 
-        fprintf(stderr, "%s: memory size = %8.2f MB \n", __func__, memory_size/1024.0/1024.0);
+        fprintf(stderr, "%s: memory size = %8.2f MB\n", __func__, memory_size/1024.0/1024.0);
     }
 
     // load weights
     {
-        int n_loaded = 0;
         size_t total_size = 0;
 
+        model.n_loaded = 0;
+
         while (true) {
             int32_t n_dims;
             int32_t length;
@@ -1006,15 +1027,15 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
 
             //printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
             total_size += ggml_nbytes(tensor);
-            n_loaded++;
+            model.n_loaded++;
         }
 
         fprintf(stderr, "%s: model size  = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
 
-        if (n_loaded == 0) {
+        if (model.n_loaded == 0) {
             fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
-        } else if (n_loaded != (int) model.tensors.size()) {
-            fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), n_loaded);
+        } else if (model.n_loaded != (int) model.tensors.size()) {
+            fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
             return false;
         }
     }
@@ -1782,9 +1803,11 @@ bool whisper_decode(
 }
 
 // the most basic sampling scheme - select the top token
-whisper_vocab::id whisper_sample_best(
+whisper_token_data whisper_sample_best(
         const whisper_vocab & vocab,
         const float * probs) {
+    whisper_token_data result;
+
     int n_logits = vocab.id_to_token.size();
 
     std::vector<std::pair<double, whisper_vocab::id>> probs_id;
@@ -1794,24 +1817,34 @@ whisper_vocab::id whisper_sample_best(
         probs_id.push_back(std::make_pair(probs[i], i));
     }
 
-    double sum_ts = 0.0;
-    double max_tx = 0.0;
+    {
+        double sum_ts =  0.0;
+        double max_ts = -1.0;
+        double max_tx = -1.0;
 
-    for (int i = 0; i < vocab.token_beg; i++) {
-        max_tx = std::max(max_tx, probs_id[i].first);
-    }
+        for (int i = 0; i < vocab.token_beg; i++) {
+            max_tx = std::max(max_tx, probs_id[i].first);
+        }
 
-    for (int i = vocab.token_beg; i < n_logits; i++) {
-        sum_ts += probs_id[i].first;
-    }
+        for (int i = vocab.token_beg; i < n_logits; i++) {
+            sum_ts += probs_id[i].first;
+            if  (probs_id[i].first > max_ts) {
+                max_ts = probs_id[i].first;
+                result.tid = probs_id[i].second;
+            }
+        }
 
-    // if the probability sum of all timestamp tokesn is higher than the max probability of the text tokens - sample a
-    // timestamp token
-    if (sum_ts > max_tx) {
-        // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
-        for (int i = 0; i < vocab.token_beg; i++) {
-            probs_id[i].first = -INFINITY;
+        // if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a
+        // timestamp token
+        if (sum_ts > max_tx) {
+            // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
+            for (int i = 0; i < vocab.token_beg; i++) {
+                probs_id[i].first = -INFINITY;
+            }
         }
+
+        result.pt = max_ts/(sum_ts + 1e-10);
+        result.ptsum = sum_ts;
     }
 
     // find the top K tokens
@@ -1839,7 +1872,10 @@ whisper_vocab::id whisper_sample_best(
         res++;
     }
 
-    return probs_id[res].second;
+    result.id = probs_id[res].second;
+    result.p  = probs_id[res].first;
+
+    return result;
 }
 
 // samples only from the timestamps tokens
@@ -1875,14 +1911,19 @@ whisper_vocab::id whisper_sample_timestamp(
     return probs_id[0].second;
 }
 
-static std::string to_timestamp(int64_t t) {
-    int64_t sec = t/100;
-    int64_t msec = t - sec*100;
-    int64_t min = sec/60;
-    sec = sec - min*60;
+//  500 -> 00:05.000
+// 6000 -> 01:00.000
+static std::string to_timestamp(int64_t t, bool comma = false) {
+    int64_t msec = t * 10;
+    int64_t hr = msec / (1000 * 60 * 60);
+    msec = msec - hr * (1000 * 60 * 60);
+    int64_t min = msec / (1000 * 60);
+    msec = msec - min * (1000 * 60);
+    int64_t sec = msec / 1000;
+    msec = msec - sec * 1000;
 
     char buf[32];
-    snprintf(buf, sizeof(buf), "%02d:%02d.%03d", (int) min, (int) sec, (int) msec);
+    snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec);
 
     return std::string(buf);
 }
@@ -2104,6 +2145,9 @@ struct whisper_context * whisper_init(const char * path_model) {
 
 void whisper_free(struct whisper_context * ctx) {
     if (ctx) {
+        if (ctx->buf_model) {
+            delete ctx->buf_model;
+        }
         delete ctx;
     }
 }
@@ -2166,7 +2210,7 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
     return 0;
 }
 
-whisper_token whisper_sample_best(struct whisper_context * ctx) {
+struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) {
     const int64_t t_start_sample_us = ggml_time_us();
 
     // TODO: simplify
@@ -2277,6 +2321,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
                     /*.strategy             =*/ WHISPER_SAMPLING_GREEDY,
 
                     /*.n_threads            =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
+                    /*.n_max_text_ctx       =*/ 16384,
                     /*.offset_ms            =*/ 0,
 
                     /*.translate            =*/ false,
@@ -2297,6 +2342,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
                         /*.beam_width =*/ -1,
                         /*.n_best     =*/ -1,
                     },
+
+                    /*.new_segment_callback =*/ nullptr,
+                    /*.new_segment_callback_user_data =*/ nullptr,
                 };
             } break;
         case WHISPER_SAMPLING_BEAM_SEARCH:
@@ -2305,6 +2353,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
                     /*.strategy             =*/ WHISPER_SAMPLING_BEAM_SEARCH,
 
                     /*.n_threads            =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
+                    /*.n_max_text_ctx       =*/ 16384,
                     /*.offset_ms            =*/ 0,
 
                     /*.translate            =*/ false,
@@ -2325,6 +2374,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
                         /*.beam_width =*/ 10,
                         /*.n_best     =*/ 5,
                     },
+
+                    /*.new_segment_callback =*/ nullptr,
+                    /*.new_segment_callback_user_data =*/ nullptr,
                 };
             } break;
     }
@@ -2339,7 +2391,6 @@ int whisper_full(
         int n_samples) {
     // clear old results
     auto & result_all = ctx->result_all;
-    auto & result_cur = ctx->result_cur;
 
     result_all.clear();
 
@@ -2349,10 +2400,12 @@ int whisper_full(
         return -1;
     }
 
+    const int seek_start = params.offset_ms/10;
+
     // if length of spectrogram is less than 1s (100 samples), then return
     // basically don't process anything that is less than 1s
     // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
-    if (whisper_n_len(ctx) < 100) {
+    if (whisper_n_len(ctx) < 100 + seek_start) {
         return 0;
     }
 
@@ -2376,8 +2429,14 @@ int whisper_full(
     int progress_prev = 0;
     int progress_step = 5;
 
+    std::vector<whisper_token_data> tokens_cur;
+    tokens_cur.reserve(whisper_n_text_ctx(ctx));
+
+    std::vector<whisper_token> prompt;
+    prompt.reserve(whisper_n_text_ctx(ctx));
+
     // main loop
-    int seek = params.offset_ms/10;
+    int seek = seek_start;
     while (true) {
         int progress_cur = (100*seek)/whisper_n_len(ctx);
         while (progress_cur >= progress_prev + progress_step) {
@@ -2397,13 +2456,12 @@ int whisper_full(
             return 7;
         }
 
-        std::vector<whisper_token> prompt;
-
         int n_past = 0;
+        prompt.clear();
 
         // if we have already generated some text, use it as a prompt to condition the next generation
         if (prompt_past.size() > 0) {
-            int n_take = std::min(whisper_n_text_ctx(ctx)/2, int(prompt_past.size()));
+            int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
 
             prompt = { whisper_token_prev(ctx) };
             prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
@@ -2426,7 +2484,7 @@ int whisper_full(
 
         // the accumulated transcription in the current interation
         int result_len = 0;
-        result_cur.clear();
+        tokens_cur.clear();
 
         for (int i = 0; i < whisper_n_text_ctx(ctx)/2 - 4; ++i) {
             if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
@@ -2445,28 +2503,29 @@ int whisper_full(
             // feel free to experiment!
             //
             {
-                whisper_token id  = 0;
-                whisper_token tid = whisper_token_beg(ctx);
+                auto token = whisper_sample_best(ctx);
 
-                id = whisper_sample_best(ctx);
-                if (i > 0) {
-                    tid = whisper_sample_timestamp(ctx);
+                if (i == 0) {
+                    token.tid = whisper_token_beg(ctx);
                 }
 
-                // update sliding window
-                if (id > whisper_token_beg(ctx)) {
-                    seek_delta = 2*(id - whisper_token_beg(ctx));
+                // timestamp token - update sliding window
+                if (token.id > whisper_token_beg(ctx)) {
+                    seek_delta = 2*(token.id - whisper_token_beg(ctx));
                     result_len = i + 1;
                 }
 
                 // add it to the context
-                prompt.push_back(id);
-                result_cur.push_back({ seek + 2*(tid - whisper_token_beg(ctx)), id });
+                prompt.push_back(token.id);
+                tokens_cur.push_back(token);
 
-                //printf("%s: %s\n", __func__, ctx->vocab.id_to_token[id].c_str());
+                //{
+                //    const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
+                //    printf("%s: %10s %6.3f '%s'\n", __func__, tt.c_str(), token.pt, ctx->vocab.id_to_token[token.id].c_str());
+                //}
 
                 // end of text token
-                if (id == whisper_token_eot(ctx)) {
+                if (token.id == whisper_token_eot(ctx)) {
                     if (result_len == 0) {
                         if (seek + seek_delta + 100 >= whisper_n_len(ctx)) {
                             result_len = i + 1;
@@ -2477,6 +2536,12 @@ int whisper_full(
                     }
                     break;
                 }
+
+                // TESTS: if no tensors are loaded, it means we are running tests
+                if (ctx->model.n_loaded == 0) {
+                    seek_delta = 100*WHISPER_CHUNK_SIZE;
+                    break;
+                }
             }
 
             if (done) {
@@ -2484,25 +2549,30 @@ int whisper_full(
             }
         }
 
-        result_cur.resize(result_len);
+        tokens_cur.resize(result_len);
 
-        for (const auto & r : result_cur) {
+        for (const auto & r : tokens_cur) {
             prompt_past.push_back(r.id);
         }
 
         // store the text from this iteration
-        if (result_cur.size() > 0) {
-            auto t0 = result_cur.front().t;
+        if (tokens_cur.size() > 0) {
+            int  i0 = 0;
+            auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
 
             std::string text = "";
 
-            for (int i = 0; i < (int) result_cur.size(); i++) {
-                if (params.print_special_tokens == false && result_cur[i].id >= whisper_token_eot(ctx)) {
+            for (int i = 0; i < (int) tokens_cur.size(); i++) {
+                //printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
+                //        ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
+                //        ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
+
+                if (params.print_special_tokens == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
                 } else {
-                    text += whisper_token_to_str(ctx, result_cur[i].id);
+                    text += whisper_token_to_str(ctx, tokens_cur[i].id);
                 }
-                if (result_cur[i].id > whisper_token_beg(ctx)) {
-                    const auto t1 = result_cur[i].t;
+                if (tokens_cur[i].id > whisper_token_beg(ctx)) {
+                    const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
                     if (!text.empty()) {
                         if (params.print_realtime) {
                             if (params.print_timestamps) {
@@ -2513,14 +2583,21 @@ int whisper_full(
                             }
                         }
 
-                        result_all.push_back({ t0, t1, text });
+                        result_all.push_back({ t0, t1, text, {} });
+                        for (int j = i0; j <= i; j++) {
+                            result_all.back().tokens.push_back(tokens_cur[j]);
+                        }
+                        if (params.new_segment_callback) {
+                            params.new_segment_callback(ctx, params.new_segment_callback_user_data);
+                        }
                     }
                     text = "";
-                    while (i < (int) result_cur.size() && result_cur[i].id > whisper_token_beg(ctx)) {
+                    while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
                         i++;
                     }
                     i--;
-                    t0 = result_cur[i].t;
+                    t0 = t1;
+                    i0 = i + 1;
                 }
             }
 
@@ -2536,7 +2613,13 @@ int whisper_full(
                     }
                 }
 
-                result_all.push_back({ t0, t1, text });
+                result_all.push_back({ t0, t1, text, {} });
+                for (int j = i0; j < (int) tokens_cur.size(); j++) {
+                    result_all.back().tokens.push_back(tokens_cur[j]);
+                }
+                if (params.new_segment_callback) {
+                    params.new_segment_callback(ctx, params.new_segment_callback_user_data);
+                }
             }
         }
 
@@ -2546,6 +2629,156 @@ int whisper_full(
     return 0;
 }
 
+int whisper_full_parallel(
+        struct whisper_context * ctx,
+        struct whisper_full_params params,
+        const float * samples,
+        int n_samples,
+        const int n_processors) {
+    if (n_processors == 1) {
+        return whisper_full(ctx, params, samples, n_samples);
+    }
+
+    int ret = 0;
+
+    // prepare separate contexts for each thread
+    std::vector<struct whisper_context> ctxs(n_processors - 1);
+
+    for (int i = 0; i < n_processors - 1; ++i) {
+        ctxs[i] = *ctx;
+
+        auto & model = ctxs[i].model;
+
+        // create the ggml memory context
+        {
+            struct ggml_init_params params = {
+                .mem_size   = ctxs[i].buf_memory.size(),
+                .mem_buffer = ctxs[i].buf_memory.data(),
+            };
+
+            model.ctx_mem = ggml_init(params);
+            if (!model.ctx_mem) {
+                fprintf(stderr, "%s: ggml_init() failed\n", __func__);
+                return false;
+            }
+        }
+
+        // separate key + value memory for each processor
+        {
+            auto & ctx = model.ctx_mem;
+
+            const auto & hparams = model.hparams;
+
+            const int n_text_state = hparams.n_text_state;
+            const int n_text_layer = hparams.n_text_layer;
+            const int n_text_ctx   = hparams.n_text_ctx;
+
+            // key/value memory for the self-attention layer
+            {
+                const int n_mem      = n_text_layer*n_text_ctx;
+                const int n_elements = n_text_state*n_mem;
+
+                model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+                model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+            }
+
+            // key/value memory for the cross-attention layer
+            {
+                const int n_audio_ctx   = hparams.n_audio_ctx;
+
+                const int n_mem      = n_text_layer*n_audio_ctx;
+                const int n_elements = n_text_state*n_mem;
+
+                model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+                model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+            }
+
+            const size_t memory_size =
+                ggml_nbytes(model.memory_k)       + ggml_nbytes(model.memory_v) +
+                ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
+        }
+    }
+
+    const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000;
+    const int n_samples_per_processor = (n_samples - offset_samples)/n_processors;
+
+    // the calling thread will process the first chunk
+    // while the other threads will process the remaining chunks
+
+    std::vector<std::thread> workers(n_processors - 1);
+    for (int i = 0; i < n_processors - 1; ++i) {
+        const int start_samples = offset_samples + (i + 1)*n_samples_per_processor;
+        const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor;
+
+        auto params_cur = params;
+
+        params_cur.offset_ms = 0;
+        params_cur.print_progress = false;
+        params_cur.print_realtime = false;
+
+        params_cur.new_segment_callback = nullptr;
+        params_cur.new_segment_callback_user_data = nullptr;
+
+        workers[i] = std::thread(whisper_full, &ctxs[i], std::move(params_cur), samples + start_samples, n_samples_cur);
+    }
+
+    {
+        auto params_cur = params;
+
+        ret = whisper_full(ctx, std::move(params_cur), samples, offset_samples + n_samples_per_processor);
+    }
+
+    for (int i = 0; i < n_processors - 1; ++i) {
+        workers[i].join();
+    }
+
+    const int64_t offset_t = (int64_t) params.offset_ms/10.0;
+
+    // combine results into ctx->result_all
+    for (int i = 0; i < n_processors - 1; ++i) {
+        auto & results_i = ctxs[i].result_all;
+
+        for (int j = 0; j < (int) results_i.size(); ++j) {
+            // correct the segment timestamp taking into account the offset
+            results_i[j].t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
+            results_i[j].t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
+
+            // make sure that segments are not overlapping
+            if (ctx->result_all.size() > 0) {
+                results_i[j].t0 = std::max(results_i[j].t0, ctx->result_all.back().t1);
+            }
+
+            ctx->result_all.push_back(std::move(results_i[j]));
+
+            // call the new_segment_callback for each segment
+            if (params.new_segment_callback) {
+                params.new_segment_callback(ctx, params.new_segment_callback_user_data);
+            }
+        }
+
+        ctx->t_mel_us    += ctxs[i].t_mel_us;
+        ctx->t_sample_us += ctxs[i].t_sample_us;
+        ctx->t_encode_us += ctxs[i].t_encode_us;
+        ctx->t_decode_us += ctxs[i].t_decode_us;
+    }
+
+    // average the timings
+    ctx->t_mel_us    /= n_processors;
+    ctx->t_sample_us /= n_processors;
+    ctx->t_encode_us /= n_processors;
+    ctx->t_decode_us /= n_processors;
+
+    // print information about the audio boundaries
+    fprintf(stderr, "\n");
+    fprintf(stderr, "%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors);
+    for (int i = 0; i < n_processors - 1; ++i) {
+        fprintf(stderr, "%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str());
+    }
+    fprintf(stderr, "%s: the transcription quality may be degraded near these boundaries\n", __func__);
+
+    return ret;
+}
+
 int whisper_full_n_segments(struct whisper_context * ctx) {
     return ctx->result_all.size();
 }
@@ -2561,3 +2794,37 @@ int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)
 const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) {
     return ctx->result_all[i_segment].text.c_str();
 }
+
+int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) {
+    return ctx->result_all[i_segment].tokens.size();
+}
+
+const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) {
+    return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str();
+}
+
+whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) {
+    return ctx->result_all[i_segment].tokens[i_token].id;
+}
+
+struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) {
+    return ctx->result_all[i_segment].tokens[i_token];
+}
+
+float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
+    return ctx->result_all[i_segment].tokens[i_token].p;
+}
+
+const char * whisper_print_system_info() {
+    static std::string s;
+
+    s  = "";
+    s += "AVX2 = "      + std::to_string(ggml_cpu_has_avx2())      + " | ";
+    s += "AVX512 = "    + std::to_string(ggml_cpu_has_avx512())    + " | ";
+    s += "NEON = "      + std::to_string(ggml_cpu_has_neon())      + " | ";
+    s += "FP16_VA = "   + std::to_string(ggml_cpu_has_fp16_va())   + " | ";
+    s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
+    s += "BLAS = "      + std::to_string(ggml_cpu_has_blas())      + " | ";
+
+    return s.c_str();
+}
index 4423674d1d28695b77ef8dfebbba44a8c28889c0..5d7c40d00ee452a6b6633cd3875f9d1d9699e41d 100644 (file)
@@ -68,6 +68,15 @@ extern "C" {
 
     typedef int whisper_token;
 
+    struct whisper_token_data {
+        whisper_token id;  // token id
+        whisper_token tid; // forced timestamp token id
+
+        float p;     // probability of the token
+        float pt;    // probability of the timestamp token
+        float ptsum; // sum of probabilities of all timestamp tokens
+    };
+
     // Allocates all memory needed for the model and loads the model from the given file.
     // Returns NULL on failure.
     WHISPER_API struct whisper_context * whisper_init(const char * path_model);
@@ -120,7 +129,7 @@ extern "C" {
     // You can also implement your own sampling method using the whisper_get_probs() function.
     // whisper_sample_best() returns the token with the highest probability
     // whisper_sample_timestamp() returns the most probable timestamp token
-    WHISPER_API whisper_token whisper_sample_best(struct whisper_context * ctx);
+    WHISPER_API struct whisper_token_data whisper_sample_best(struct whisper_context * ctx);
     WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
 
     // Return the id of the specified language, returns -1 if not found
@@ -160,10 +169,16 @@ extern "C" {
         WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet!
     };
 
+    // Text segment callback
+    // Called on every newly generated text segment
+    // Use the whisper_full_...() functions to obtain the text segments
+    typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, void * user_data);
+
     struct whisper_full_params {
         enum whisper_sampling_strategy strategy;
 
         int n_threads;
+        int n_max_text_ctx;
         int offset_ms;
 
         bool translate;
@@ -184,6 +199,9 @@ extern "C" {
             int beam_width;
             int n_best;
         } beam_search;
+
+        whisper_new_segment_callback new_segment_callback;
+        void * new_segment_callback_user_data;
     };
 
     WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
@@ -196,6 +214,16 @@ extern "C" {
             const float * samples,
             int n_samples);
 
+    // Split the input audio in chunks and process each chunk separately using whisper_full()
+    // It seems this approach can offer some speedup in some cases.
+    // However, the transcription accuracy can be worse at the beginning and end of each chunk.
+    WHISPER_API int whisper_full_parallel(
+            struct whisper_context * ctx,
+            struct whisper_full_params params,
+            const float * samples,
+            int n_samples,
+            const int n_processors);
+
     // Number of generated text segments.
     // A segment can be a few words, a sentence, or even a paragraph.
     WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx);
@@ -207,6 +235,23 @@ extern "C" {
     // Get the text of the specified segment.
     WHISPER_API const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment);
 
+    // Get number of tokens in the specified segment.
+    WHISPER_API int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment);
+
+    // Get the token text of the specified token in the specified segment.
+    WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token);
+    WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token);
+
+    // Get token data for the specified token in the specified segment.
+    // This contains probabilities, timestamps, etc.
+    WHISPER_API struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token);
+
+    // Get the probability of the specified token in the specified segment.
+    WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
+
+    // Print system information
+    WHISPER_API const char * whisper_print_system_info();
+
 #ifdef __cplusplus
 }
 #endif
index 34f104b70c26daffd9b7f6988d6d1f5929d27712..f92ae73c3ad302e8bf02dac27cb54acfee39354d 100644 (file)
@@ -11,7 +11,7 @@ extern "C" {
 #define GGML_MAX_DIMS     4
 #define GGML_MAX_NODES    4096
 #define GGML_MAX_PARAMS   16
-#define GGML_MAX_CONTEXTS 16
+#define GGML_MAX_CONTEXTS 64
 #define GGML_MAX_OPT      4
 
 #ifdef __ARM_NEON
@@ -548,6 +548,17 @@ enum ggml_opt_result ggml_opt(
         struct ggml_opt_params params,
         struct ggml_tensor * f);
 
+//
+// system info
+//
+
+int ggml_cpu_has_avx2(void);
+int ggml_cpu_has_avx512(void);
+int ggml_cpu_has_neon(void);
+int ggml_cpu_has_fp16_va(void);
+int ggml_cpu_has_wasm_simd(void);
+int ggml_cpu_has_blas(void);
+
 #ifdef  __cplusplus
 }
 #endif
index 115e619b081f48927a90e0c37322508316fba308..1000a5b6b711e41fd092776dd9571d362ed8319a 100644 (file)
 #include <stdio.h>
 
 #if defined _MSC_VER
-#include "msvc_thread_atomic.h"
+#include <Windows.h>
+
+typedef volatile LONG atomic_int;
+typedef atomic_int atomic_bool;
+
+static void atomic_store(atomic_int* ptr, LONG val) {
+    InterlockedExchange(ptr, val);
+}
+static LONG atomic_load(atomic_int* ptr) {
+    return InterlockedCompareExchange(ptr, 0, 0);
+}
+static LONG atomic_fetch_add(atomic_int* ptr, LONG inc) {
+    return InterlockedExchangeAdd(ptr, inc);
+}
+static LONG atomic_fetch_sub(atomic_int* ptr, LONG dec) {
+    return atomic_fetch_add(ptr, -(dec));
+}
+
+typedef HANDLE pthread_t;
+
+typedef DWORD thread_ret_t;
+static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
+    out = CreateThread(NULL, 0, func, arg, 0, NULL);
+    return out != NULL;
+}
+
+static int pthread_join(pthread_t thread, void* unused) {
+    return (int) WaitForSingleObject(thread, INFINITE);
+}
+
+static int sched_yield (void) {
+    Sleep (0);
+    return 0;
+}
 #else
 #include <pthread.h>
 #include <stdatomic.h>
+
 typedef void* thread_ret_t;
 #endif
 
@@ -47,6 +81,8 @@ typedef void* thread_ret_t;
 
 #ifdef GGML_USE_ACCELERATE
 #include <Accelerate/Accelerate.h>
+#elif GGML_USE_OPENBLAS
+#include <cblas.h>
 #endif
 
 // floating point type used to accumulate sums
@@ -73,7 +109,11 @@ ggml_fp16_t ggml_fp32_to_fp16(float x) {
 
 #else
 
+#ifdef __wasm_simd128__
+#include <wasm_simd128.h>
+#else
 #include <immintrin.h>
+#endif
 
 // FP16 <-> FP32
 // ref: https://github.com/Maratyszcza/FP16
@@ -288,7 +328,7 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
         sumf += x[i]*y[i];
     }
 #elif defined(__AVX2__)
-    // AVX 256-bit (unroll 4)
+    // AVX 256-bit
     const int n32 = (n & ~31);
 
     __m256 sum0 = _mm256_setzero_ps();
@@ -330,6 +370,45 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
     for (int i = n32; i < n; ++i) {
         sumf += x[i]*y[i];
     }
+#elif defined(__wasm_simd128__)
+    // WASM 128-bit
+    const int n16 = (n & ~15);
+
+    v128_t sum0 = wasm_f32x4_splat(0);
+    v128_t sum1 = wasm_f32x4_splat(0);
+    v128_t sum2 = wasm_f32x4_splat(0);
+    v128_t sum3 = wasm_f32x4_splat(0);
+
+    v128_t x0, x1, x2, x3;
+    v128_t y0, y1, y2, y3;
+
+    for (int i = 0; i < n16; i += 16) {
+        x0 = wasm_v128_load(x + i + 0);
+        x1 = wasm_v128_load(x + i + 4);
+        x2 = wasm_v128_load(x + i + 8);
+        x3 = wasm_v128_load(x + i + 12);
+
+        y0 = wasm_v128_load(y + i + 0);
+        y1 = wasm_v128_load(y + i + 4);
+        y2 = wasm_v128_load(y + i + 8);
+        y3 = wasm_v128_load(y + i + 12);
+
+        sum0 = wasm_f32x4_add(sum0, wasm_f32x4_mul(x0, y0));
+        sum1 = wasm_f32x4_add(sum1, wasm_f32x4_mul(x1, y1));
+        sum2 = wasm_f32x4_add(sum2, wasm_f32x4_mul(x2, y2));
+        sum3 = wasm_f32x4_add(sum3, wasm_f32x4_mul(x3, y3));
+    }
+
+    sum0 = wasm_f32x4_add(sum0, sum1);
+    sum2 = wasm_f32x4_add(sum2, sum3);
+    sum0 = wasm_f32x4_add(sum0, sum2);
+
+    sumf = wasm_f32x4_extract_lane(sum0, 0) + wasm_f32x4_extract_lane(sum0, 1) + wasm_f32x4_extract_lane(sum0, 2) + wasm_f32x4_extract_lane(sum0, 3);
+
+    // leftovers
+    for (int i = n16; i < n; ++i) {
+        sumf += x[i]*y[i];
+    }
 #else
     // scalar
     for (int i = 0; i < n; ++i) {
@@ -446,7 +525,7 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
         sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
     }
 #elif defined(__AVX2__)
-    // AVX 256-bit (unroll 4)
+    // AVX 256-bit
     const int n32 = (n & ~31);
 
     __m256 sum0 = _mm256_setzero_ps();
@@ -489,6 +568,54 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
         //GGML_ASSERT(false);
         sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
     }
+#elif defined(__wasm_simd128__)
+    // WASM 128-bit
+    const int n16 = (n & ~15);
+
+    v128_t sum0 = wasm_f32x4_splat(0.0f);
+    v128_t sum1 = wasm_f32x4_splat(0.0f);
+    v128_t sum2 = wasm_f32x4_splat(0.0f);
+    v128_t sum3 = wasm_f32x4_splat(0.0f);
+
+    v128_t x0, x1, x2, x3;
+    v128_t y0, y1, y2, y3;
+
+    float tx[16];
+    float ty[16];
+
+    for (int i = 0; i < n16; i += 16) {
+        for (int k = 0; k < 16; ++k) {
+            tx[k] = ggml_fp16_to_fp32(x[i + k]);
+            ty[k] = ggml_fp16_to_fp32(y[i + k]);
+        }
+
+        x0 = wasm_v128_load(tx + 0);
+        x1 = wasm_v128_load(tx + 4);
+        x2 = wasm_v128_load(tx + 8);
+        x3 = wasm_v128_load(tx + 12);
+
+        y0 = wasm_v128_load(ty + 0);
+        y1 = wasm_v128_load(ty + 4);
+        y2 = wasm_v128_load(ty + 8);
+        y3 = wasm_v128_load(ty + 12);
+
+        sum0 = wasm_f32x4_add(sum0, wasm_f32x4_mul(x0, y0));
+        sum1 = wasm_f32x4_add(sum1, wasm_f32x4_mul(x1, y1));
+        sum2 = wasm_f32x4_add(sum2, wasm_f32x4_mul(x2, y2));
+        sum3 = wasm_f32x4_add(sum3, wasm_f32x4_mul(x3, y3));
+    }
+
+    sum0 = wasm_f32x4_add(sum0, sum1);
+    sum2 = wasm_f32x4_add(sum2, sum3);
+    sum0 = wasm_f32x4_add(sum0, sum2);
+
+    sumf = wasm_f32x4_extract_lane(sum0, 0) + wasm_f32x4_extract_lane(sum0, 1) + wasm_f32x4_extract_lane(sum0, 2) + wasm_f32x4_extract_lane(sum0, 3);
+
+    // leftovers
+    for (int i = n16; i < n; ++i) {
+        //GGML_ASSERT(false);
+        sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
+    }
 #else
     for (int i = 0; i < n; ++i) {
         sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
@@ -535,7 +662,7 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
         y[i] += x[i]*v;
     }
 #elif defined(__AVX2__)
-    // AVX 256-bit (unroll 4)
+    // AVX 256-bit
     const int n32 = (n & ~31);
 
     const __m256 v4 = _mm256_set1_ps(v);
@@ -569,6 +696,41 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
     for (int i = n32; i < n; ++i) {
         y[i] += x[i]*v;
     }
+#elif defined(__wasm_simd128__)
+    // WASM SIMD 128-bit
+    const int n16 = (n & ~15);
+
+    const v128_t v4 = wasm_f32x4_splat(v);
+
+    v128_t x0, x1, x2, x3;
+    v128_t y0, y1, y2, y3;
+
+    for (int i = 0; i < n16; i += 16) {
+        x0 = wasm_v128_load(x + i + 0);
+        x1 = wasm_v128_load(x + i + 4);
+        x2 = wasm_v128_load(x + i + 8);
+        x3 = wasm_v128_load(x + i + 12);
+
+        y0 = wasm_v128_load(y + i + 0);
+        y1 = wasm_v128_load(y + i + 4);
+        y2 = wasm_v128_load(y + i + 8);
+        y3 = wasm_v128_load(y + i + 12);
+
+        y0 = wasm_f32x4_add(y0, wasm_f32x4_mul(x0, v4));
+        y1 = wasm_f32x4_add(y1, wasm_f32x4_mul(x1, v4));
+        y2 = wasm_f32x4_add(y2, wasm_f32x4_mul(x2, v4));
+        y3 = wasm_f32x4_add(y3, wasm_f32x4_mul(x3, v4));
+
+        wasm_v128_store(y + i + 0, y0);
+        wasm_v128_store(y + i + 4, y1);
+        wasm_v128_store(y + i + 8, y2);
+        wasm_v128_store(y + i + 12, y3);
+    }
+
+    // leftovers
+    for (int i = n16; i < n; ++i) {
+        y[i] += x[i]*v;
+    }
 #else
     // scalar
     for (int i = 0; i < n; ++i) {
@@ -696,6 +858,54 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
         GGML_ASSERT(false);
         y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);
     }
+#elif defined(__wasm_simd128__)
+    // WASM SIMD 128-bit
+    const int n16 = (n & ~15);
+
+    const v128_t v4 = wasm_f32x4_splat(v);
+
+    v128_t x0, x1, x2, x3;
+    v128_t y0, y1, y2, y3;
+
+    float tx[16];
+    float ty[16];
+
+    for (int i = 0; i < n16; i += 16) {
+        for (int k = 0; k < 16; ++k) {
+            tx[k] = ggml_fp16_to_fp32(x[i + k]);
+            ty[k] = ggml_fp16_to_fp32(y[i + k]);
+        }
+
+        x0 = wasm_v128_load(tx + 0);
+        x1 = wasm_v128_load(tx + 4);
+        x2 = wasm_v128_load(tx + 8);
+        x3 = wasm_v128_load(tx + 12);
+
+        y0 = wasm_v128_load(ty + 0);
+        y1 = wasm_v128_load(ty + 4);
+        y2 = wasm_v128_load(ty + 8);
+        y3 = wasm_v128_load(ty + 12);
+
+        y0 = wasm_f32x4_add(y0, wasm_f32x4_mul(x0, v4));
+        y1 = wasm_f32x4_add(y1, wasm_f32x4_mul(x1, v4));
+        y2 = wasm_f32x4_add(y2, wasm_f32x4_mul(x2, v4));
+        y3 = wasm_f32x4_add(y3, wasm_f32x4_mul(x3, v4));
+
+        wasm_v128_store(ty + 0, y0);
+        wasm_v128_store(ty + 4, y1);
+        wasm_v128_store(ty + 8, y2);
+        wasm_v128_store(ty + 12, y3);
+
+        for (int k = 0; k < 16; ++k) {
+            y[i + k] = ggml_fp32_to_fp16(ty[k]);
+        }
+    }
+
+    // leftovers
+    for (int i = n16; i < n; ++i) {
+        GGML_ASSERT(false);
+        y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);
+    }
 #else
     for (int i = 0; i < n; ++i) {
         y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);
@@ -931,6 +1141,7 @@ struct ggml_state {
 
 // global state
 struct ggml_state g_state;
+atomic_int g_state_barrier = 0;
 
 ////////////////////////////////////////////////////////////////////////////////
 
@@ -1060,6 +1271,17 @@ int ggml_up64(int n) {
 ////////////////////////////////////////////////////////////////////////////////
 
 struct ggml_context * ggml_init(struct ggml_init_params params) {
+    // make this function thread safe
+    {
+        int processing = atomic_fetch_add(&g_state_barrier, 1);
+        while (processing > 0) {
+            // wait for other threads to finish
+            atomic_fetch_sub(&g_state_barrier, 1);
+            sched_yield();
+            processing = atomic_fetch_add(&g_state_barrier, 1);
+        }
+    }
+
     static bool is_first_call = true;
     if (is_first_call) {
         const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
@@ -1103,6 +1325,9 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
 
     if (ctx == NULL) {
         GGML_PRINT_DEBUG("%s: no unused context found\n", __func__);
+
+        atomic_fetch_sub(&g_state_barrier, 1);
+
         return NULL;
     }
 
@@ -1117,10 +1342,25 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
 
     ggml_assert_aligned(ctx->mem_buffer);
 
+    GGML_PRINT_DEBUG("%s: context initialized\n", __func__);
+
+    atomic_fetch_sub(&g_state_barrier, 1);
+
     return ctx;
 }
 
 void ggml_free(struct ggml_context * ctx) {
+    // make this function thread safe
+    {
+        int processing = atomic_fetch_add(&g_state_barrier, 1);
+        while (processing > 0) {
+            // wait for other threads to finish
+            atomic_fetch_sub(&g_state_barrier, 1);
+            sched_yield();
+            processing = atomic_fetch_add(&g_state_barrier, 1);
+        }
+    }
+
     for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
         if (&g_state.contexts[i].context == ctx) {
             g_state.contexts[i].used = false;
@@ -1132,11 +1372,15 @@ void ggml_free(struct ggml_context * ctx) {
                 free(ctx->mem_buffer);
             }
 
+            atomic_fetch_sub(&g_state_barrier, 1);
+
             return;
         }
     }
 
     GGML_PRINT_DEBUG("%s: context not found\n", __func__);
+
+    atomic_fetch_sub(&g_state_barrier, 1);
 }
 
 size_t ggml_used_mem(const struct ggml_context * ctx) {
@@ -3852,46 +4096,44 @@ void ggml_compute_forward_mul_mat_f32(
     // nb00 <  nb01 - src0 is transposed
     //   compute by src0 columns
 
-//#ifdef GGML_USE_ACCELERATE
-//    if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
-//        GGML_ASSERT(ggml_is_contiguous(src0));
-//        GGML_ASSERT(nb10 == sizeof(float));
-//
-//        if (params->ith != 0) return;
-//
-//        if (params->type == GGML_TASK_INIT) {
-//            return;
-//        }
-//
-//        if (params->type == GGML_TASK_FINALIZE) {
-//            return;
-//        }
-//
-//        float * const wdata = params->wdata;
-//
-//        for (int i03 = 0; i03 < ne03; i03++) {
-//            for (int i02 = 0; i02 < ne02; i02++) {
-//                const float * x = (float *) (src0->data);
-//                const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
-//
-//                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
-//
-//                // zT = y * xT
-//                {
-//                    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
-//                            ne11, ne01, ne10,
-//                            1.0f,    y, ne10,
-//                                     x, ne10,
-//                            0.0f,    d, ne01);
-//                }
-//            }
-//        }
-//
-//        //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
-//
-//        return;
-//    }
-//#endif
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+    if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
+        GGML_ASSERT(ggml_is_contiguous(src0));
+        GGML_ASSERT(nb10 == sizeof(float));
+
+        if (params->ith != 0) return;
+
+        if (params->type == GGML_TASK_INIT) {
+            return;
+        }
+
+        if (params->type == GGML_TASK_FINALIZE) {
+            return;
+        }
+
+        for (int i03 = 0; i03 < ne03; i03++) {
+            for (int i02 = 0; i02 < ne02; i02++) {
+                const float * x = (float *) (src0->data);
+                const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
+
+                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+
+                // zT = y * xT
+                {
+                    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+                            ne11, ne01, ne10,
+                            1.0f,    y, ne10,
+                                     x, ne10,
+                            0.0f,    d, ne01);
+                }
+            }
+        }
+
+        //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
+
+        return;
+    }
+#endif
 
     if (params->type == GGML_TASK_INIT) {
         if (nb01 >= nb00) {
@@ -4098,7 +4340,7 @@ void ggml_compute_forward_mul_mat_f16_f32(
     // nb00 <  nb01 - src0 is transposed
     //   compute by src0 columns
 
-#ifdef GGML_USE_ACCELERATE
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
         GGML_ASSERT(nb10 == sizeof(float));
 
@@ -6654,7 +6896,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                         } else {
                             if (node->src0->type == GGML_TYPE_F16 &&
                                 node->src1->type == GGML_TYPE_F32) {
-#ifdef GGML_USE_ACCELERATE
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
                                 if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
                                     cur = sizeof(float)*(node->src0->ne[0]*node->src0->ne[1]);
                                 } else {
@@ -7358,7 +7600,7 @@ enum ggml_opt_result ggml_opt_adam(
 
         {
             const int64_t t_end_cpu = ggml_cycles();
-            GGML_PRINT_DEBUG("time iter:      %5.3f s\n", (t_end_cpu - t_start_cpu)/CLOCKS_PER_SEC);
+            GGML_PRINT_DEBUG("time iter:      %5.3f s\n", ((float)(t_end_cpu - t_start_cpu))/CLOCKS_PER_SEC);
             UNUSED(t_end_cpu);
 
             const int64_t t_end_wall = ggml_time_us();
@@ -7829,3 +8071,53 @@ enum ggml_opt_result ggml_opt(
 }
 
 ////////////////////////////////////////////////////////////////////////////////
+
+int ggml_cpu_has_avx2(void) {
+#if defined(__AVX2__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_avx512(void) {
+#if defined(__AVX512F__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_neon(void) {
+#if defined(__ARM_NEON__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_fp16_va(void) {
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_wasm_simd(void) {
+#if defined(__wasm_simd128__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_blas(void) {
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+////////////////////////////////////////////////////////////////////////////////
diff --git a/src/msvc_thread_atomic.h b/src/msvc_thread_atomic.h
deleted file mode 100644 (file)
index 52cd419..0000000
+++ /dev/null
@@ -1,31 +0,0 @@
-#pragma once
-#include <Windows.h>
-
-typedef volatile LONG atomic_int;
-typedef atomic_int atomic_bool;
-
-static void atomic_store(atomic_int* ptr, LONG val) {
-    InterlockedExchange(ptr, val);
-}
-static LONG atomic_load(atomic_int* ptr) {
-    return InterlockedCompareExchange(ptr, 0, 0);
-}
-static LONG atomic_fetch_add(atomic_int* ptr, LONG inc) {
-    return InterlockedExchangeAdd(ptr, inc);
-}
-static LONG atomic_fetch_sub(atomic_int* ptr, LONG dec) {
-    return atomic_fetch_add(ptr, -(dec));
-}
-
-typedef HANDLE pthread_t;
-
-typedef DWORD thread_ret_t;
-static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
-    out = CreateThread(NULL, 0, func, arg, 0, NULL);
-    return out != NULL;
-}
-
-static int pthread_join(pthread_t thread, void* unused) {
-    return (int) WaitForSingleObject(thread, INFINITE);
-}
-