]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
sync : latest changes from whisper.cpp
authorGeorgi Gerganov <redacted>
Sat, 31 Dec 2022 10:32:04 +0000 (12:32 +0200)
committerGeorgi Gerganov <redacted>
Sat, 31 Dec 2022 10:32:04 +0000 (12:32 +0200)
examples/whisper/main.cpp
examples/whisper/whisper.cpp
examples/whisper/whisper.h
include/ggml/ggml.h
src/ggml.c

index 465d43fb0796455bdfce509ab1a5b661500bb58d..ce8b484df30f041f447571dcea5afea3777f5ea3 100644 (file)
@@ -62,19 +62,22 @@ struct whisper_params {
 
     float word_thold = 0.01f;
 
-    bool speed_up      = false;
-    bool translate     = false;
-    bool diarize       = false;
-    bool output_txt    = false;
-    bool output_vtt    = false;
-    bool output_srt    = false;
-    bool output_wts    = false;
-    bool print_special = false;
-    bool print_colors  = false;
-    bool no_timestamps = false;
-
-    std::string language  = "en";
-    std::string model     = "models/ggml-base.en.bin";
+    bool speed_up       = false;
+    bool translate      = false;
+    bool diarize        = false;
+    bool output_txt     = false;
+    bool output_vtt     = false;
+    bool output_srt     = false;
+    bool output_wts     = false;
+    bool output_csv     = false;
+    bool print_special  = false;
+    bool print_colors   = false;
+    bool print_progress = false;
+    bool no_timestamps  = false;
+
+    std::string language = "en";
+    std::string prompt;
+    std::string model    = "models/ggml-base.en.bin";
 
     std::vector<std::string> fname_inp = {};
 };
@@ -94,27 +97,30 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
             whisper_print_usage(argc, argv, params);
             exit(0);
         }
-        else if (arg == "-t"    || arg == "--threads")       { params.n_threads     = std::stoi(argv[++i]); }
-        else if (arg == "-p"    || arg == "--processors")    { params.n_processors  = std::stoi(argv[++i]); }
-        else if (arg == "-ot"   || arg == "--offset-t")      { params.offset_t_ms   = std::stoi(argv[++i]); }
-        else if (arg == "-on"   || arg == "--offset-n")      { params.offset_n      = std::stoi(argv[++i]); }
-        else if (arg == "-d"    || arg == "--duration")      { params.duration_ms   = std::stoi(argv[++i]); }
-        else if (arg == "-mc"   || arg == "--max-context")   { params.max_context   = std::stoi(argv[++i]); }
-        else if (arg == "-ml"   || arg == "--max-len")       { params.max_len       = std::stoi(argv[++i]); }
-        else if (arg == "-wt"   || arg == "--word-thold")    { params.word_thold    = std::stof(argv[++i]); }
-        else if (arg == "-su"   || arg == "--speed-up")      { params.speed_up      = true; }
-        else if (arg == "-tr"   || arg == "--translate")     { params.translate     = true; }
-        else if (arg == "-di"   || arg == "--diarize")       { params.diarize       = true; }
-        else if (arg == "-otxt" || arg == "--output-txt")    { params.output_txt    = true; }
-        else if (arg == "-ovtt" || arg == "--output-vtt")    { params.output_vtt    = true; }
-        else if (arg == "-osrt" || arg == "--output-srt")    { params.output_srt    = true; }
-        else if (arg == "-owts" || arg == "--output-words")  { params.output_wts    = true; }
-        else if (arg == "-ps"   || arg == "--print-special") { params.print_special = true; }
-        else if (arg == "-pc"   || arg == "--print-colors")  { params.print_colors  = true; }
-        else if (arg == "-nt"   || arg == "--no-timestamps") { params.no_timestamps = true; }
-        else if (arg == "-l"    || arg == "--language")      { params.language      = argv[++i]; }
-        else if (arg == "-m"    || arg == "--model")         { params.model         = argv[++i]; }
-        else if (arg == "-f"    || arg == "--file")          { params.fname_inp.push_back(argv[++i]); }
+        else if (arg == "-t"    || arg == "--threads")        { params.n_threads      = std::stoi(argv[++i]); }
+        else if (arg == "-p"    || arg == "--processors")     { params.n_processors   = std::stoi(argv[++i]); }
+        else if (arg == "-ot"   || arg == "--offset-t")       { params.offset_t_ms    = std::stoi(argv[++i]); }
+        else if (arg == "-on"   || arg == "--offset-n")       { params.offset_n       = std::stoi(argv[++i]); }
+        else if (arg == "-d"    || arg == "--duration")       { params.duration_ms    = std::stoi(argv[++i]); }
+        else if (arg == "-mc"   || arg == "--max-context")    { params.max_context    = std::stoi(argv[++i]); }
+        else if (arg == "-ml"   || arg == "--max-len")        { params.max_len        = std::stoi(argv[++i]); }
+        else if (arg == "-wt"   || arg == "--word-thold")     { params.word_thold     = std::stof(argv[++i]); }
+        else if (arg == "-su"   || arg == "--speed-up")       { params.speed_up       = true; }
+        else if (arg == "-tr"   || arg == "--translate")      { params.translate      = true; }
+        else if (arg == "-di"   || arg == "--diarize")        { params.diarize        = true; }
+        else if (arg == "-otxt" || arg == "--output-txt")     { params.output_txt     = true; }
+        else if (arg == "-ovtt" || arg == "--output-vtt")     { params.output_vtt     = true; }
+        else if (arg == "-osrt" || arg == "--output-srt")     { params.output_srt     = true; }
+        else if (arg == "-owts" || arg == "--output-words")   { params.output_wts     = true; }
+        else if (arg == "-ocsv" || arg == "--output-csv")     { params.output_csv     = true; }
+        else if (arg == "-ps"   || arg == "--print-special")  { params.print_special  = true; }
+        else if (arg == "-pc"   || arg == "--print-colors")   { params.print_colors   = true; }
+        else if (arg == "-pp"   || arg == "--print-progress") { params.print_progress = true; }
+        else if (arg == "-nt"   || arg == "--no-timestamps")  { params.no_timestamps  = true; }
+        else if (arg == "-l"    || arg == "--language")       { params.language       = argv[++i]; }
+        else if (                  arg == "--prompt")         { params.prompt         = argv[++i]; }
+        else if (arg == "-m"    || arg == "--model")          { params.model          = argv[++i]; }
+        else if (arg == "-f"    || arg == "--file")           { params.fname_inp.emplace_back(argv[++i]); }
         else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
             whisper_print_usage(argc, argv, params);
@@ -125,33 +131,36 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
     return true;
 }
 
-void whisper_print_usage(int argc, char ** argv, const whisper_params & params) {
+void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
     fprintf(stderr, "\n");
     fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
     fprintf(stderr, "\n");
     fprintf(stderr, "options:\n");
-    fprintf(stderr, "  -h,       --help          [default] show this help message and exit\n");
-    fprintf(stderr, "  -t N,     --threads N     [%-7d] number of threads to use during computation\n",    params.n_threads);
-    fprintf(stderr, "  -p N,     --processors N  [%-7d] number of processors to use during computation\n", params.n_processors);
-    fprintf(stderr, "  -ot N,    --offset-t N    [%-7d] time offset in milliseconds\n",                    params.offset_t_ms);
-    fprintf(stderr, "  -on N,    --offset-n N    [%-7d] segment index offset\n",                           params.offset_n);
-    fprintf(stderr, "  -d  N,    --duration N    [%-7d] duration of audio to process in milliseconds\n",   params.duration_ms);
-    fprintf(stderr, "  -mc N,    --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
-    fprintf(stderr, "  -ml N,    --max-len N     [%-7d] maximum segment length in characters\n",           params.max_len);
-    fprintf(stderr, "  -wt N,    --word-thold N  [%-7.2f] word timestamp probability threshold\n",         params.word_thold);
-    fprintf(stderr, "  -su,      --speed-up      [%-7s] speed up audio by x2 (reduced accuracy)\n",        params.speed_up ? "true" : "false");
-    fprintf(stderr, "  -tr,      --translate     [%-7s] translate from source language to english\n",      params.translate ? "true" : "false");
-    fprintf(stderr, "  -di,      --diarize       [%-7s] stereo audio diarization\n",                       params.diarize ? "true" : "false");
-    fprintf(stderr, "  -otxt,    --output-txt    [%-7s] output result in a text file\n",                   params.output_txt ? "true" : "false");
-    fprintf(stderr, "  -ovtt,    --output-vtt    [%-7s] output result in a vtt file\n",                    params.output_vtt ? "true" : "false");
-    fprintf(stderr, "  -osrt,    --output-srt    [%-7s] output result in a srt file\n",                    params.output_srt ? "true" : "false");
-    fprintf(stderr, "  -owts,    --output-words  [%-7s] output script for generating karaoke video\n",     params.output_wts ? "true" : "false");
-    fprintf(stderr, "  -ps,      --print-special [%-7s] print special tokens\n",                           params.print_special ? "true" : "false");
-    fprintf(stderr, "  -pc,      --print-colors  [%-7s] print colors\n",                                   params.print_colors ? "true" : "false");
-    fprintf(stderr, "  -nt,      --no-timestamps [%-7s] do not print timestamps\n",                        params.no_timestamps ? "false" : "true");
-    fprintf(stderr, "  -l LANG,  --language LANG [%-7s] spoken language\n",                                params.language.c_str());
-    fprintf(stderr, "  -m FNAME, --model FNAME   [%-7s] model path\n",                                     params.model.c_str());
-    fprintf(stderr, "  -f FNAME, --file FNAME    [%-7s] input WAV file path\n",                            "");
+    fprintf(stderr, "  -h,       --help           [default] show this help message and exit\n");
+    fprintf(stderr, "  -t N,     --threads N      [%-7d] number of threads to use during computation\n",    params.n_threads);
+    fprintf(stderr, "  -p N,     --processors N   [%-7d] number of processors to use during computation\n", params.n_processors);
+    fprintf(stderr, "  -ot N,    --offset-t N     [%-7d] time offset in milliseconds\n",                    params.offset_t_ms);
+    fprintf(stderr, "  -on N,    --offset-n N     [%-7d] segment index offset\n",                           params.offset_n);
+    fprintf(stderr, "  -d  N,    --duration N     [%-7d] duration of audio to process in milliseconds\n",   params.duration_ms);
+    fprintf(stderr, "  -mc N,    --max-context N  [%-7d] maximum number of text context tokens to store\n", params.max_context);
+    fprintf(stderr, "  -ml N,    --max-len N      [%-7d] maximum segment length in characters\n",           params.max_len);
+    fprintf(stderr, "  -wt N,    --word-thold N   [%-7.2f] word timestamp probability threshold\n",         params.word_thold);
+    fprintf(stderr, "  -su,      --speed-up       [%-7s] speed up audio by x2 (reduced accuracy)\n",        params.speed_up ? "true" : "false");
+    fprintf(stderr, "  -tr,      --translate      [%-7s] translate from source language to english\n",      params.translate ? "true" : "false");
+    fprintf(stderr, "  -di,      --diarize        [%-7s] stereo audio diarization\n",                       params.diarize ? "true" : "false");
+    fprintf(stderr, "  -otxt,    --output-txt     [%-7s] output result in a text file\n",                   params.output_txt ? "true" : "false");
+    fprintf(stderr, "  -ovtt,    --output-vtt     [%-7s] output result in a vtt file\n",                    params.output_vtt ? "true" : "false");
+    fprintf(stderr, "  -osrt,    --output-srt     [%-7s] output result in a srt file\n",                    params.output_srt ? "true" : "false");
+    fprintf(stderr, "  -owts,    --output-words   [%-7s] output script for generating karaoke video\n",     params.output_wts ? "true" : "false");
+    fprintf(stderr, "  -ocsv,    --output-csv     [%-7s] output result in a CSV file\n",                    params.output_csv ? "true" : "false");
+    fprintf(stderr, "  -ps,      --print-special  [%-7s] print special tokens\n",                           params.print_special ? "true" : "false");
+    fprintf(stderr, "  -pc,      --print-colors   [%-7s] print colors\n",                                   params.print_colors ? "true" : "false");
+    fprintf(stderr, "  -pp,      --print-progress [%-7s] print progress\n",                                 params.print_progress ? "true" : "false");
+    fprintf(stderr, "  -nt,      --no-timestamps  [%-7s] do not print timestamps\n",                        params.no_timestamps ? "false" : "true");
+    fprintf(stderr, "  -l LANG,  --language LANG  [%-7s] spoken language ('auto' for auto-detect)\n",       params.language.c_str());
+    fprintf(stderr, "            --prompt PROMPT  [%-7s] initial prompt\n",                                 params.prompt.c_str());
+    fprintf(stderr, "  -m FNAME, --model FNAME    [%-7s] model path\n",                                     params.model.c_str());
+    fprintf(stderr, "  -f FNAME, --file FNAME     [%-7s] input WAV file path\n",                            "");
     fprintf(stderr, "\n");
 }
 
@@ -200,7 +209,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
             const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
             const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
 
-            std::string speaker = "";
+            std::string speaker;
 
             if (params.diarize && pcmf32s.size() == 2) {
                 const int64_t n_samples = pcmf32s[0].size();
@@ -266,7 +275,7 @@ bool output_txt(struct whisper_context * ctx, const char * fname) {
     const int n_segments = whisper_full_n_segments(ctx);
     for (int i = 0; i < n_segments; ++i) {
         const char * text = whisper_full_get_segment_text(ctx, i);
-        fout << text;
+        fout << text << "\n";
     }
 
     return true;
@@ -319,10 +328,36 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
     return true;
 }
 
+bool output_csv(struct whisper_context * ctx, const char * fname) {
+    std::ofstream fout(fname);
+    if (!fout.is_open()) {
+        fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
+        return false;
+    }
+
+    fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
+
+    const int n_segments = whisper_full_n_segments(ctx);
+    for (int i = 0; i < n_segments; ++i) {
+        const char * text = whisper_full_get_segment_text(ctx, i);
+       if (text[0] == ' ')
+         text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character.
+        const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
+        const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+       //need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
+        fout << 10 * t0 << ", " 
+            << 10 * t1 << ", \"" 
+            << text    << "\"\n";
+    }
+
+    return true;
+}
+
+
 // karaoke video generation
 // outputs a bash script that uses ffmpeg to generate a video with the subtitles
 // TODO: font parameter adjustments
-bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) {
+bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & /*params*/, float t_sec) {
     std::ofstream fout(fname);
 
     fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
@@ -371,7 +406,6 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
             txt_ul = "\\ \\ ";
 
             {
-                int ncnt = 0;
                 for (int k = 0; k < n; ++k) {
                     const auto & token2 = tokens[k];
 
@@ -395,8 +429,6 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
                             txt_ul += "\\ ";
                         }
                     }
-
-                    ncnt += txt.size();
                 }
 
                 ::replace_all(txt_bg, "'", "\u2019");
@@ -447,7 +479,7 @@ int main(int argc, char ** argv) {
         return 2;
     }
 
-    if (whisper_lang_id(params.language.c_str()) == -1) {
+    if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
         fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
         whisper_print_usage(argc, argv, params);
         exit(0);
@@ -462,6 +494,22 @@ int main(int argc, char ** argv) {
         return 3;
     }
 
+    // initial prompt
+    std::vector<whisper_token> prompt_tokens;
+
+    if (!params.prompt.empty()) {
+        prompt_tokens.resize(1024);
+        prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
+
+        fprintf(stderr, "\n");
+        fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
+        fprintf(stderr, "initial tokens: [ ");
+        for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
+            fprintf(stderr, "%d ", prompt_tokens[i]);
+        }
+        fprintf(stderr, "]\n");
+    }
+
     for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
         const auto fname_inp = params.fname_inp[f];
 
@@ -486,14 +534,14 @@ int main(int argc, char ** argv) {
                     }
                 }
 
-                if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), NULL) == false) {
+                if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
                     fprintf(stderr, "error: failed to open WAV file from stdin\n");
                     return 4;
                 }
 
                 fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
             }
-            else if (drwav_init_file(&wav, fname_inp.c_str(), NULL) == false) {
+            else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
                 fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
                 return 5;
             }
@@ -528,11 +576,11 @@ int main(int argc, char ** argv) {
             // convert to mono, float
             pcmf32.resize(n);
             if (wav.channels == 1) {
-                for (int i = 0; i < n; i++) {
+                for (uint64_t i = 0; i < n; i++) {
                     pcmf32[i] = float(pcm16[i])/32768.0f;
                 }
             } else {
-                for (int i = 0; i < n; i++) {
+                for (uint64_t i = 0; i < n; i++) {
                     pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
                 }
             }
@@ -543,7 +591,7 @@ int main(int argc, char ** argv) {
 
                 pcmf32s[0].resize(n);
                 pcmf32s[1].resize(n);
-                for (int i = 0; i < n; i++) {
+                for (uint64_t i = 0; i < n; i++) {
                     pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
                     pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
                 }
@@ -577,13 +625,12 @@ int main(int argc, char ** argv) {
             fprintf(stderr, "\n");
         }
 
-
         // run the inference
         {
             whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
 
             wparams.print_realtime   = false;
-            wparams.print_progress   = false;
+            wparams.print_progress   = params.print_progress;
             wparams.print_timestamps = !params.no_timestamps;
             wparams.print_special    = params.print_special;
             wparams.translate        = params.translate;
@@ -599,6 +646,9 @@ int main(int argc, char ** argv) {
 
             wparams.speed_up         = params.speed_up;
 
+            wparams.prompt_tokens    = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
+            wparams.prompt_n_tokens  = prompt_tokens.empty() ? 0       : prompt_tokens.size();
+
             whisper_print_user_data user_data = { &params, &pcmf32s };
 
             // this callback is called on each new segment
@@ -613,7 +663,7 @@ int main(int argc, char ** argv) {
             {
                 static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
 
-                wparams.encoder_begin_callback = [](struct whisper_context * ctx, void * user_data) {
+                wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
                     bool is_aborted = *(bool*)user_data;
                     return !is_aborted;
                 };
@@ -653,6 +703,13 @@ int main(int argc, char ** argv) {
                 const auto fname_wts = fname_inp + ".wts";
                 output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
             }
+
+           // output to CSV file
+            if (params.output_csv) {
+                const auto fname_csv = fname_inp + ".csv";
+                output_csv(ctx, fname_csv.c_str());
+            }
+
         }
     }
 
index 42467efe693fa62e2c420e438ceba8a6eed56966..84c24900791a6427a581b86bd1741991ae230394 100644 (file)
@@ -14,6 +14,7 @@
 #include <string>
 #include <thread>
 #include <vector>
+#include <regex>
 
 #define USE_FLASH_ATTN
 //#define USE_FLASH_FF
@@ -203,6 +204,10 @@ struct whisper_vocab {
     std::map<token, id> token_to_id;
     std::map<id, token> id_to_token;
 
+    // used to avoid memory allocations during sampling
+    // TODO: move to whisper_context in the future
+    std::vector<std::pair<double, whisper_vocab::id>> probs_id;
+
     id token_eot  = 50256;
     id token_sot  = 50257;
     id token_prev = 50360;
@@ -429,6 +434,12 @@ struct whisper_context {
     int32_t exp_n_audio_ctx; // 0 - use default
 };
 
+template<typename T>
+static void read_safe(std::ifstream& fin, T& dest)
+{
+  fin.read((char*)& dest, sizeof(T));
+}
+
 // load the model from a ggml file
 //
 // file format:
@@ -455,7 +466,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
     // verify magic
     {
         uint32_t magic;
-        fin.read((char *) &magic, sizeof(magic));
+        read_safe(fin, magic);
         if (magic != 0x67676d6c) {
             fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
             return false;
@@ -466,17 +477,17 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
     {
         auto & hparams = model.hparams;
 
-        fin.read((char *) &hparams.n_vocab,       sizeof(hparams.n_vocab));
-        fin.read((char *) &hparams.n_audio_ctx,   sizeof(hparams.n_audio_ctx));
-        fin.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
-        fin.read((char *) &hparams.n_audio_head,  sizeof(hparams.n_audio_head));
-        fin.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
-        fin.read((char *) &hparams.n_text_ctx,    sizeof(hparams.n_text_ctx));
-        fin.read((char *) &hparams.n_text_state,  sizeof(hparams.n_text_state));
-        fin.read((char *) &hparams.n_text_head,   sizeof(hparams.n_text_head));
-        fin.read((char *) &hparams.n_text_layer,  sizeof(hparams.n_text_layer));
-        fin.read((char *) &hparams.n_mels,        sizeof(hparams.n_mels));
-        fin.read((char *) &hparams.f16,           sizeof(hparams.f16));
+        read_safe(fin, hparams.n_vocab);
+        read_safe(fin, hparams.n_audio_ctx);
+        read_safe(fin, hparams.n_audio_state);
+        read_safe(fin, hparams.n_audio_head);
+        read_safe(fin, hparams.n_audio_layer);
+        read_safe(fin, hparams.n_text_ctx);
+        read_safe(fin, hparams.n_text_state);
+        read_safe(fin, hparams.n_text_head);
+        read_safe(fin, hparams.n_text_layer);
+        read_safe(fin, hparams.n_mels);
+        read_safe(fin, hparams.f16);
 
         assert(hparams.n_text_state == hparams.n_audio_state);
 
@@ -524,8 +535,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
     {
         auto & filters = wctx.model.filters;
 
-        fin.read((char *) &filters.n_mel, sizeof(filters.n_mel));
-        fin.read((char *) &filters.n_fft, sizeof(filters.n_fft));
+        read_safe(fin, filters.n_mel);
+        read_safe(fin, filters.n_fft);
 
         filters.data.resize(filters.n_mel * filters.n_fft);
         fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float));
@@ -534,7 +545,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
     // load vocab
     {
         int32_t n_vocab = 0;
-        fin.read((char *) &n_vocab, sizeof(n_vocab));
+        read_safe(fin, n_vocab);
 
         //if (n_vocab != model.hparams.n_vocab) {
         //    fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
@@ -543,12 +554,23 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
         //}
 
         std::string word;
+        std::vector<char> tmp;
+
+        tmp.reserve(128);
+
         for (int i = 0; i < n_vocab; i++) {
             uint32_t len;
-            fin.read((char *) &len, sizeof(len));
+            read_safe(fin, len);
 
-            word.resize(len);
-            fin.read((char *) word.data(), len);
+            if (len > 0) {
+                tmp.resize(len);
+                fin.read(&tmp[0], tmp.size()); // read to buffer
+                word.assign(&tmp[0], tmp.size());
+            } else {
+                // seems like we have an empty-string token in multi-language models (i = 50256)
+                //fprintf(stderr, "%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
+                word = "";
+            }
 
             vocab.token_to_id[word] = i;
             vocab.id_to_token[i] = word;
@@ -588,6 +610,11 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
                 vocab.id_to_token[i] = word;
             }
         }
+
+        wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
+        wctx.probs.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
+
+        vocab.probs_id.reserve(n_vocab);
     }
 
     {
@@ -606,7 +633,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
     const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
 
     size_t ctx_size = 0;
-    size_t ctx_mem_size = 0;
 
     {
         const auto & hparams = model.hparams;
@@ -715,12 +741,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
             ctx_size += n_text_layer*(             n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
         }
 
-        ctx_mem_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_k
-        ctx_mem_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_v
-
-        ctx_mem_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_k
-        ctx_mem_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_v
-
         ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
 
         fprintf(stderr, "%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
@@ -728,10 +748,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
 
     // create the ggml context
     {
-        struct ggml_init_params params = {
-            .mem_size   = wctx.buf_model->size(),
-            .mem_buffer = wctx.buf_model->data(),
-        };
+        struct ggml_init_params params;
+        params.mem_size   = wctx.buf_model->size();
+        params.mem_buffer = wctx.buf_model->data();
 
         model.ctx = ggml_init(params);
         if (!model.ctx) {
@@ -938,10 +957,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
 
     // create the ggml memory context
     {
-        struct ggml_init_params params = {
-            .mem_size   = wctx.buf_memory.size(),
-            .mem_buffer = wctx.buf_memory.data(),
-        };
+        struct ggml_init_params params;
+        params.mem_size   = wctx.buf_memory.size();
+        params.mem_buffer = wctx.buf_memory.data();
 
         model.ctx_mem = ggml_init(params);
         if (!model.ctx_mem) {
@@ -998,9 +1016,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
             int32_t length;
             int32_t ftype;
 
-            fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
-            fin.read(reinterpret_cast<char *>(&length), sizeof(length));
-            fin.read(reinterpret_cast<char *>(&ftype),  sizeof(ftype));
+            read_safe(fin, n_dims);
+            read_safe(fin, length);
+            read_safe(fin, ftype);
 
             if (fin.eof()) {
                 break;
@@ -1009,14 +1027,16 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
             int32_t nelements = 1;
             int32_t ne[3] = { 1, 1, 1 };
             for (int i = 0; i < n_dims; ++i) {
-                fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+                read_safe(fin, ne[i]);
                 nelements *= ne[i];
             }
 
-            std::string name(length, 0);
-            fin.read(&name[0], length);
+            std::string name;
+            std::vector<char> tmp(length); // create a buffer
+            fin.read(&tmp[0], tmp.size()); // read to buffer
+            name.assign(&tmp[0], tmp.size());
 
-            if (model.tensors.find(name.data()) == model.tensors.end()) {
+            if (model.tensors.find(name) == model.tensors.end()) {
                 fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
                 return false;
             }
@@ -1088,10 +1108,9 @@ static bool whisper_encode(
     const int n_mels = hparams.n_mels;
     assert(mel_inp.n_mel == n_mels);
 
-    struct ggml_init_params params = {
-        .mem_size   = wctx.buf_compute.size(),
-        .mem_buffer = wctx.buf_compute.data(),
-    };
+    struct ggml_init_params params;
+    params.mem_size   = wctx.buf_compute.size();
+    params.mem_buffer = wctx.buf_compute.data();
 
     struct ggml_context * ctx0 = ggml_init(params);
 
@@ -1166,10 +1185,9 @@ static bool whisper_encode(
 
         // create separate context for each layer to reduce memory usage
 
-        struct ggml_init_params paramsL = {
-            .mem_size   = wctx.buf_compute_layer.size(),
-            .mem_buffer = wctx.buf_compute_layer.data(),
-        };
+        struct ggml_init_params paramsL;
+        paramsL.mem_size   = wctx.buf_compute_layer.size();
+        paramsL.mem_buffer = wctx.buf_compute_layer.data();
 
         struct ggml_context * ctxL = ggml_init(paramsL);
 
@@ -1374,8 +1392,8 @@ static bool whisper_encode(
         // input for next layer (inpO -> inpL)
         memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
         inpL->op = GGML_OP_NONE;
-        inpL->src0 = NULL;
-        inpL->src1 = NULL;
+        inpL->src0 = nullptr;
+        inpL->src1 = nullptr;
 
         //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
 
@@ -1428,8 +1446,8 @@ static bool whisper_encode(
 
         // TODO: hack to disconnect the encoded features from the previous graph
         cur->op = GGML_OP_NONE;
-        cur->src0 = NULL;
-        cur->src1 = NULL;
+        cur->src0 = nullptr;
+        cur->src1 = nullptr;
 
         for (int il = 0; il < model.hparams.n_text_layer; ++il) {
             auto & layer = model.layers_decoder[il];
@@ -1503,10 +1521,9 @@ static bool whisper_decode(
     const int N = n_tokens;
     const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
 
-    struct ggml_init_params params = {
-            .mem_size   = wctx.buf_compute.size(),
-            .mem_buffer = wctx.buf_compute.data(),
-        };
+    struct ggml_init_params params;
+    params.mem_size   = wctx.buf_compute.size();
+    params.mem_buffer = wctx.buf_compute.data();
 
     struct ggml_context * ctx0 = ggml_init(params);
 
@@ -1529,10 +1546,9 @@ static bool whisper_decode(
     for (int il = 0; il < n_layer; ++il) {
         const auto & layer = model.layers_decoder[il];
 
-        struct ggml_init_params paramsL = {
-            .mem_size   = wctx.buf_compute_layer.size(),
-            .mem_buffer = wctx.buf_compute_layer.data(),
-        };
+        struct ggml_init_params paramsL;
+        paramsL.mem_size   = wctx.buf_compute_layer.size();
+        paramsL.mem_buffer = wctx.buf_compute_layer.data();
 
         struct ggml_context * ctxL = ggml_init(paramsL);
         struct ggml_cgraph gf = {};
@@ -1788,8 +1804,8 @@ static bool whisper_decode(
         // input for next layer (inpO -> inpL)
         memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
         inpL->op = GGML_OP_NONE;
-        inpL->src0 = NULL;
-        inpL->src1 = NULL;
+        inpL->src0 = nullptr;
+        inpL->src1 = nullptr;
 
         if (N > 1) {
             //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
@@ -1845,7 +1861,7 @@ static bool whisper_decode(
 
 // the most basic sampling scheme - select the top token
 static whisper_token_data whisper_sample_best(
-        const whisper_vocab & vocab,
+              whisper_vocab & vocab,
         const float * probs,
               bool force_timestamp,
               bool is_initial) {
@@ -1853,13 +1869,13 @@ static whisper_token_data whisper_sample_best(
         0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
     };
 
-    int n_logits = vocab.id_to_token.size();
+    const int n_logits = vocab.n_vocab;
 
-    std::vector<std::pair<double, whisper_vocab::id>> probs_id;
-    probs_id.reserve(n_logits);
+    auto & probs_id = vocab.probs_id;
 
+    probs_id.clear();
     for (int i = 0; i < n_logits; i++) {
-        probs_id.push_back(std::make_pair(probs[i], i));
+        probs_id.emplace_back(probs[i], i);
     }
 
     {
@@ -1997,6 +2013,9 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
     std::vector<float> even;
     std::vector<float> odd;
 
+    even.reserve(N/2);
+    odd.reserve(N/2);
+
     for (int i = 0; i < N; i++) {
         if (i % 2 == 0) {
             even.push_back(in[i]);
@@ -2032,7 +2051,7 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
 static bool log_mel_spectrogram(
     const float * samples,
     const int n_samples,
-    const int sample_rate,
+    const int /*sample_rate*/,
     const int fft_size,
     const int fft_step,
     const int n_mel,
@@ -2151,6 +2170,71 @@ static bool log_mel_spectrogram(
     return true;
 }
 
+// split text into tokens
+//
+// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
+//
+// Regex (Python):
+// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
+//
+// Regex (C++):
+// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
+//
+static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, const std::string & text) {
+    std::vector<std::string> words;
+
+    // first split the text into words
+    {
+        std::string str = text;
+        std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
+
+        std::regex re(pat);
+        std::smatch m;
+
+        while (std::regex_search(str, m, re)) {
+            for (auto x : m) {
+                words.push_back(x);
+            }
+            str = m.suffix();
+        }
+    }
+
+    // find the longest tokens that form the words:
+    std::vector<whisper_vocab::id> tokens;
+    for (const auto & word : words) {
+        if (word.empty()) continue;
+
+        int i = 0;
+        int n = word.size();
+        while (i < n) {
+            int j = n;
+            while (j > i) {
+                auto it = vocab.token_to_id.find(word.substr(i, j-i));
+                if (it != vocab.token_to_id.end()) {
+                    tokens.push_back(it->second);
+                    i = j;
+                    break;
+                }
+                --j;
+            }
+            if (i == n) {
+                break;
+            }
+            if (j == i) {
+                auto sub = word.substr(i, 1);
+                if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
+                    tokens.push_back(vocab.token_to_id.at(sub));
+                } else {
+                    fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
+                }
+                ++i;
+            }
+        }
+    }
+
+    return tokens;
+}
+
 //
 // interface implementation
 //
@@ -2166,7 +2250,8 @@ struct whisper_context * whisper_init(const char * path_model) {
 
     if (!whisper_model_load(path_model, *ctx)) {
         fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
-        return NULL;
+        delete ctx;
+        return nullptr;
     }
 
     ctx->t_load_us = ggml_time_us() - t_start_us;
@@ -2281,8 +2366,38 @@ struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx,
     return res;
 }
 
+int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) {
+    const auto res = tokenize(ctx->vocab, text);
+
+    if (n_max_tokens < (int) res.size()) {
+        fprintf(stderr, "%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
+        return -1;
+    }
+
+    for (int i = 0; i < (int) res.size(); i++) {
+        tokens[i] = res[i];
+    }
+
+    return res.size();
+}
+
+int whisper_lang_max_id() {
+    auto max_id = 0;
+    for (const auto & kv : g_lang) {
+        max_id = std::max(max_id, kv.second.first);
+    }
+
+    return max_id;
+}
+
 int whisper_lang_id(const char * lang) {
     if (!g_lang.count(lang)) {
+        for (const auto & kv : g_lang) {
+            if (kv.second.second == lang) {
+                return kv.second.first;
+            }
+        }
+
         fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
         return -1;
     }
@@ -2290,6 +2405,86 @@ int whisper_lang_id(const char * lang) {
     return g_lang.at(lang).first;
 }
 
+const char * whisper_lang_str(int id) {
+    for (const auto & kv : g_lang) {
+        if (kv.second.first == id) {
+            return kv.first.c_str();
+        }
+    }
+
+    fprintf(stderr, "%s: unknown language id %d\n", __func__, id);
+    return nullptr;
+}
+
+int whisper_lang_auto_detect(
+        struct whisper_context * ctx,
+        int offset_ms,
+        int n_threads,
+        float * lang_probs) {
+    const int seek = offset_ms/10;
+
+    if (seek < 0) {
+        fprintf(stderr, "%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
+        return -1;
+    }
+
+    if (seek >= ctx->mel.n_len) {
+        fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, ctx->mel.n_len*10);
+        return -2;
+    }
+
+    // run the encoder
+    if (whisper_encode(ctx, seek, n_threads) != 0) {
+        fprintf(stderr, "%s: failed to encode\n", __func__);
+        return -6;
+    }
+
+    const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
+
+    if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) {
+        fprintf(stderr, "%s: failed to decode\n", __func__);
+        return -7;
+    }
+
+    std::vector<std::pair<float, int>> probs_id;
+    for (const auto & kv : g_lang) {
+        const auto token_lang = whisper_token_lang(ctx, kv.second.first);
+        probs_id.emplace_back(ctx->probs[token_lang], kv.second.first);
+    }
+
+    // sort descending
+    {
+        using pair_type = decltype(probs_id)::value_type;
+        std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
+            return a.first > b.first;
+        });
+    }
+
+    // softmax
+    {
+        float sum = 0;
+        for (const auto & kv : probs_id) {
+            sum += exp(kv.first);
+        }
+
+        for (auto & kv : probs_id) {
+            kv.first = exp(kv.first) / sum;
+        }
+    }
+
+    {
+        for (int i = 0; i < (int) probs_id.size(); i++) {
+            if (lang_probs) {
+                lang_probs[probs_id[i].second] = probs_id[i].first;
+            }
+
+            //printf("%s: lang %2d (%3s): %f\n", __func__, probs_id[i].second, whisper_lang_str(probs_id[i].second), probs_id[i].first);
+        }
+    }
+
+    return probs_id[0].second;
+}
+
 int whisper_n_len(struct whisper_context * ctx) {
     return ctx->mel.n_len;
 }
@@ -2302,6 +2497,10 @@ int whisper_n_text_ctx(struct whisper_context * ctx) {
     return ctx->model.hparams.n_text_ctx;
 }
 
+int whisper_n_audio_ctx(struct whisper_context * ctx) {
+    return ctx->model.hparams.n_audio_ctx;
+}
+
 int whisper_is_multilingual(struct whisper_context * ctx) {
     return ctx->vocab.is_multilingual() ? 1 : 0;
 }
@@ -2338,6 +2537,10 @@ whisper_token whisper_token_beg(struct whisper_context * ctx) {
     return ctx->vocab.token_beg;
 }
 
+whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) {
+    return whisper_token_sot(ctx) + 1 + lang_id;
+}
+
 whisper_token whisper_token_translate(void) {
     return whisper_vocab::token_translate;
 }
@@ -2371,7 +2574,10 @@ const char * whisper_print_system_info(void) {
     s += "AVX = "       + std::to_string(ggml_cpu_has_avx())       + " | ";
     s += "AVX2 = "      + std::to_string(ggml_cpu_has_avx2())      + " | ";
     s += "AVX512 = "    + std::to_string(ggml_cpu_has_avx512())    + " | ";
+    s += "FMA = "       + std::to_string(ggml_cpu_has_fma())       + " | ";
     s += "NEON = "      + std::to_string(ggml_cpu_has_neon())      + " | ";
+    s += "ARM_FMA = "   + std::to_string(ggml_cpu_has_arm_fma())   + " | ";
+    s += "F16C = "      + std::to_string(ggml_cpu_has_f16c())      + " | ";
     s += "FP16_VA = "   + std::to_string(ggml_cpu_has_fp16_va())   + " | ";
     s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
     s += "BLAS = "      + std::to_string(ggml_cpu_has_blas())      + " | ";
@@ -2569,10 +2775,25 @@ int whisper_full(
     } else {
         if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
             fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
-            return -1;
+            return -2;
         }
     }
 
+    // auto-detect language if not specified
+    if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
+        std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
+
+        const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data());
+        if (lang_id < 0) {
+            fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
+            return -3;
+        }
+
+        params.language = whisper_lang_str(lang_id);
+
+        fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
+    }
+
     if (params.token_timestamps) {
         ctx->t_beg = 0;
         ctx->t_last = 0;
@@ -2605,13 +2826,18 @@ int whisper_full(
         std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
     }
 
-    // overwrite audio_ctx
+    // overwrite audio_ctx, max allowed is hparams.n_audio_ctx
+    if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
+        fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
+        return -4;
+    }
     ctx->exp_n_audio_ctx = params.audio_ctx;
 
     // these tokens determine the task that will be performed
     std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
     if (whisper_is_multilingual(ctx)) {
-        prompt_init.push_back(whisper_token_sot(ctx) + 1 + whisper_lang_id(params.language));
+        const int lang_id = whisper_lang_id(params.language);
+        prompt_init.push_back(whisper_token_lang(ctx, lang_id));
         if (params.translate) {
             prompt_init.push_back(whisper_token_translate());
         } else {
@@ -2639,10 +2865,17 @@ int whisper_full(
             }
         }
 
+        // of only 1 second left, then stop
         if (seek + 100 >= seek_end) {
             break;
         }
 
+        // if there is a very short audio segment left to process, we remove any past prompt since it tends
+        // to confuse the decoder and often make it repeat or hallucinate stuff
+        if (seek > seek_start && seek + 500 >= seek_end) {
+            prompt_past.clear();
+        }
+
         if (params.encoder_begin_callback) {
             if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
                 fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
@@ -2653,14 +2886,14 @@ int whisper_full(
         // encode audio features starting at offset seek
         if (whisper_encode(ctx, seek, params.n_threads) != 0) {
             fprintf(stderr, "%s: failed to encode\n", __func__);
-            return 7;
+            return -4;
         }
 
         int n_past = 0;
         prompt.clear();
 
         // if we have already generated some text, use it as a prompt to condition the next generation
-        if (prompt_past.size() > 0) {
+        if (!prompt_past.empty()) {
             int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
 
             prompt = { whisper_token_prev(ctx) };
@@ -2686,11 +2919,12 @@ int whisper_full(
         tokens_cur.clear();
 
         bool failed = false;
+        bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment?
 
         for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
             if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
                 fprintf(stderr, "%s: failed to decode\n", __func__);
-                return 8;
+                return -5;
             }
 
             n_past += prompt.size();
@@ -2711,13 +2945,13 @@ int whisper_full(
                     const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
 
                     // do not allow to go back in time
-                    if (seek_delta != 100*WHISPER_CHUNK_SIZE &&
-                        seek_delta > seek_delta_new && result_len < i) {
+                    if (has_ts && seek_delta > seek_delta_new && result_len < i) {
                         break;
                     }
 
                     seek_delta = seek_delta_new;
                     result_len = i + 1;
+                    has_ts = true;
                 }
 
                 // add it to the context
@@ -2726,11 +2960,14 @@ int whisper_full(
 
                 //{
                 //    const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
-                //    printf("%s: %10s %6d %6.3f '%s'\n", __func__, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
+                //    printf("%s: %3d %10s %6d %6.3f '%s'\n", __func__, i, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
                 //}
 
-                // end of text token
-                if (token.id == whisper_token_eot(ctx) || (params.max_tokens > 0 && i > params.max_tokens)) {
+                // end of segment
+                if (token.id == whisper_token_eot(ctx) ||                // end of text token
+                    (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
+                    (has_ts && seek + seek_delta + 100 >= seek_end)      // end of audio reached
+                    ) {
                     if (result_len == 0) {
                         if (seek + seek_delta + 100 >= seek_end) {
                             result_len = i + 1;
@@ -2765,8 +3002,14 @@ int whisper_full(
         }
 
         if (failed) {
-            fprintf(stderr, "\n%s: failed to generate timestamp token - using fallback strategy\n\n", __func__);
-            seek += 100;
+            // when we fail to sample timestamp token, retry by clearing the past prompt
+            // if it fails again, then we advance the window by 1 second
+            if (!prompt_past.empty()) {
+                prompt_past.clear();
+            } else {
+                fprintf(stderr, "\n%s: failed to generate timestamp token - skipping one second\n\n", __func__);
+                seek += 100;
+            }
             continue;
         }
 
@@ -2778,11 +3021,11 @@ int whisper_full(
         }
 
         // store the text from this iteration
-        if (tokens_cur.size() > 0) {
+        if (!tokens_cur.empty()) {
             int  i0 = 0;
             auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
 
-            std::string text = "";
+            std::string text;
 
             for (int i = 0; i < (int) tokens_cur.size(); i++) {
                 //printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
@@ -2901,10 +3144,9 @@ int whisper_full_parallel(
 
         // create the ggml memory context
         {
-            struct ggml_init_params params = {
-                .mem_size   = ctxs[i].buf_memory.size(),
-                .mem_buffer = ctxs[i].buf_memory.data(),
-            };
+            struct ggml_init_params params;
+            params.mem_size   = ctxs[i].buf_memory.size();
+            params.mem_buffer = ctxs[i].buf_memory.data();
 
             model.ctx_mem = ggml_init(params);
             if (!model.ctx_mem) {
@@ -2990,7 +3232,7 @@ int whisper_full_parallel(
             results_i[j].t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
 
             // make sure that segments are not overlapping
-            if (ctx->result_all.size() > 0) {
+            if (!ctx->result_all.empty()) {
                 results_i[j].t0 = std::max(results_i[j].t0, ctx->result_all.back().t1);
             }
 
index def77d4c3c2a76751c48996d4fc5f0f49c44bc52..e36b761ff6f9df3cfe30ab3ea30baf0362bb3ee8 100644 (file)
@@ -139,12 +139,45 @@ extern "C" {
     WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
     WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial);
 
+    // Convert the provided text into tokens.
+    // The tokens pointer must be large enough to hold the resulting tokens.
+    // Returns the number of tokens on success, no more than n_max_tokens
+    // Returns -1 on failure
+    // TODO: not sure if correct
+    WHISPER_API int whisper_tokenize(
+            struct whisper_context * ctx,
+                        const char * text,
+                     whisper_token * tokens,
+                                  int   n_max_tokens);
+
+    // Largest language id (i.e. number of available languages - 1)
+    WHISPER_API int whisper_lang_max_id();
+
     // Return the id of the specified language, returns -1 if not found
+    // Examples:
+    //   "de" -> 2
+    //   "german" -> 2
     WHISPER_API int whisper_lang_id(const char * lang);
 
+    // Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found
+    WHISPER_API const char * whisper_lang_str(int id);
+
+    // Use mel data at offset_ms to try and auto-detect the spoken language
+    // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first
+    // Returns the top language id or negative on failure
+    // If not null, fills the lang_probs array with the probabilities of all languages
+    // The array must be whispe_lang_max_id() + 1 in size
+    // ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
+    WHISPER_API int whisper_lang_auto_detect(
+            struct whisper_context * ctx,
+                               int   offset_ms,
+                               int   n_threads,
+                             float * lang_probs);
+
     WHISPER_API int whisper_n_len          (struct whisper_context * ctx); // mel length
     WHISPER_API int whisper_n_vocab        (struct whisper_context * ctx);
     WHISPER_API int whisper_n_text_ctx     (struct whisper_context * ctx);
+    WHISPER_API int whisper_n_audio_ctx    (struct whisper_context * ctx);
     WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
 
     // The probabilities for the next token
@@ -160,6 +193,7 @@ extern "C" {
     WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx);
     WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx);
     WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx);
+    WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id);
 
     // Task tokens
     WHISPER_API whisper_token whisper_token_translate (void);
@@ -225,6 +259,7 @@ extern "C" {
         const whisper_token * prompt_tokens;
         int prompt_n_tokens;
 
+        // for auto-detection, set to nullptr, "" or "auto"
         const char * language;
 
         struct {
index 3e4e962a69ec166538a817876f2cbe0c391a89f3..a217d2d5f94335ea64612ff2e6f254317cd1d8d6 100644 (file)
@@ -681,34 +681,32 @@ struct ggml_opt_params {
     bool print_forward_graph;
     bool print_backward_graph;
 
-    union {
-        // ADAM parameters
-        struct {
-            int n_iter;
-
-            float alpha; // learning rate
-            float beta1;
-            float beta2;
-            float eps;   // epsilon for numerical stability
-            float eps_f; // epsilon for convergence test
-            float eps_g; // epsilon for convergence test
-        } adam;
-
-        // LBFGS parameters
-        struct {
-            int m; // number of corrections to approximate the inv. Hessian
-            int n_iter;
-            int max_linesearch;
-
-            float eps;      // convergence tolerance
-            float ftol;     // line search tolerance
-            float wolfe;
-            float min_step;
-            float max_step;
-
-            enum ggml_linesearch linesearch;
-        } lbfgs;
-    };
+    // ADAM parameters
+    struct {
+        int n_iter;
+
+        float alpha; // learning rate
+        float beta1;
+        float beta2;
+        float eps;   // epsilon for numerical stability
+        float eps_f; // epsilon for convergence test
+        float eps_g; // epsilon for convergence test
+    } adam;
+
+    // LBFGS parameters
+    struct {
+        int m; // number of corrections to approximate the inv. Hessian
+        int n_iter;
+        int max_linesearch;
+
+        float eps;      // convergence tolerance
+        float ftol;     // line search tolerance
+        float wolfe;
+        float min_step;
+        float max_step;
+
+        enum ggml_linesearch linesearch;
+    } lbfgs;
 };
 
 struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
@@ -726,7 +724,10 @@ enum ggml_opt_result ggml_opt(
 int ggml_cpu_has_avx(void);
 int ggml_cpu_has_avx2(void);
 int ggml_cpu_has_avx512(void);
+int ggml_cpu_has_fma(void);
 int ggml_cpu_has_neon(void);
+int ggml_cpu_has_arm_fma(void);
+int ggml_cpu_has_f16c(void);
 int ggml_cpu_has_fp16_va(void);
 int ggml_cpu_has_wasm_simd(void);
 int ggml_cpu_has_blas(void);
index b6d528d9fd6dd414351808cb648c5a774b95c70e..7d2f465258d4baf1f1193a6e771e98f47c7f5c27 100644 (file)
 #include <stdint.h>
 #include <stdio.h>
 
+// if C99 - static_assert is noop
+// ref: https://stackoverflow.com/a/53923785/4039976
+#ifndef static_assert
+#define static_assert(cond, msg) struct global_scope_noop_trick
+#endif
+
 #if defined _MSC_VER || defined(__MINGW32__)
 
 #if !defined(__MINGW32__)
@@ -69,6 +75,10 @@ static int sched_yield (void) {
 typedef void* thread_ret_t;
 #endif
 
+#ifdef __HAIKU__
+#define static_assert(cond, msg) _Static_assert(cond, msg)
+#endif
+
 #define GGML_DEBUG 0
 #define GGML_GELU_FP16
 
@@ -120,13 +130,35 @@ ggml_fp16_t ggml_fp32_to_fp16(float x) {
     return x;
 }
 
+#define GGML_FP16_TO_FP32(x) (x)
+#define GGML_FP32_TO_FP16(x) (x)
+
 #else
 
 #ifdef __wasm_simd128__
 #include <wasm_simd128.h>
 #else
+#ifdef __POWER9_VECTOR__
+#include <altivec.h>
+#undef bool
+#define bool _Bool
+#else
 #include <immintrin.h>
 #endif
+#endif
+
+#ifdef __F16C__
+float ggml_fp16_to_fp32(ggml_fp16_t h) {
+    return _cvtsh_ss(h);
+}
+ggml_fp16_t ggml_fp32_to_fp16(float f) {
+    return _cvtss_sh(f, 0);
+}
+
+#define GGML_FP16_TO_FP32(x) _cvtsh_ss(x)
+#define GGML_FP32_TO_FP16(x) _cvtss_sh(x, 0)
+
+#else
 
 // FP16 <-> FP32
 // ref: https://github.com/Maratyszcza/FP16
@@ -135,7 +167,8 @@ static inline float fp32_from_bits(uint32_t w) {
     union {
         uint32_t as_bits;
         float as_value;
-    } fp32 = { w };
+    } fp32;
+    fp32.as_bits = w;
     return fp32.as_value;
 }
 
@@ -143,7 +176,8 @@ static inline uint32_t fp32_to_bits(float f) {
        union {
                float as_value;
                uint32_t as_bits;
-       } fp32 = { f };
+       } fp32;
+       fp32.as_value = f;
        return fp32.as_bits;
 }
 
@@ -195,7 +229,13 @@ ggml_fp16_t ggml_fp32_to_fp16(float f) {
     const uint32_t nonsign = exp_bits + mantissa_bits;
     return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
 }
-#endif
+
+#define GGML_FP16_TO_FP32(x) ggml_fp16_to_fp32(x)
+#define GGML_FP32_TO_FP16(x) ggml_fp32_to_fp16(x)
+
+#endif // __F16C__
+
+#endif // __ARM_NEON
 
 //
 // global data
@@ -273,196 +313,429 @@ int64_t ggml_cycles_per_ms(void) {
 #define CACHE_LINE_SIZE 64
 #endif
 
-const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
+static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
 
 //
-// fundamental operations
+// simd mappings
 //
 
-inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
-
-inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
-
-inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
-
-inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
-
-inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] + y[i]; }
-inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i] += x[i];        }
-inline static void ggml_vec_acc1_f32(const int n, float * y, const float   v)                  { for (int i = 0; i < n; ++i) y[i] += v;           }
-inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] - y[i]; }
-inline static void ggml_vec_set_f32 (const int n, float * x, const float   v)                  { for (int i = 0; i < n; ++i) x[i]  = v;           }
-inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = x[i];        }
-inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = -x[i];       }
-inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]*y[i];   }
-inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]/y[i];   }
-
-inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
-    ggml_float sumf = 0.0;
-#ifdef __ARM_NEON
-    // NEON 128-bit
-    const int n16 = (n & ~15);
-
-    float32x4_t sum0 = vdupq_n_f32(0);
-    float32x4_t sum1 = vdupq_n_f32(0);
-    float32x4_t sum2 = vdupq_n_f32(0);
-    float32x4_t sum3 = vdupq_n_f32(0);
-
-    float32x4_t x0, x1, x2, x3;
-    float32x4_t y0, y1, y2, y3;
-
-    for (int i = 0; i < n16; i += 16) {
-        x0 = vld1q_f32(x + i + 0);
-        x1 = vld1q_f32(x + i + 4);
-        x2 = vld1q_f32(x + i + 8);
-        x3 = vld1q_f32(x + i + 12);
-
-        y0 = vld1q_f32(y + i + 0);
-        y1 = vld1q_f32(y + i + 4);
-        y2 = vld1q_f32(y + i + 8);
-        y3 = vld1q_f32(y + i + 12);
+// we define a common set of C macros which map to specific intrinsics based on the current architecture
+// we then implement the fundamental computation operations below using only these macros
+// adding support for new architectures requires to define the corresponding SIMD macros
+//
+// GGML_F32_STEP / GGML_F16_STEP
+//   number of elements to process in a single step
+//
+// GGML_F32_EPR / GGML_F16_EPR
+//   number of elements to fit in a single register
+//
 
-        sum0 = vfmaq_f32(sum0, x0, y0);
-        sum1 = vfmaq_f32(sum1, x1, y1);
-        sum2 = vfmaq_f32(sum2, x2, y2);
-        sum3 = vfmaq_f32(sum3, x3, y3);
-    }
+#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)
 
-    // reduce sum0..sum3 to sum0
-    sum0 = vaddq_f32(sum0, sum1);
-    sum2 = vaddq_f32(sum2, sum3);
-    sum0 = vaddq_f32(sum0, sum2);
+#define GGML_SIMD
 
-    float32x2_t sumf32 = vadd_f32(vget_low_f32(sum0), vget_high_f32(sum0));
-    sumf = vget_lane_f32(sumf32, 0) + vget_lane_f32(sumf32, 1);
+// F32 NEON
 
-    // leftovers
-    for (int i = n16; i < n; ++i) {
-        sumf += x[i]*y[i];
-    }
-#elif defined(__AVX2__)
-    // AVX 256-bit
-    const int n32 = (n & ~31);
+#define GGML_F32_STEP 16
+#define GGML_F32_EPR  4
 
-    __m256 sum0 = _mm256_setzero_ps();
-    __m256 sum1 = _mm256_setzero_ps();
-    __m256 sum2 = _mm256_setzero_ps();
-    __m256 sum3 = _mm256_setzero_ps();
+#define GGML_F32x4              float32x4_t
+#define GGML_F32x4_ZERO         vdupq_n_f32(0.0f)
+#define GGML_F32x4_SET1(x)      vdupq_n_f32(x)
+#define GGML_F32x4_LOAD         vld1q_f32
+#define GGML_F32x4_STORE        vst1q_f32
+#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
+#define GGML_F32x4_ADD          vaddq_f32
+#define GGML_F32x4_MUL          vmulq_f32
+#if defined(__ARM_FEATURE_QRDMX)
+    #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
+#else
+    #define GGML_F32x4_REDUCE_ONE(x) \
+    (vgetq_lane_f32(x, 0) +          \
+     vgetq_lane_f32(x, 1) +          \
+     vgetq_lane_f32(x, 2) +          \
+     vgetq_lane_f32(x, 3))
+#endif
+#define GGML_F32x4_REDUCE(res, x)              \
+{                                              \
+    for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
+        x[2*i] = vaddq_f32(x[2*i], x[2*i+1]);  \
+    }                                          \
+    for (int i = 0; i < GGML_F32_ARR/4; ++i) { \
+        x[4*i] = vaddq_f32(x[4*i], x[4*i+2]);  \
+    }                                          \
+    for (int i = 0; i < GGML_F32_ARR/8; ++i) { \
+        x[8*i] = vaddq_f32(x[8*i], x[8*i+4]);  \
+    }                                          \
+    res = GGML_F32x4_REDUCE_ONE(x[0]);         \
+}
+
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 NEON
 
-    __m256 x0, x1, x2, x3;
-    __m256 y0, y1, y2, y3;
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+    #define GGML_F16_STEP 32
+    #define GGML_F16_EPR  8
+
+    #define GGML_F16x8              float16x8_t
+    #define GGML_F16x8_ZERO         vdupq_n_f16(0.0f)
+    #define GGML_F16x8_SET1(x)      vdupq_n_f16(x)
+    #define GGML_F16x8_LOAD         vld1q_f16
+    #define GGML_F16x8_STORE        vst1q_f16
+    #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
+    #define GGML_F16x8_ADD          vaddq_f16
+    #define GGML_F16x8_MUL          vmulq_f16
+    #define GGML_F16x8_REDUCE(res, x)                             \
+    {                                                             \
+        for (int i = 0; i < GGML_F16_ARR/2; ++i) {                \
+            x[2*i] = vaddq_f16(x[2*i], x[2*i+1]);                 \
+        }                                                         \
+        for (int i = 0; i < GGML_F16_ARR/4; ++i) {                \
+            x[4*i] = vaddq_f16(x[4*i], x[4*i+2]);                 \
+        }                                                         \
+        for (int i = 0; i < GGML_F16_ARR/8; ++i) {                \
+            x[8*i] = vaddq_f16(x[8*i], x[8*i+4]);                 \
+        }                                                         \
+        const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
+        const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
+        res = vaddvq_f32(vaddq_f32(t0, t1));                      \
+    }
+
+    #define GGML_F16_VEC        GGML_F16x8
+    #define GGML_F16_VEC_ZERO   GGML_F16x8_ZERO
+    #define GGML_F16_VEC_SET1   GGML_F16x8_SET1
+    #define GGML_F16_VEC_LOAD   GGML_F16x8_LOAD
+    #define GGML_F16_VEC_STORE  GGML_F16x8_STORE
+    #define GGML_F16_VEC_FMA    GGML_F16x8_FMA
+    #define GGML_F16_VEC_ADD    GGML_F16x8_ADD
+    #define GGML_F16_VEC_MUL    GGML_F16x8_MUL
+    #define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE
+#else
+    // if FP16 vector arithmetic is not supported, we use FP32 instead
+    // and take advantage of the vcvt_ functions to convert to/from FP16
+
+    #define GGML_F16_STEP 16
+    #define GGML_F16_EPR  4
+
+    #define GGML_F32Cx4              float32x4_t
+    #define GGML_F32Cx4_ZERO         vdupq_n_f32(0.0f)
+    #define GGML_F32Cx4_SET1(x)      vdupq_n_f32(x)
+    #define GGML_F32Cx4_LOAD(x)      vcvt_f32_f16(vld1_f16(x))
+    #define GGML_F32Cx4_STORE(x, y)  vst1_f16(x, vcvt_f16_f32(y))
+    #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)
+    #define GGML_F32Cx4_ADD          vaddq_f32
+    #define GGML_F32Cx4_MUL          vmulq_f32
+    #define GGML_F32Cx4_REDUCE       GGML_F32x4_REDUCE
+
+    #define GGML_F16_VEC        GGML_F32Cx4
+    #define GGML_F16_VEC_ZERO   GGML_F32Cx4_ZERO
+    #define GGML_F16_VEC_SET1   GGML_F32Cx4_SET1
+    #define GGML_F16_VEC_LOAD   GGML_F32Cx4_LOAD
+    #define GGML_F16_VEC_STORE  GGML_F32Cx4_STORE
+    #define GGML_F16_VEC_FMA    GGML_F32Cx4_FMA
+    #define GGML_F16_VEC_ADD    GGML_F32Cx4_ADD
+    #define GGML_F16_VEC_MUL    GGML_F32Cx4_MUL
+    #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
+#endif
 
-    for (int i = 0; i < n32; i += 32) {
-        x0 = _mm256_loadu_ps(x + i + 0);
-        x1 = _mm256_loadu_ps(x + i + 8);
-        x2 = _mm256_loadu_ps(x + i + 16);
-        x3 = _mm256_loadu_ps(x + i + 24);
+#elif defined(__AVX__)
 
-        y0 = _mm256_loadu_ps(y + i + 0);
-        y1 = _mm256_loadu_ps(y + i + 8);
-        y2 = _mm256_loadu_ps(y + i + 16);
-        y3 = _mm256_loadu_ps(y + i + 24);
+#define GGML_SIMD
 
-        sum0 = _mm256_fmadd_ps(x0, y0, sum0);
-        sum1 = _mm256_fmadd_ps(x1, y1, sum1);
-        sum2 = _mm256_fmadd_ps(x2, y2, sum2);
-        sum3 = _mm256_fmadd_ps(x3, y3, sum3);
-    }
+// F32 AVX
 
-    sum0 = _mm256_add_ps(sum0, sum1);
-    sum2 = _mm256_add_ps(sum2, sum3);
-    sum0 = _mm256_add_ps(sum0, sum2);
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR  8
 
-    const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0), _mm256_extractf128_ps(sum0, 1));
-    const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4));
-    const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2));
+#define GGML_F32x8         __m256
+#define GGML_F32x8_ZERO    _mm256_setzero_ps()
+#define GGML_F32x8_SET1(x) _mm256_set1_ps(x)
+#define GGML_F32x8_LOAD    _mm256_loadu_ps
+#define GGML_F32x8_STORE   _mm256_storeu_ps
+#if defined(__FMA__)
+    #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a)
+#else
+    #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a)
+#endif
+#define GGML_F32x8_ADD     _mm256_add_ps
+#define GGML_F32x8_MUL     _mm256_mul_ps
+#define GGML_F32x8_REDUCE(res, x)                                 \
+{                                                                 \
+    for (int i = 0; i < GGML_F32_ARR/2; ++i) {                    \
+        x[2*i] = _mm256_add_ps(x[2*i], x[2*i+1]);                 \
+    }                                                             \
+    for (int i = 0; i < GGML_F32_ARR/4; ++i) {                    \
+        x[4*i] = _mm256_add_ps(x[4*i], x[4*i+2]);                 \
+    }                                                             \
+    for (int i = 0; i < GGML_F32_ARR/8; ++i) {                    \
+        x[8*i] = _mm256_add_ps(x[8*i], x[8*i+4]);                 \
+    }                                                             \
+    const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]),    \
+                                 _mm256_extractf128_ps(x[0], 1)); \
+    const __m128 t1 = _mm_hadd_ps(t0, t0);                        \
+    res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));                     \
+}
+// TODO: is this optimal ?
+
+#define GGML_F32_VEC        GGML_F32x8
+#define GGML_F32_VEC_ZERO   GGML_F32x8_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x8_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x8_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x8_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x8_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x8_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x8_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
+
+// F16 AVX
+
+#define GGML_F16_STEP 32
+#define GGML_F16_EPR  8
+
+// F16 arithmetic is not supported by AVX, so we use F32 instead
+// we take advantage of the _mm256_cvt intrinsics to convert F16 <-> F32
+
+#define GGML_F32Cx8             __m256
+#define GGML_F32Cx8_ZERO        _mm256_setzero_ps()
+#define GGML_F32Cx8_SET1(x)     _mm256_set1_ps(x)
+#define GGML_F32Cx8_LOAD(x)     _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
+#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
+#define GGML_F32Cx8_FMA         GGML_F32x8_FMA
+#define GGML_F32Cx8_ADD         _mm256_add_ps
+#define GGML_F32Cx8_MUL         _mm256_mul_ps
+#define GGML_F32Cx8_REDUCE      GGML_F32x8_REDUCE
+
+#define GGML_F16_VEC        GGML_F32Cx8
+#define GGML_F16_VEC_ZERO   GGML_F32Cx8_ZERO
+#define GGML_F16_VEC_SET1   GGML_F32Cx8_SET1
+#define GGML_F16_VEC_LOAD   GGML_F32Cx8_LOAD
+#define GGML_F16_VEC_STORE  GGML_F32Cx8_STORE
+#define GGML_F16_VEC_FMA    GGML_F32Cx8_FMA
+#define GGML_F16_VEC_ADD    GGML_F32Cx8_ADD
+#define GGML_F16_VEC_MUL    GGML_F32Cx8_MUL
+#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
+
+#elif defined(__POWER9_VECTOR__)
+
+// TODO: uncomment this when it works
+//#define GGML_SIMD
+
+// F32 POWER9
+
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR  8
+
+// TODO: not tested !!
+#define GGML_F32x4         __vector float
+#define GGML_F32x4_ZERO    (__vector float){0.0f, 0.0f, 0.0f, 0.0f}
+#define GGML_F32x4_SET1(x) (__vector float){x, x, x, x}
+#define GGML_F32x4_LOAD    vec_vsx_ld
+#define GGML_F32x4_STORE   vec_vsx_st
+#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
+#define GGML_F32x4_ADD     vec_add
+#define GGML_F32x4_MUL     vec_mul
+#define GGML_F32x4_REDUCE(res, x)              \
+{                                              \
+    for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
+        x[2*i] = vec_add(x[2*i], x[2*i+1]);    \
+    }                                          \
+    for (int i = 0; i < GGML_F32_ARR/4; ++i) { \
+        x[4*i] = vec_add(x[4*i], x[4*i+2]);    \
+    }                                          \
+    for (int i = 0; i < GGML_F32_ARR/8; ++i) { \
+        x[8*i] = vec_add(x[8*i], x[8*i+4]);    \
+    }                                          \
+    res = vec_extract(x[0], 0) +               \
+          vec_extract(x[0], 1) +               \
+          vec_extract(x[0], 2) +               \
+          vec_extract(x[0], 3);                \
+}
+
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 POWER9
+// TODO: implement here
+// ...
 
-    sumf = _mm_cvtss_f32(r1);
+#elif defined(__wasm_simd128__)
 
-    // leftovers
-    for (int i = n32; i < n; ++i) {
-        sumf += x[i]*y[i];
-    }
-#elif defined(__AVX__)
-    // AVX 256-bit
-    const int n32 = (n & ~31);
+#define GGML_SIMD
+
+// F32 WASM
+
+#define GGML_F32_STEP 16
+#define GGML_F32_EPR  4
+
+#define GGML_F32x4              v128_t
+#define GGML_F32x4_ZERO         wasm_f32x4_splat(0.0f)
+#define GGML_F32x4_SET1(x)      wasm_f32x4_splat(x)
+#define GGML_F32x4_LOAD         wasm_v128_load
+#define GGML_F32x4_STORE        wasm_v128_store
+#define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a)
+#define GGML_F32x4_ADD          wasm_f32x4_add
+#define GGML_F32x4_MUL          wasm_f32x4_mul
+#define GGML_F32x4_REDUCE(res, x)                  \
+{                                                  \
+    for (int i = 0; i < GGML_F32_ARR/2; ++i) {     \
+        x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \
+    }                                              \
+    for (int i = 0; i < GGML_F32_ARR/4; ++i) {     \
+        x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \
+    }                                              \
+    for (int i = 0; i < GGML_F32_ARR/8; ++i) {     \
+        x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \
+    }                                              \
+    res = wasm_f32x4_extract_lane(x[0], 0) +       \
+          wasm_f32x4_extract_lane(x[0], 1) +       \
+          wasm_f32x4_extract_lane(x[0], 2) +       \
+          wasm_f32x4_extract_lane(x[0], 3);        \
+}
+
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 WASM
+
+#define GGML_F16_STEP 16
+#define GGML_F16_EPR  4
+
+inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) {
+    float tmp[4];
+
+    tmp[0] = GGML_FP16_TO_FP32(p[0]);
+    tmp[1] = GGML_FP16_TO_FP32(p[1]);
+    tmp[2] = GGML_FP16_TO_FP32(p[2]);
+    tmp[3] = GGML_FP16_TO_FP32(p[3]);
+
+    return wasm_v128_load(tmp);
+}
+
+inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
+    float tmp[4];
+
+    wasm_v128_store(tmp, x);
+
+    p[0] = GGML_FP32_TO_FP16(tmp[0]);
+    p[1] = GGML_FP32_TO_FP16(tmp[1]);
+    p[2] = GGML_FP32_TO_FP16(tmp[2]);
+    p[3] = GGML_FP32_TO_FP16(tmp[3]);
+}
+
+#define GGML_F16x4             v128_t
+#define GGML_F16x4_ZERO        wasm_f32x4_splat(0.0f)
+#define GGML_F16x4_SET1(x)     wasm_f32x4_splat(x)
+#define GGML_F16x4_LOAD(x)     __wasm_f16x4_load(x)
+#define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y)
+#define GGML_F16x4_FMA         GGML_F32x4_FMA
+#define GGML_F16x4_ADD         wasm_f32x4_add
+#define GGML_F16x4_MUL         wasm_f32x4_mul
+#define GGML_F16x4_REDUCE(res, x)                  \
+{                                                  \
+    for (int i = 0; i < GGML_F16_ARR/2; ++i) {     \
+        x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \
+    }                                              \
+    for (int i = 0; i < GGML_F16_ARR/4; ++i) {     \
+        x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \
+    }                                              \
+    for (int i = 0; i < GGML_F16_ARR/8; ++i) {     \
+        x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \
+    }                                              \
+    res = wasm_f32x4_extract_lane(x[0], 0) +       \
+          wasm_f32x4_extract_lane(x[0], 1) +       \
+          wasm_f32x4_extract_lane(x[0], 2) +       \
+          wasm_f32x4_extract_lane(x[0], 3);        \
+}
+
+#define GGML_F16_VEC        GGML_F16x4
+#define GGML_F16_VEC_ZERO   GGML_F16x4_ZERO
+#define GGML_F16_VEC_SET1   GGML_F16x4_SET1
+#define GGML_F16_VEC_LOAD   GGML_F16x4_LOAD
+#define GGML_F16_VEC_STORE  GGML_F16x4_STORE
+#define GGML_F16_VEC_FMA    GGML_F16x4_FMA
+#define GGML_F16_VEC_ADD    GGML_F16x4_ADD
+#define GGML_F16_VEC_MUL    GGML_F16x4_MUL
+#define GGML_F16_VEC_REDUCE GGML_F16x4_REDUCE
 
-    __m256 sum0 = _mm256_setzero_ps();
-    __m256 sum1 = _mm256_setzero_ps();
-    __m256 sum2 = _mm256_setzero_ps();
-    __m256 sum3 = _mm256_setzero_ps();
+#endif
 
-    __m256 x0, x1, x2, x3;
-    __m256 y0, y1, y2, y3;
+// GGML_F32_ARR / GGML_F16_ARR
+//   number of registers to use per step
+#ifdef GGML_SIMD
+#define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR)
+#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)
+#endif
 
-    for (int i = 0; i < n32; i += 32) {
-        x0 = _mm256_loadu_ps(x + i + 0);
-        x1 = _mm256_loadu_ps(x + i + 8);
-        x2 = _mm256_loadu_ps(x + i + 16);
-        x3 = _mm256_loadu_ps(x + i + 24);
+//
+// fundamental operations
+//
 
-        y0 = _mm256_loadu_ps(y + i + 0);
-        y1 = _mm256_loadu_ps(y + i + 8);
-        y2 = _mm256_loadu_ps(y + i + 16);
-        y3 = _mm256_loadu_ps(y + i + 24);
+inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
 
-       sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0);
-       sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1);
-       sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2);
-       sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3);
-    }
+inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
 
-    sum0 = _mm256_add_ps(sum0, sum1);
-    sum2 = _mm256_add_ps(sum2, sum3);
-    sum0 = _mm256_add_ps(sum0, sum2);
+inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
 
-    const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0), _mm256_extractf128_ps(sum0, 1));
-    const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4));
-    const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2));
+inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
 
-    sumf = _mm_cvtss_f32(r1);
+inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] + y[i]; }
+inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i] += x[i];        }
+inline static void ggml_vec_acc1_f32(const int n, float * y, const float   v)                  { for (int i = 0; i < n; ++i) y[i] += v;           }
+inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] - y[i]; }
+inline static void ggml_vec_set_f32 (const int n, float * x, const float   v)                  { for (int i = 0; i < n; ++i) x[i]  = v;           }
+inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = x[i];        }
+inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = -x[i];       }
+inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]*y[i];   }
+inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]/y[i];   }
 
-    // leftovers
-    for (int i = n32; i < n; ++i) {
-        sumf += x[i]*y[i];
-    }
-#elif defined(__wasm_simd128__)
-    // WASM 128-bit
-    const int n16 = (n & ~15);
+inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
+    ggml_float sumf = 0.0;
 
-    v128_t sum0 = wasm_f32x4_splat(0);
-    v128_t sum1 = wasm_f32x4_splat(0);
-    v128_t sum2 = wasm_f32x4_splat(0);
-    v128_t sum3 = wasm_f32x4_splat(0);
+#ifdef GGML_SIMD
+    const int np = (n & ~(GGML_F32_STEP - 1));
 
-    v128_t x0, x1, x2, x3;
-    v128_t y0, y1, y2, y3;
+    GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
 
-    for (int i = 0; i < n16; i += 16) {
-        x0 = wasm_v128_load(x + i + 0);
-        x1 = wasm_v128_load(x + i + 4);
-        x2 = wasm_v128_load(x + i + 8);
-        x3 = wasm_v128_load(x + i + 12);
+    GGML_F32_VEC ax[GGML_F32_ARR];
+    GGML_F32_VEC ay[GGML_F32_ARR];
 
-        y0 = wasm_v128_load(y + i + 0);
-        y1 = wasm_v128_load(y + i + 4);
-        y2 = wasm_v128_load(y + i + 8);
-        y3 = wasm_v128_load(y + i + 12);
+    for (int i = 0; i < np; i += GGML_F32_STEP) {
+        for (int j = 0; j < GGML_F32_ARR; j++) {
+            ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
+            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
 
-        sum0 = wasm_f32x4_add(sum0, wasm_f32x4_mul(x0, y0));
-        sum1 = wasm_f32x4_add(sum1, wasm_f32x4_mul(x1, y1));
-        sum2 = wasm_f32x4_add(sum2, wasm_f32x4_mul(x2, y2));
-        sum3 = wasm_f32x4_add(sum3, wasm_f32x4_mul(x3, y3));
+            sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
+        }
     }
 
-    sum0 = wasm_f32x4_add(sum0, sum1);
-    sum2 = wasm_f32x4_add(sum2, sum3);
-    sum0 = wasm_f32x4_add(sum0, sum2);
-
-    sumf = wasm_f32x4_extract_lane(sum0, 0) + wasm_f32x4_extract_lane(sum0, 1) + wasm_f32x4_extract_lane(sum0, 2) + wasm_f32x4_extract_lane(sum0, 3);
+    // reduce sum0..sum3 to sum0
+    GGML_F32_VEC_REDUCE(sumf, sum);
 
     // leftovers
-    for (int i = n16; i < n; ++i) {
+    for (int i = np; i < n; ++i) {
         sumf += x[i]*y[i];
     }
 #else
@@ -477,248 +750,87 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
 
 inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
     ggml_float sumf = 0.0;
-#ifdef __ARM_NEON
-    const int n32 = (n & ~31);
 
-#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
-    float16x8_t sum0 = vdupq_n_f16(0);
-    float16x8_t sum1 = vdupq_n_f16(0);
-    float16x8_t sum2 = vdupq_n_f16(0);
-    float16x8_t sum3 = vdupq_n_f16(0);
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F16_STEP - 1));
 
-    float16x8_t x0, x1, x2, x3;
-    float16x8_t y0, y1, y2, y3;
+    GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
 
-    for (int i = 0; i < n32; i += 32) {
-        x0 = vld1q_f16(x + i + 0 );
-        x1 = vld1q_f16(x + i + 8 );
-        x2 = vld1q_f16(x + i + 16);
-        x3 = vld1q_f16(x + i + 24);
+    GGML_F16_VEC ax[GGML_F16_ARR];
+    GGML_F16_VEC ay[GGML_F16_ARR];
 
-        y0 = vld1q_f16(y + i + 0 );
-        y1 = vld1q_f16(y + i + 8 );
-        y2 = vld1q_f16(y + i + 16);
-        y3 = vld1q_f16(y + i + 24);
+    for (int i = 0; i < np; i += GGML_F16_STEP) {
+        for (int j = 0; j < GGML_F16_ARR; j++) {
+            ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR);
+            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR);
 
-        sum0 = vfmaq_f16(sum0, x0, y0);
-        sum1 = vfmaq_f16(sum1, x1, y1);
-        sum2 = vfmaq_f16(sum2, x2, y2);
-        sum3 = vfmaq_f16(sum3, x3, y3);
+            sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
+        }
     }
 
     // reduce sum0..sum3 to sum0
-    sum0 = vaddq_f16(sum0, sum1);
-    sum2 = vaddq_f16(sum2, sum3);
-    sum0 = vaddq_f16(sum0, sum2);
-
-    // load sum0 into 2 float32x4_t
-    float32x4_t sum0f32 = vcvt_f32_f16(vget_low_f16(sum0));
-    float32x4_t sum1f32 = vcvt_f32_f16(vget_high_f16(sum0));
-
-    // reduce sum0f32 and sum1f32 to sumf
-    sum0f32 = vaddq_f32(sum0f32, sum1f32);
-
-    float32x2_t sumf32 = vadd_f32(vget_low_f32(sum0f32), vget_high_f32(sum0f32));
-    sumf = vget_lane_f32(sumf32, 0) + vget_lane_f32(sumf32, 1);
-#else
-    float32x4_t sum0 = vdupq_n_f32(0);
-    float32x4_t sum1 = vdupq_n_f32(0);
-    float32x4_t sum2 = vdupq_n_f32(0);
-    float32x4_t sum3 = vdupq_n_f32(0);
-    float32x4_t sum4 = vdupq_n_f32(0);
-    float32x4_t sum5 = vdupq_n_f32(0);
-    float32x4_t sum6 = vdupq_n_f32(0);
-    float32x4_t sum7 = vdupq_n_f32(0);
-
-    float32x4_t x0, x1, x2, x3, x4, x5, x6, x7;
-    float32x4_t y0, y1, y2, y3, y4, y5, y6, y7;
-
-    for (int i = 0; i < n32; i += 32) {
-        x0 = vcvt_f32_f16(vld1_f16(x + i + 0 ));
-        x1 = vcvt_f32_f16(vld1_f16(x + i + 4 ));
-        x2 = vcvt_f32_f16(vld1_f16(x + i + 8 ));
-        x3 = vcvt_f32_f16(vld1_f16(x + i + 12));
-        x4 = vcvt_f32_f16(vld1_f16(x + i + 16));
-        x5 = vcvt_f32_f16(vld1_f16(x + i + 20));
-        x6 = vcvt_f32_f16(vld1_f16(x + i + 24));
-        x7 = vcvt_f32_f16(vld1_f16(x + i + 28));
-
-        y0 = vcvt_f32_f16(vld1_f16(y + i + 0 ));
-        y1 = vcvt_f32_f16(vld1_f16(y + i + 4 ));
-        y2 = vcvt_f32_f16(vld1_f16(y + i + 8 ));
-        y3 = vcvt_f32_f16(vld1_f16(y + i + 12));
-        y4 = vcvt_f32_f16(vld1_f16(y + i + 16));
-        y5 = vcvt_f32_f16(vld1_f16(y + i + 20));
-        y6 = vcvt_f32_f16(vld1_f16(y + i + 24));
-        y7 = vcvt_f32_f16(vld1_f16(y + i + 28));
-
-        sum0 = vfmaq_f32(sum0, x0, y0);
-        sum1 = vfmaq_f32(sum1, x1, y1);
-        sum2 = vfmaq_f32(sum2, x2, y2);
-        sum3 = vfmaq_f32(sum3, x3, y3);
-        sum4 = vfmaq_f32(sum4, x4, y4);
-        sum5 = vfmaq_f32(sum5, x5, y5);
-        sum6 = vfmaq_f32(sum6, x6, y6);
-        sum7 = vfmaq_f32(sum7, x7, y7);
-    }
-
-    // reduce sum0..sum7 to sum0
-    sum0 = vaddq_f32(sum0, sum1);
-    sum2 = vaddq_f32(sum2, sum3);
-    sum4 = vaddq_f32(sum4, sum5);
-    sum6 = vaddq_f32(sum6, sum7);
-    sum0 = vaddq_f32(sum0, sum2);
-    sum4 = vaddq_f32(sum4, sum6);
-    sum0 = vaddq_f32(sum0, sum4);
-
-    // reduce sum0 to sumf
-    float32x2_t sumf32 = vadd_f32(vget_low_f32(sum0), vget_high_f32(sum0));
-    sumf = vget_lane_f32(sumf32, 0) + vget_lane_f32(sumf32, 1);
-#endif
-
-    // leftovers
-    for (int i = n32; i < n; ++i) {
-        sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
-    }
-#elif defined(__AVX2__)
-    // AVX 256-bit
-    const int n32 = (n & ~31);
-
-    __m256 sum0 = _mm256_setzero_ps();
-    __m256 sum1 = _mm256_setzero_ps();
-    __m256 sum2 = _mm256_setzero_ps();
-    __m256 sum3 = _mm256_setzero_ps();
-
-    __m256 x0, x1, x2, x3;
-    __m256 y0, y1, y2, y3;
-
-    for (int i = 0; i < n32; i += 32) {
-        x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 )));
-        x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 )));
-        x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16)));
-        x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24)));
-
-        y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 )));
-        y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 )));
-        y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16)));
-        y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24)));
-
-        sum0 = _mm256_fmadd_ps(x0, y0, sum0);
-        sum1 = _mm256_fmadd_ps(x1, y1, sum1);
-        sum2 = _mm256_fmadd_ps(x2, y2, sum2);
-        sum3 = _mm256_fmadd_ps(x3, y3, sum3);
-    }
-
-    const __m256 sum01 = _mm256_add_ps(sum0, sum1);
-    const __m256 sum23 = _mm256_add_ps(sum2, sum3);
-    const __m256 sum0123 = _mm256_add_ps(sum01, sum23);
-
-    const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0123), _mm256_extractf128_ps(sum0123, 1));
-    const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4));
-    const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2));
-
-    sumf = _mm_cvtss_f32(r1);
+    GGML_F16_VEC_REDUCE(sumf, sum);
 
     // leftovers
-    for (int i = n32; i < n; ++i) {
-        //GGML_ASSERT(false);
-        sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
+    for (int i = np; i < n; ++i) {
+        sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]);
     }
-#elif defined(__AVX__)
-    // AVX 256-bit
+#elif defined(__POWER9_VECTOR__)
+    // TODO: this is temporary because I cannot fit it in the GGML_SIMD pattern like all other architectures without
+    //       being able to test it. hoping someone with access to a POWER9 machine can help out here.
     const int n32 = (n & ~31);
 
-    __m256 sum0 = _mm256_setzero_ps();
-    __m256 sum1 = _mm256_setzero_ps();
-    __m256 sum2 = _mm256_setzero_ps();
-    __m256 sum3 = _mm256_setzero_ps();
-
-    __m256 x0, x1, x2, x3;
-    __m256 y0, y1, y2, y3;
+    vector float sum0 = vec_splats (0.0f);
 
     for (int i = 0; i < n32; i += 32) {
-        x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 )));
-        x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 )));
-        x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16)));
-        x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24)));
-
-        y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 )));
-        y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 )));
-        y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16)));
-        y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24)));
+        // Use vec_xl, not vec_ld, because x is sometimes unaligned.
+        vector unsigned short x0 = vec_xl(i * 2 +  0, x);
+        vector unsigned short x1 = vec_xl(i * 2 + 16, x);
+        vector unsigned short x2 = vec_xl(i * 2 + 32, x);
+        vector unsigned short x3 = vec_xl(i * 2 + 48, x);
+
+        vector unsigned short y0 = vec_xl(i * 2 +  0, y);
+        vector unsigned short y1 = vec_xl(i * 2 + 16, y);
+        vector unsigned short y2 = vec_xl(i * 2 + 32, y);
+        vector unsigned short y3 = vec_xl(i * 2 + 48, y);
+
+        vector float fx0l = vec_extract_fp32_from_shortl(x0);
+        vector float fx0h = vec_extract_fp32_from_shorth(x0);
+        vector float fx1l = vec_extract_fp32_from_shortl(x1);
+        vector float fx1h = vec_extract_fp32_from_shorth(x1);
+        vector float fx2l = vec_extract_fp32_from_shortl(x2);
+        vector float fx2h = vec_extract_fp32_from_shorth(x2);
+        vector float fx3l = vec_extract_fp32_from_shortl(x3);
+        vector float fx3h = vec_extract_fp32_from_shorth(x3);
+
+        vector float fy0l = vec_extract_fp32_from_shortl(y0);
+        vector float fy0h = vec_extract_fp32_from_shorth(y0);
+        vector float fy1l = vec_extract_fp32_from_shortl(y1);
+        vector float fy1h = vec_extract_fp32_from_shorth(y1);
+        vector float fy2l = vec_extract_fp32_from_shortl(y2);
+        vector float fy2h = vec_extract_fp32_from_shorth(y2);
+        vector float fy3l = vec_extract_fp32_from_shortl(y3);
+        vector float fy3h = vec_extract_fp32_from_shorth(y3);
+
+        sum0 = vec_add(sum0, vec_mul(fx0l, fy0l));
+        sum0 = vec_add(sum0, vec_mul(fx0h, fy0h));
+        sum0 = vec_add(sum0, vec_mul(fx1l, fy1l));
+        sum0 = vec_add(sum0, vec_mul(fx1h, fy1h));
+        sum0 = vec_add(sum0, vec_mul(fx2l, fy2l));
+        sum0 = vec_add(sum0, vec_mul(fx2h, fy2h));
+        sum0 = vec_add(sum0, vec_mul(fx3l, fy3l));
+        sum0 = vec_add(sum0, vec_mul(fx3h, fy3h));
+    }
+
+    sumf = vec_extract(sum0, 0) + vec_extract(sum0, 1)
+         + vec_extract(sum0, 2) + vec_extract(sum0, 3);
 
-       sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0);
-       sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1);
-       sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2);
-       sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3);
-    }
-
-    const __m256 sum01 = _mm256_add_ps(sum0, sum1);
-    const __m256 sum23 = _mm256_add_ps(sum2, sum3);
-    const __m256 sum0123 = _mm256_add_ps(sum01, sum23);
-
-    const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0123), _mm256_extractf128_ps(sum0123, 1));
-    const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4));
-    const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2));
-
-    sumf = _mm_cvtss_f32(r1);
-
-    // leftovers
     for (int i = n32; i < n; ++i) {
-        //GGML_ASSERT(false);
-        sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
-    }
-#elif defined(__wasm_simd128__)
-    // WASM 128-bit
-    const int n16 = (n & ~15);
-
-    v128_t sum0 = wasm_f32x4_splat(0.0f);
-    v128_t sum1 = wasm_f32x4_splat(0.0f);
-    v128_t sum2 = wasm_f32x4_splat(0.0f);
-    v128_t sum3 = wasm_f32x4_splat(0.0f);
-
-    v128_t x0, x1, x2, x3;
-    v128_t y0, y1, y2, y3;
-
-    float tx[16];
-    float ty[16];
-
-    for (int i = 0; i < n16; i += 16) {
-        for (int k = 0; k < 16; ++k) {
-            tx[k] = ggml_fp16_to_fp32(x[i + k]);
-            ty[k] = ggml_fp16_to_fp32(y[i + k]);
-        }
-
-        x0 = wasm_v128_load(tx + 0);
-        x1 = wasm_v128_load(tx + 4);
-        x2 = wasm_v128_load(tx + 8);
-        x3 = wasm_v128_load(tx + 12);
-
-        y0 = wasm_v128_load(ty + 0);
-        y1 = wasm_v128_load(ty + 4);
-        y2 = wasm_v128_load(ty + 8);
-        y3 = wasm_v128_load(ty + 12);
-
-        sum0 = wasm_f32x4_add(sum0, wasm_f32x4_mul(x0, y0));
-        sum1 = wasm_f32x4_add(sum1, wasm_f32x4_mul(x1, y1));
-        sum2 = wasm_f32x4_add(sum2, wasm_f32x4_mul(x2, y2));
-        sum3 = wasm_f32x4_add(sum3, wasm_f32x4_mul(x3, y3));
-    }
-
-    sum0 = wasm_f32x4_add(sum0, sum1);
-    sum2 = wasm_f32x4_add(sum2, sum3);
-    sum0 = wasm_f32x4_add(sum0, sum2);
-
-    sumf = wasm_f32x4_extract_lane(sum0, 0) + wasm_f32x4_extract_lane(sum0, 1) + wasm_f32x4_extract_lane(sum0, 2) + wasm_f32x4_extract_lane(sum0, 3);
-
-    // leftovers
-    for (int i = n16; i < n; ++i) {
-        //GGML_ASSERT(false);
-        sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
+        sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]);
     }
 #else
     for (int i = 0; i < n; ++i) {
-        sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
+        sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]);
     }
 #endif
 
@@ -726,144 +838,26 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
 }
 
 inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) {
-#ifdef __ARM_NEON
-    // NEON 128-bit
-    const int n16 = (n & ~15);
-
-    const float32x4_t v4 = vdupq_n_f32(v);
-
-    float32x4_t x0, x1, x2, x3;
-    float32x4_t y0, y1, y2, y3;
-
-    for (int i = 0; i < n16; i += 16) {
-        x0 = vld1q_f32(x + i + 0);
-        x1 = vld1q_f32(x + i + 4);
-        x2 = vld1q_f32(x + i + 8);
-        x3 = vld1q_f32(x + i + 12);
-
-        y0 = vld1q_f32(y + i + 0);
-        y1 = vld1q_f32(y + i + 4);
-        y2 = vld1q_f32(y + i + 8);
-        y3 = vld1q_f32(y + i + 12);
-
-        y0 = vfmaq_f32(y0, x0, v4);
-        y1 = vfmaq_f32(y1, x1, v4);
-        y2 = vfmaq_f32(y2, x2, v4);
-        y3 = vfmaq_f32(y3, x3, v4);
-
-        vst1q_f32(y + i + 0, y0);
-        vst1q_f32(y + i + 4, y1);
-        vst1q_f32(y + i + 8, y2);
-        vst1q_f32(y + i + 12, y3);
-    }
-
-    // leftovers
-    for (int i = n16; i < n; ++i) {
-        y[i] += x[i]*v;
-    }
-#elif defined(__AVX2__)
-    // AVX 256-bit
-    const int n32 = (n & ~31);
-
-    const __m256 v4 = _mm256_set1_ps(v);
-
-    __m256 x0, x1, x2, x3;
-    __m256 y0, y1, y2, y3;
-
-    for (int i = 0; i < n32; i += 32) {
-        x0 = _mm256_loadu_ps(x + i + 0);
-        x1 = _mm256_loadu_ps(x + i + 8);
-        x2 = _mm256_loadu_ps(x + i + 16);
-        x3 = _mm256_loadu_ps(x + i + 24);
-
-        y0 = _mm256_loadu_ps(y + i + 0);
-        y1 = _mm256_loadu_ps(y + i + 8);
-        y2 = _mm256_loadu_ps(y + i + 16);
-        y3 = _mm256_loadu_ps(y + i + 24);
-
-        y0 = _mm256_fmadd_ps(x0, v4, y0);
-        y1 = _mm256_fmadd_ps(x1, v4, y1);
-        y2 = _mm256_fmadd_ps(x2, v4, y2);
-        y3 = _mm256_fmadd_ps(x3, v4, y3);
-
-        _mm256_storeu_ps(y + i + 0, y0);
-        _mm256_storeu_ps(y + i + 8, y1);
-        _mm256_storeu_ps(y + i + 16, y2);
-        _mm256_storeu_ps(y + i + 24, y3);
-    }
-
-    // leftovers
-    for (int i = n32; i < n; ++i) {
-        y[i] += x[i]*v;
-    }
-#elif defined(__AVX__)
-    // AVX 256-bit
-    const int n32 = (n & ~31);
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F32_STEP - 1));
 
-    const __m256 v4 = _mm256_set1_ps(v);
+    GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
 
-    __m256 x0, x1, x2, x3;
-    __m256 y0, y1, y2, y3;
+    GGML_F32_VEC ax[GGML_F32_ARR];
+    GGML_F32_VEC ay[GGML_F32_ARR];
 
-    for (int i = 0; i < n32; i += 32) {
-        x0 = _mm256_loadu_ps(x + i + 0);
-        x1 = _mm256_loadu_ps(x + i + 8);
-        x2 = _mm256_loadu_ps(x + i + 16);
-        x3 = _mm256_loadu_ps(x + i + 24);
+    for (int i = 0; i < np; i += GGML_F32_STEP) {
+        for (int j = 0; j < GGML_F32_ARR; j++) {
+            ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
+            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+            ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
 
-        y0 = _mm256_loadu_ps(y + i + 0);
-        y1 = _mm256_loadu_ps(y + i + 8);
-        y2 = _mm256_loadu_ps(y + i + 16);
-        y3 = _mm256_loadu_ps(y + i + 24);
-
-       y0 = _mm256_add_ps(_mm256_mul_ps(x0, v4), y0);
-       y1 = _mm256_add_ps(_mm256_mul_ps(x1, v4), y1);
-       y2 = _mm256_add_ps(_mm256_mul_ps(x2, v4), y2);
-       y3 = _mm256_add_ps(_mm256_mul_ps(x3, v4), y3);
-
-        _mm256_storeu_ps(y + i + 0, y0);
-        _mm256_storeu_ps(y + i + 8, y1);
-        _mm256_storeu_ps(y + i + 16, y2);
-        _mm256_storeu_ps(y + i + 24, y3);
-    }
-
-    // leftovers
-    for (int i = n32; i < n; ++i) {
-        y[i] += x[i]*v;
-    }
-#elif defined(__wasm_simd128__)
-    // WASM SIMD 128-bit
-    const int n16 = (n & ~15);
-
-    const v128_t v4 = wasm_f32x4_splat(v);
-
-    v128_t x0, x1, x2, x3;
-    v128_t y0, y1, y2, y3;
-
-    for (int i = 0; i < n16; i += 16) {
-        x0 = wasm_v128_load(x + i + 0);
-        x1 = wasm_v128_load(x + i + 4);
-        x2 = wasm_v128_load(x + i + 8);
-        x3 = wasm_v128_load(x + i + 12);
-
-        y0 = wasm_v128_load(y + i + 0);
-        y1 = wasm_v128_load(y + i + 4);
-        y2 = wasm_v128_load(y + i + 8);
-        y3 = wasm_v128_load(y + i + 12);
-
-        y0 = wasm_f32x4_add(y0, wasm_f32x4_mul(x0, v4));
-        y1 = wasm_f32x4_add(y1, wasm_f32x4_mul(x1, v4));
-        y2 = wasm_f32x4_add(y2, wasm_f32x4_mul(x2, v4));
-        y3 = wasm_f32x4_add(y3, wasm_f32x4_mul(x3, v4));
-
-        wasm_v128_store(y + i + 0, y0);
-        wasm_v128_store(y + i + 4, y1);
-        wasm_v128_store(y + i + 8, y2);
-        wasm_v128_store(y + i + 12, y3);
+            GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
+        }
     }
 
     // leftovers
-    for (int i = n16; i < n; ++i) {
+    for (int i = np; i < n; ++i) {
         y[i] += x[i]*v;
     }
 #else
@@ -875,216 +869,125 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
 }
 
 inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_fp16_t * restrict x, const float v) {
-#ifdef __ARM_NEON
-    // NEON 128-bit
-    const int n32 = (n & ~31);
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F16_STEP - 1));
 
-#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
-    const float16x8_t v8 = vdupq_n_f16(v);
+    GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
 
-    float16x8_t x0, x1, x2, x3;
-    float16x8_t y0, y1, y2, y3;
+    GGML_F16_VEC ax[GGML_F16_ARR];
+    GGML_F16_VEC ay[GGML_F16_ARR];
 
-    for (int i = 0; i < n32; i += 32) {
-        y0 = vld1q_f16(y + i + 0 );
-        y1 = vld1q_f16(y + i + 8 );
-        y2 = vld1q_f16(y + i + 16);
-        y3 = vld1q_f16(y + i + 24);
-
-        x0 = vld1q_f16(x + i + 0 );
-        x1 = vld1q_f16(x + i + 8 );
-        x2 = vld1q_f16(x + i + 16);
-        x3 = vld1q_f16(x + i + 24);
-
-        y0 = vfmaq_f16(y0, x0, v8);
-        y1 = vfmaq_f16(y1, x1, v8);
-        y2 = vfmaq_f16(y2, x2, v8);
-        y3 = vfmaq_f16(y3, x3, v8);
-
-        vst1q_f16(y + i + 0 , y0);
-        vst1q_f16(y + i + 8 , y1);
-        vst1q_f16(y + i + 16, y2);
-        vst1q_f16(y + i + 24, y3);
-    }
-#else
-    const float32x4_t v40 = vdupq_n_f32(v);
-    const float32x4_t v41 = vdupq_n_f32(v);
+    for (int i = 0; i < np; i += GGML_F16_STEP) {
+        for (int j = 0; j < GGML_F16_ARR; j++) {
+            ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR);
+            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR);
+            ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
 
-    float32x4_t x0, x1, x2, x3, x4, x5, x6, x7;
-    float32x4_t y0, y1, y2, y3, y4, y5, y6, y7;
-
-    for (int i = 0; i < n32; i += 32) {
-        y0 = vcvt_f32_f16(vld1_f16(y + i + 0 ));
-        y1 = vcvt_f32_f16(vld1_f16(y + i + 4 ));
-        y2 = vcvt_f32_f16(vld1_f16(y + i + 8 ));
-        y3 = vcvt_f32_f16(vld1_f16(y + i + 12));
-        y4 = vcvt_f32_f16(vld1_f16(y + i + 16));
-        y5 = vcvt_f32_f16(vld1_f16(y + i + 20));
-        y6 = vcvt_f32_f16(vld1_f16(y + i + 24));
-        y7 = vcvt_f32_f16(vld1_f16(y + i + 28));
-
-        x0 = vcvt_f32_f16(vld1_f16(x + i + 0 ));
-        x1 = vcvt_f32_f16(vld1_f16(x + i + 4 ));
-        x2 = vcvt_f32_f16(vld1_f16(x + i + 8 ));
-        x3 = vcvt_f32_f16(vld1_f16(x + i + 12));
-        x4 = vcvt_f32_f16(vld1_f16(x + i + 16));
-        x5 = vcvt_f32_f16(vld1_f16(x + i + 20));
-        x6 = vcvt_f32_f16(vld1_f16(x + i + 24));
-        x7 = vcvt_f32_f16(vld1_f16(x + i + 28));
-
-        y0 = vfmaq_f32(y0, x0, v40);
-        y1 = vfmaq_f32(y1, x1, v40);
-        y2 = vfmaq_f32(y2, x2, v40);
-        y3 = vfmaq_f32(y3, x3, v40);
-        y4 = vfmaq_f32(y4, x4, v41);
-        y5 = vfmaq_f32(y5, x5, v41);
-        y6 = vfmaq_f32(y6, x6, v41);
-        y7 = vfmaq_f32(y7, x7, v41);
-
-        vst1_f16(y + i + 0 , vcvt_f16_f32(y0));
-        vst1_f16(y + i + 4 , vcvt_f16_f32(y1));
-        vst1_f16(y + i + 8 , vcvt_f16_f32(y2));
-        vst1_f16(y + i + 12, vcvt_f16_f32(y3));
-        vst1_f16(y + i + 16, vcvt_f16_f32(y4));
-        vst1_f16(y + i + 20, vcvt_f16_f32(y5));
-        vst1_f16(y + i + 24, vcvt_f16_f32(y6));
-        vst1_f16(y + i + 28, vcvt_f16_f32(y7));
+            GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay[j]);
+        }
     }
-#endif
 
     // leftovers
-    for (int i = n32; i < n; ++i) {
+    for (int i = np; i < n; ++i) {
         GGML_ASSERT(false);
-        y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);
+        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
     }
-#elif defined(__AVX2__)
-    // AVX 256-bit
+#elif defined(__POWER9_VECTOR__)
+    // TODO: this is temporary because I cannot fit it in the GGML_SIMD pattern like all other architectures without
+    //       being able to test it. hoping someone with access to a POWER9 machine can help out here.
     const int n32 = (n & ~31);
-
-    const __m256 v8 = _mm256_set1_ps(v);
-
-    __m256 x0, x1, x2, x3;
-    __m256 y0, y1, y2, y3;
-
     for (int i = 0; i < n32; i += 32) {
-        y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 )));
-        y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 )));
-        y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16)));
-        y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24)));
-
-        x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 )));
-        x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 )));
-        x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16)));
-        x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24)));
-
-        y0 = _mm256_fmadd_ps(x0, v8, y0);
-        y1 = _mm256_fmadd_ps(x1, v8, y1);
-        y2 = _mm256_fmadd_ps(x2, v8, y2);
-        y3 = _mm256_fmadd_ps(x3, v8, y3);
-
-        _mm_storeu_si128((__m128i*)(y + i + 0 ), _mm256_cvtps_ph(y0, 0));
-        _mm_storeu_si128((__m128i*)(y + i + 8 ), _mm256_cvtps_ph(y1, 0));
-        _mm_storeu_si128((__m128i*)(y + i + 16), _mm256_cvtps_ph(y2, 0));
-        _mm_storeu_si128((__m128i*)(y + i + 24), _mm256_cvtps_ph(y3, 0));
+        // Use vec_xl, not vec_ld, because x is sometimes unaligned!
+        vector unsigned short x0 = vec_xl(i * 2 +  0, x);
+        vector unsigned short x1 = vec_xl(i * 2 + 16, x);
+        vector unsigned short x2 = vec_xl(i * 2 + 32, x);
+        vector unsigned short x3 = vec_xl(i * 2 + 48, x);
+
+        vector unsigned short y0 = vec_xl(i * 2 +  0, y);
+        vector unsigned short y1 = vec_xl(i * 2 + 16, y);
+        vector unsigned short y2 = vec_xl(i * 2 + 32, y);
+        vector unsigned short y3 = vec_xl(i * 2 + 48, y);
+
+        vector float v4 = vec_splats(v);
+
+        vector float fx0l = vec_extract_fp32_from_shortl(x0);
+        vector float fx0h = vec_extract_fp32_from_shorth(x0);
+        vector float fx1l = vec_extract_fp32_from_shortl(x1);
+        vector float fx1h = vec_extract_fp32_from_shorth(x1);
+        vector float fx2l = vec_extract_fp32_from_shortl(x2);
+        vector float fx2h = vec_extract_fp32_from_shorth(x2);
+        vector float fx3l = vec_extract_fp32_from_shortl(x3);
+        vector float fx3h = vec_extract_fp32_from_shorth(x3);
+
+        vector float fy0l = vec_extract_fp32_from_shortl(y0);
+        vector float fy0h = vec_extract_fp32_from_shorth(y0);
+        vector float fy1l = vec_extract_fp32_from_shortl(y1);
+        vector float fy1h = vec_extract_fp32_from_shorth(y1);
+        vector float fy2l = vec_extract_fp32_from_shortl(y2);
+        vector float fy2h = vec_extract_fp32_from_shorth(y2);
+        vector float fy3l = vec_extract_fp32_from_shortl(y3);
+        vector float fy3h = vec_extract_fp32_from_shorth(y3);
+
+        fy0l = vec_madd(fx0l, v4, fy0l);
+        fy0h = vec_madd(fx0h, v4, fy0h);
+        fy1l = vec_madd(fx1l, v4, fy1l);
+        fy1h = vec_madd(fx1h, v4, fy1h);
+        fy2l = vec_madd(fx2l, v4, fy2l);
+        fy2h = vec_madd(fx2h, v4, fy2h);
+        fy3l = vec_madd(fx3l, v4, fy3l);
+        fy3h = vec_madd(fx3h, v4, fy3h);
+
+        y0 = vec_pack_to_short_fp32(fy0h, fy0l);
+        y1 = vec_pack_to_short_fp32(fy1h, fy1l);
+        y2 = vec_pack_to_short_fp32(fy2h, fy2l);
+        y3 = vec_pack_to_short_fp32(fy3h, fy3l);
+
+        vec_xst(y0, i * 2 +  0, y);
+        vec_xst(y1, i * 2 + 16, y);
+        vec_xst(y2, i * 2 + 32, y);
+        vec_xst(y3, i * 2 + 48, y);
     }
 
-    // leftovers
     for (int i = n32; i < n; ++i) {
-        GGML_ASSERT(false);
-        y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);
-    }
-#elif defined(__AVX__)
-    // AVX 256-bit
-    const int n32 = (n & ~31);
-
-    const __m256 v8 = _mm256_set1_ps(v);
-
-    __m256 x0, x1, x2, x3;
-    __m256 y0, y1, y2, y3;
-
-    for (int i = 0; i < n32; i += 32) {
-        y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 )));
-        y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 )));
-        y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16)));
-        y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24)));
-
-        x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 )));
-        x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 )));
-        x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16)));
-        x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24)));
-
-       y0 = _mm256_add_ps(_mm256_mul_ps(x0, v8), y0);
-       y1 = _mm256_add_ps(_mm256_mul_ps(x1, v8), y1);
-       y2 = _mm256_add_ps(_mm256_mul_ps(x2, v8), y2);
-       y3 = _mm256_add_ps(_mm256_mul_ps(x3, v8), y3);
-
-        _mm_storeu_si128((__m128i*)(y + i + 0 ), _mm256_cvtps_ph(y0, 0));
-        _mm_storeu_si128((__m128i*)(y + i + 8 ), _mm256_cvtps_ph(y1, 0));
-        _mm_storeu_si128((__m128i*)(y + i + 16), _mm256_cvtps_ph(y2, 0));
-        _mm_storeu_si128((__m128i*)(y + i + 24), _mm256_cvtps_ph(y3, 0));
+        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
     }
-
-    // leftovers
-    for (int i = n32; i < n; ++i) {
-        GGML_ASSERT(false);
-        y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);
+#else
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
     }
-#elif defined(__wasm_simd128__)
-    // WASM SIMD 128-bit
-    const int n16 = (n & ~15);
-
-    const v128_t v4 = wasm_f32x4_splat(v);
-
-    v128_t x0, x1, x2, x3;
-    v128_t y0, y1, y2, y3;
-
-    float tx[16];
-    float ty[16];
-
-    for (int i = 0; i < n16; i += 16) {
-        for (int k = 0; k < 16; ++k) {
-            tx[k] = ggml_fp16_to_fp32(x[i + k]);
-            ty[k] = ggml_fp16_to_fp32(y[i + k]);
-        }
+#endif
+}
 
-        x0 = wasm_v128_load(tx + 0);
-        x1 = wasm_v128_load(tx + 4);
-        x2 = wasm_v128_load(tx + 8);
-        x3 = wasm_v128_load(tx + 12);
+//inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) { for (int i = 0; i < n; ++i) y[i] *= v;          }
+inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) {
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F32_STEP - 1));
 
-        y0 = wasm_v128_load(ty + 0);
-        y1 = wasm_v128_load(ty + 4);
-        y2 = wasm_v128_load(ty + 8);
-        y3 = wasm_v128_load(ty + 12);
+    GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
 
-        y0 = wasm_f32x4_add(y0, wasm_f32x4_mul(x0, v4));
-        y1 = wasm_f32x4_add(y1, wasm_f32x4_mul(x1, v4));
-        y2 = wasm_f32x4_add(y2, wasm_f32x4_mul(x2, v4));
-        y3 = wasm_f32x4_add(y3, wasm_f32x4_mul(x3, v4));
+    GGML_F32_VEC ay[GGML_F32_ARR];
 
-        wasm_v128_store(ty + 0, y0);
-        wasm_v128_store(ty + 4, y1);
-        wasm_v128_store(ty + 8, y2);
-        wasm_v128_store(ty + 12, y3);
+    for (int i = 0; i < np; i += GGML_F32_STEP) {
+        for (int j = 0; j < GGML_F32_ARR; j++) {
+            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+            ay[j] = GGML_F32_VEC_MUL(ay[j], vx);
 
-        for (int k = 0; k < 16; ++k) {
-            y[i + k] = ggml_fp32_to_fp16(ty[k]);
+            GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
         }
     }
 
     // leftovers
-    for (int i = n16; i < n; ++i) {
-        GGML_ASSERT(false);
-        y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);
+    for (int i = np; i < n; ++i) {
+        y[i] *= v;
     }
 #else
+    // scalar
     for (int i = 0; i < n; ++i) {
-        y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);
+        y[i] *= v;
     }
 #endif
 }
 
-inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) { for (int i = 0; i < n; ++i) y[i] *= v;          }
 inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrt(*s);   }
 inline static void ggml_vec_sqr_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i];   }
 inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrt(x[i]); }
@@ -1093,8 +996,8 @@ inline static void ggml_vec_sgn_f32  (const int n, float * y, const float * x) {
 inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
 inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
 
-const ggml_float GELU_COEF_A    = 0.044715;
-const ggml_float SQRT_2_OVER_PI = 0.79788456080286535587989211986876;
+static const ggml_float GELU_COEF_A    = 0.044715;
+static const ggml_float SQRT_2_OVER_PI = 0.79788456080286535587989211986876;
 
 inline static float ggml_gelu_f32(float x) {
     return 0.5*x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0 + GELU_COEF_A*x*x)));
@@ -1111,9 +1014,9 @@ inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp
 inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
     uint16_t t;
     for (int i = 0; i < n; ++i) {
-        ggml_fp16_t fp16 = ggml_fp32_to_fp16(x[i]);
+        ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
         memcpy(&t, &fp16, sizeof(uint16_t));
-        y[i] = ggml_fp16_to_fp32(table_gelu_f16[t]);
+        y[i] = GGML_FP16_TO_FP32(table_gelu_f16[t]);
     }
 }
 #else
@@ -1155,7 +1058,7 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
 // data types
 //
 
-const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
+static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
     sizeof(int8_t ),
     sizeof(int16_t),
     sizeof(int32_t),
@@ -1163,7 +1066,7 @@ const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
     sizeof(float  ),
 };
 
-const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
+static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "NONE",
 
     "DUP",
@@ -1203,7 +1106,7 @@ const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "FLASH_FF",
 };
 
-const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
+static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
 
     "x",
@@ -1256,7 +1159,7 @@ struct ggml_object {
     char padding[8];
 };
 
-const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
+static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
 
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
 static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -1311,8 +1214,26 @@ struct ggml_state {
 };
 
 // global state
-struct ggml_state g_state;
-atomic_int g_state_barrier = 0;
+static struct ggml_state g_state;
+static atomic_int g_state_barrier = 0;
+
+// barrier via spin lock
+inline static void ggml_critical_section_start() {
+    int processing = atomic_fetch_add(&g_state_barrier, 1);
+
+    while (processing > 0) {
+        // wait for other threads to finish
+        atomic_fetch_sub(&g_state_barrier, 1);
+        sched_yield(); // TODO: reconsider this
+        processing = atomic_fetch_add(&g_state_barrier, 1);
+    }
+}
+
+// TODO: make this somehow automatically executed
+//       some sort of "sentry" mechanism
+inline static void ggml_critical_section_end() {
+    atomic_fetch_sub(&g_state_barrier, 1);
+}
 
 ////////////////////////////////////////////////////////////////////////////////
 
@@ -1403,7 +1324,7 @@ bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
     return
         tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] &&
         tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
-        tensor->nb[3] == tensor->nb[2]*tensor->ne[2];;
+        tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
 }
 
 bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
@@ -1443,32 +1364,45 @@ int ggml_up64(int n) {
 
 struct ggml_context * ggml_init(struct ggml_init_params params) {
     // make this function thread safe
-    {
-        int processing = atomic_fetch_add(&g_state_barrier, 1);
-        while (processing > 0) {
-            // wait for other threads to finish
-            atomic_fetch_sub(&g_state_barrier, 1);
-            sched_yield();
-            processing = atomic_fetch_add(&g_state_barrier, 1);
-        }
-    }
+    ggml_critical_section_start();
 
     static bool is_first_call = true;
+
     if (is_first_call) {
-        const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
-
-        ggml_fp16_t ii;
-        for (int i = 0; i < (1 << 16); ++i) {
-            uint16_t ui = i;
-            memcpy(&ii, &ui, sizeof(ii));
-            const float f = ggml_fp16_to_fp32(ii);
-            table_gelu_f16[i] = ggml_fp32_to_fp16(ggml_gelu_f32(f));
-            table_exp_f16[i] = ggml_fp32_to_fp16(exp(f));
+        // initialize GELU and EXP tables
+        {
+            const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
+
+            ggml_fp16_t ii;
+            for (int i = 0; i < (1 << 16); ++i) {
+                uint16_t ui = i;
+                memcpy(&ii, &ui, sizeof(ii));
+                const float f = GGML_FP16_TO_FP32(ii);
+                table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
+                table_exp_f16[i]  = GGML_FP32_TO_FP16(exp(f));
+            }
+
+            const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
+
+            GGML_PRINT_DEBUG("%s: GELU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
         }
 
-        const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
+        // initialize g_state
+        {
+            const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
 
-        GGML_PRINT_DEBUG("%s: GELU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
+            g_state = (struct ggml_state) {
+                /*.contexts =*/ { 0 },
+            };
+
+            for (int i = 0; i < GGML_MAX_CONTEXTS; ++i) {
+                g_state.contexts[i].used = false;
+            }
+
+            const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
+
+            GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
+        }
 
         is_first_call = false;
     }
@@ -1476,14 +1410,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
     // find non-used context in g_state
     struct ggml_context * ctx = NULL;
 
-    static bool first_time = true;
-    if (first_time) {
-        for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
-            g_state.contexts[i].used = false;
-        }
-        first_time = false;
-    }
-
     for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
         if (!g_state.contexts[i].used) {
             g_state.contexts[i].used = true;
@@ -1497,7 +1423,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
     if (ctx == NULL) {
         GGML_PRINT_DEBUG("%s: no unused context found\n", __func__);
 
-        atomic_fetch_sub(&g_state_barrier, 1);
+        ggml_critical_section_end();
 
         return NULL;
     }
@@ -1515,22 +1441,16 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
 
     GGML_PRINT_DEBUG("%s: context initialized\n", __func__);
 
-    atomic_fetch_sub(&g_state_barrier, 1);
+    ggml_critical_section_end();
 
     return ctx;
 }
 
 void ggml_free(struct ggml_context * ctx) {
     // make this function thread safe
-    {
-        int processing = atomic_fetch_add(&g_state_barrier, 1);
-        while (processing > 0) {
-            // wait for other threads to finish
-            atomic_fetch_sub(&g_state_barrier, 1);
-            sched_yield();
-            processing = atomic_fetch_add(&g_state_barrier, 1);
-        }
-    }
+    ggml_critical_section_start();
+
+    bool found = false;
 
     for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
         if (&g_state.contexts[i].context == ctx) {
@@ -1543,15 +1463,16 @@ void ggml_free(struct ggml_context * ctx) {
                 free(ctx->mem_buffer);
             }
 
-            atomic_fetch_sub(&g_state_barrier, 1);
-
-            return;
+            found = true;
+            break;
         }
     }
 
-    GGML_PRINT_DEBUG("%s: context not found\n", __func__);
+    if (!found) {
+        GGML_PRINT_DEBUG("%s: context not found\n", __func__);
+    }
 
-    atomic_fetch_sub(&g_state_barrier, 1);
+    ggml_critical_section_end();
 }
 
 size_t ggml_used_mem(const struct ggml_context * ctx) {
@@ -1846,7 +1767,7 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
         case GGML_TYPE_F16:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
-                return ggml_fp16_to_fp32(((ggml_fp16_t *)(tensor->data))[i]);
+                return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
             } break;
         case GGML_TYPE_F32:
             {
@@ -1882,7 +1803,7 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
         case GGML_TYPE_F16:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
-                ((ggml_fp16_t *)(tensor->data))[i] = ggml_fp32_to_fp16(value);
+                ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
             } break;
         case GGML_TYPE_F32:
             {
@@ -1916,7 +1837,7 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
         case GGML_TYPE_F16:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
-                return ggml_fp16_to_fp32(((ggml_fp16_t *)(tensor->data))[i]);
+                return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
             } break;
         case GGML_TYPE_F32:
             {
@@ -1952,7 +1873,7 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
         case GGML_TYPE_F16:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
-                ((ggml_fp16_t *)(tensor->data))[i] = ggml_fp32_to_fp16(value);
+                ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
             } break;
         case GGML_TYPE_F32:
             {
@@ -3132,7 +3053,7 @@ void ggml_set_param(
 
 // ggml_compute_forward_dup
 
-void ggml_compute_forward_dup_f16(
+static void ggml_compute_forward_dup_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -3144,25 +3065,99 @@ void ggml_compute_forward_dup_f16(
         return;
     }
 
-    //const int ne00 = src0->ne[0];
-    //const int ne01 = src0->ne[1];
-    //const int ne02 = src0->ne[2];
-    //const int ne03 = src0->ne[3];
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
 
-    //const size_t nb00 = src0->nb[0];
-    //const size_t nb01 = src0->nb[1];
-    //const size_t nb02 = src0->nb[2];
-    //const size_t nb03 = src0->nb[3];
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
 
     if (ggml_is_contiguous(src0) && src0->type == dst->type) {
         memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
         return;
     }
 
-    GGML_ASSERT(false); // TODO: implement
+    if (src0->nb[0] == sizeof(ggml_fp16_t)) {
+        if (dst->type == GGML_TYPE_F16) {
+            int id = 0;
+            const size_t rs = ne00*nb00;
+
+            for (int i03 = 0; i03 < ne03; i03++) {
+                for (int i02 = 0; i02 < ne02; i02++) {
+                    for (int i01 = 0; i01 < ne01; i01++) {
+                        const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+                        char * dst_ptr = (char *) dst->data + id*rs;
+
+                        memcpy(dst_ptr, src0_ptr, rs);
+
+                        id++;
+                    }
+                }
+            }
+        } else if (dst->type == GGML_TYPE_F32) {
+            int id = 0;
+            float * dst_ptr = (float *) dst->data;
+
+            for (int i03 = 0; i03 < ne03; i03++) {
+                for (int i02 = 0; i02 < ne02; i02++) {
+                    for (int i01 = 0; i01 < ne01; i01++) {
+                        for (int i00 = 0; i00 < ne00; i00++) {
+                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                            dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
+                            id++;
+                        }
+                    }
+                }
+            }
+        } else {
+            GGML_ASSERT(false); // TODO: implement
+        }
+    } else {
+        //printf("%s: this is not optimal - fix me\n", __func__);
+
+        if (dst->type == GGML_TYPE_F32) {
+            int id = 0;
+            float * dst_ptr = (float *) dst->data;
+
+            for (int i03 = 0; i03 < ne03; i03++) {
+                for (int i02 = 0; i02 < ne02; i02++) {
+                    for (int i01 = 0; i01 < ne01; i01++) {
+                        for (int i00 = 0; i00 < ne00; i00++) {
+                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                            dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
+                            id++;
+                        }
+                    }
+                }
+            }
+        } else if (dst->type == GGML_TYPE_F16) {
+            int id = 0;
+            ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+            for (int i03 = 0; i03 < ne03; i03++) {
+                for (int i02 = 0; i02 < ne02; i02++) {
+                    for (int i01 = 0; i01 < ne01; i01++) {
+                        for (int i00 = 0; i00 < ne00; i00++) {
+                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                            dst_ptr[id] = *src0_ptr;
+                            id++;
+                        }
+                    }
+                }
+            }
+        } else {
+            GGML_ASSERT(false); // TODO: implement
+        }
+    }
 }
 
-void ggml_compute_forward_dup_f32(
+static void ggml_compute_forward_dup_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -3216,7 +3211,7 @@ void ggml_compute_forward_dup_f32(
                         for (int i00 = 0; i00 < ne00; i00++) {
                             const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
 
-                            dst_ptr[id] = ggml_fp32_to_fp16(*src0_ptr);
+                            dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
                             id++;
                         }
                     }
@@ -3254,7 +3249,7 @@ void ggml_compute_forward_dup_f32(
                         for (int i00 = 0; i00 < ne00; i00++) {
                             const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
 
-                            dst_ptr[id] = ggml_fp32_to_fp16(*src0_ptr);
+                            dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
                             id++;
                         }
                     }
@@ -3266,7 +3261,7 @@ void ggml_compute_forward_dup_f32(
     }
 }
 
-void ggml_compute_forward_dup(
+static void ggml_compute_forward_dup(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -3291,7 +3286,7 @@ void ggml_compute_forward_dup(
 
 // ggml_compute_forward_add
 
-void ggml_compute_forward_add_f32(
+static void ggml_compute_forward_add_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -3344,7 +3339,7 @@ void ggml_compute_forward_add_f32(
     }
 }
 
-void ggml_compute_forward_add(
+static void ggml_compute_forward_add(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -3367,7 +3362,7 @@ void ggml_compute_forward_add(
 
 // ggml_compute_forward_sub
 
-void ggml_compute_forward_sub_f32(
+static void ggml_compute_forward_sub_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -3394,7 +3389,7 @@ void ggml_compute_forward_sub_f32(
     }
 }
 
-void ggml_compute_forward_sub(
+static void ggml_compute_forward_sub(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -3417,7 +3412,7 @@ void ggml_compute_forward_sub(
 
 // ggml_compute_forward_mul
 
-void ggml_compute_forward_mul_f32(
+static void ggml_compute_forward_mul_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -3444,7 +3439,7 @@ void ggml_compute_forward_mul_f32(
     }
 }
 
-void ggml_compute_forward_mul(
+static void ggml_compute_forward_mul(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -3467,7 +3462,7 @@ void ggml_compute_forward_mul(
 
 // ggml_compute_forward_div
 
-void ggml_compute_forward_div_f32(
+static void ggml_compute_forward_div_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -3494,7 +3489,7 @@ void ggml_compute_forward_div_f32(
     }
 }
 
-void ggml_compute_forward_div(
+static void ggml_compute_forward_div(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -3517,7 +3512,7 @@ void ggml_compute_forward_div(
 
 // ggml_compute_forward_sqr
 
-void ggml_compute_forward_sqr_f32(
+static void ggml_compute_forward_sqr_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -3541,7 +3536,7 @@ void ggml_compute_forward_sqr_f32(
     }
 }
 
-void ggml_compute_forward_sqr(
+static void ggml_compute_forward_sqr(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -3563,7 +3558,7 @@ void ggml_compute_forward_sqr(
 
 // ggml_compute_forward_sqrt
 
-void ggml_compute_forward_sqrt_f32(
+static void ggml_compute_forward_sqrt_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -3587,7 +3582,7 @@ void ggml_compute_forward_sqrt_f32(
     }
 }
 
-void ggml_compute_forward_sqrt(
+static void ggml_compute_forward_sqrt(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -3609,7 +3604,7 @@ void ggml_compute_forward_sqrt(
 
 // ggml_compute_forward_sum
 
-void ggml_compute_forward_sum_f32(
+static void ggml_compute_forward_sum_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -3645,7 +3640,7 @@ void ggml_compute_forward_sum_f32(
     }
 }
 
-void ggml_compute_forward_sum(
+static void ggml_compute_forward_sum(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -3667,7 +3662,7 @@ void ggml_compute_forward_sum(
 
 // ggml_compute_forward_mean
 
-void ggml_compute_forward_mean_f32(
+static void ggml_compute_forward_mean_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -3722,7 +3717,7 @@ void ggml_compute_forward_mean_f32(
     }
 }
 
-void ggml_compute_forward_mean(
+static void ggml_compute_forward_mean(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -3744,7 +3739,7 @@ void ggml_compute_forward_mean(
 
 // ggml_compute_forward_repeat
 
-void ggml_compute_forward_repeat_f32(
+static void ggml_compute_forward_repeat_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -3784,7 +3779,7 @@ void ggml_compute_forward_repeat_f32(
     }
 }
 
-void ggml_compute_forward_repeat(
+static void ggml_compute_forward_repeat(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -3806,7 +3801,7 @@ void ggml_compute_forward_repeat(
 
 // ggml_compute_forward_abs
 
-void ggml_compute_forward_abs_f32(
+static void ggml_compute_forward_abs_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -3830,7 +3825,7 @@ void ggml_compute_forward_abs_f32(
     }
 }
 
-void ggml_compute_forward_abs(
+static void ggml_compute_forward_abs(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -3852,7 +3847,7 @@ void ggml_compute_forward_abs(
 
 // ggml_compute_forward_sgn
 
-void ggml_compute_forward_sgn_f32(
+static void ggml_compute_forward_sgn_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -3876,7 +3871,7 @@ void ggml_compute_forward_sgn_f32(
     }
 }
 
-void ggml_compute_forward_sgn(
+static void ggml_compute_forward_sgn(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -3898,7 +3893,7 @@ void ggml_compute_forward_sgn(
 
 // ggml_compute_forward_neg
 
-void ggml_compute_forward_neg_f32(
+static void ggml_compute_forward_neg_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -3922,7 +3917,7 @@ void ggml_compute_forward_neg_f32(
     }
 }
 
-void ggml_compute_forward_neg(
+static void ggml_compute_forward_neg(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -3944,7 +3939,7 @@ void ggml_compute_forward_neg(
 
 // ggml_compute_forward_step
 
-void ggml_compute_forward_step_f32(
+static void ggml_compute_forward_step_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -3968,7 +3963,7 @@ void ggml_compute_forward_step_f32(
     }
 }
 
-void ggml_compute_forward_step(
+static void ggml_compute_forward_step(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -3990,7 +3985,7 @@ void ggml_compute_forward_step(
 
 // ggml_compute_forward_relu
 
-void ggml_compute_forward_relu_f32(
+static void ggml_compute_forward_relu_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -4014,7 +4009,7 @@ void ggml_compute_forward_relu_f32(
     }
 }
 
-void ggml_compute_forward_relu(
+static void ggml_compute_forward_relu(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -4036,7 +4031,7 @@ void ggml_compute_forward_relu(
 
 // ggml_compute_forward_gelu
 
-void ggml_compute_forward_gelu_f32(
+static void ggml_compute_forward_gelu_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -4077,7 +4072,7 @@ void ggml_compute_forward_gelu_f32(
     }
 }
 
-void ggml_compute_forward_gelu(
+static void ggml_compute_forward_gelu(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -4099,7 +4094,7 @@ void ggml_compute_forward_gelu(
 
 // ggml_compute_forward_norm
 
-void ggml_compute_forward_norm_f32(
+static void ggml_compute_forward_norm_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -4159,7 +4154,7 @@ void ggml_compute_forward_norm_f32(
     }
 }
 
-void ggml_compute_forward_norm(
+static void ggml_compute_forward_norm(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -4181,9 +4176,10 @@ void ggml_compute_forward_norm(
 
 // ggml_compute_forward_mul_mat
 
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
 // helper function to determine if it is better to use BLAS or not
 // for large matrices, BLAS is faster
-bool ggml_compute_forward_mul_mat_use_blas(
+static bool ggml_compute_forward_mul_mat_use_blas(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
               struct ggml_tensor * dst) {
@@ -4195,15 +4191,16 @@ bool ggml_compute_forward_mul_mat_use_blas(
     const int ne1 = dst->ne[1];
 
     // TODO: find the optimal values for these
-    if (ggml_is_contiguous(src1) && ne0 >= 32 && ne1 >= 32 && ne10 >= 32) {
+    if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ne0 >= 32 && ne1 >= 32 && ne10 >= 32) {
         //printf("BLAS: %d %d %d\n", ne0, ne1, ne10);
         return true;
     }
 
     return false;
 }
+#endif
 
-void ggml_compute_forward_mul_mat_f32(
+static void ggml_compute_forward_mul_mat_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -4272,7 +4269,6 @@ void ggml_compute_forward_mul_mat_f32(
 
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
-        GGML_ASSERT(ggml_is_contiguous(src0));
         GGML_ASSERT(nb10 == sizeof(float));
 
         if (params->ith != 0) return;
@@ -4447,7 +4443,7 @@ void ggml_compute_forward_mul_mat_f32(
     //}
 }
 
-void ggml_compute_forward_mul_mat_f16_f32(
+static void ggml_compute_forward_mul_mat_f16_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -4536,7 +4532,7 @@ void ggml_compute_forward_mul_mat_f16_f32(
                     int id = 0;
                     for (int i01 = 0; i01 < ne01; ++i01) {
                         for (int i00 = 0; i00 < ne00; ++i00) {
-                            wdata[id++] = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
+                            wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
                         }
                     }
                 }
@@ -4564,13 +4560,22 @@ void ggml_compute_forward_mul_mat_f16_f32(
                 //    }
                 //}
 
-                // zT = y * xT
                 {
+#if 1
+                    // zT = y * xT
                     cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
                             ne11, ne01, ne10,
-                            1.0f,    y, ne10,
-                                     x, ne10,
+                            1.0f,    y, ne00,
+                                     x, ne00,
                             0.0f,    d, ne01);
+#else
+                    // zT = (xT * y)T
+                    cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans,
+                            ne01, ne11, ne10,
+                            1.0f,    x, ne00,
+                                     y, ne00,
+                            0.0f,    d, ne01);
+#endif
                 }
             }
         }
@@ -4590,7 +4595,7 @@ void ggml_compute_forward_mul_mat_f16_f32(
                 for (int i12 = 0; i12 < ne12; ++i12) {
                     for (int i11 = 0; i11 < ne11; ++i11) {
                         for (int i10 = 0; i10 < ne10; ++i10) {
-                            wdata[id++] = ggml_fp32_to_fp16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
+                            wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
                         }
                     }
                 }
@@ -4624,12 +4629,12 @@ void ggml_compute_forward_mul_mat_f16_f32(
         const int ic1 = MIN(ic0 + dc, ne);
 
         for (int i = ic0; i < ic1; ++i) {
-            ((float *) dst->data)[i] = ggml_fp16_to_fp32(wdata[i]);
+            ((float *) dst->data)[i] = GGML_FP16_TO_FP32(wdata[i]);
         }
 
         for (int k = 1; k < nth; k++) {
             for (int i = ic0; i < ic1; ++i) {
-                ((float *) dst->data)[i] += ggml_fp16_to_fp32(wdata[(ne + CACHE_LINE_SIZE_F32)*k + i]);
+                ((float *) dst->data)[i] += GGML_FP16_TO_FP32(wdata[(ne + CACHE_LINE_SIZE_F32)*k + i]);
             }
         }
 
@@ -4742,7 +4747,7 @@ void ggml_compute_forward_mul_mat_f16_f32(
     //}
 }
 
-void ggml_compute_forward_mul_mat(
+static void ggml_compute_forward_mul_mat(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -4768,7 +4773,7 @@ void ggml_compute_forward_mul_mat(
 
 // ggml_compute_forward_scale
 
-void ggml_compute_forward_scale_f32(
+static void ggml_compute_forward_scale_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -4803,7 +4808,7 @@ void ggml_compute_forward_scale_f32(
     }
 }
 
-void ggml_compute_forward_scale(
+static void ggml_compute_forward_scale(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -4826,7 +4831,7 @@ void ggml_compute_forward_scale(
 
 // ggml_compute_forward_cpy
 
-void ggml_compute_forward_cpy(
+static void ggml_compute_forward_cpy(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -4835,7 +4840,7 @@ void ggml_compute_forward_cpy(
 
 // ggml_compute_forward_reshape
 
-void ggml_compute_forward_reshape(
+static void ggml_compute_forward_reshape(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -4847,7 +4852,7 @@ void ggml_compute_forward_reshape(
 
 // ggml_compute_forward_view
 
-void ggml_compute_forward_view(
+static void ggml_compute_forward_view(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0) {
     // NOP
@@ -4857,7 +4862,7 @@ void ggml_compute_forward_view(
 
 // ggml_compute_forward_permute
 
-void ggml_compute_forward_permute(
+static void ggml_compute_forward_permute(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0) {
     // NOP
@@ -4867,7 +4872,7 @@ void ggml_compute_forward_permute(
 
 // ggml_compute_forward_transpose
 
-void ggml_compute_forward_transpose(
+static void ggml_compute_forward_transpose(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0) {
     // NOP
@@ -4877,7 +4882,7 @@ void ggml_compute_forward_transpose(
 
 // ggml_compute_forward_get_rows
 
-void ggml_compute_forward_get_rows_f16(
+static void ggml_compute_forward_get_rows_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -4900,12 +4905,12 @@ void ggml_compute_forward_get_rows_f16(
 
         for (int j = 0; j < nc; ++j) {
             ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j];
-            ((float *) ((char *)  dst->data + i*dst->nb[1]))[j] = ggml_fp16_to_fp32(v);
+            ((float *) ((char *)  dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v);
         }
     }
 }
 
-void ggml_compute_forward_get_rows_f32(
+static void ggml_compute_forward_get_rows_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -4932,7 +4937,7 @@ void ggml_compute_forward_get_rows_f32(
     }
 }
 
-void ggml_compute_forward_get_rows(
+static void ggml_compute_forward_get_rows(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -4958,7 +4963,7 @@ void ggml_compute_forward_get_rows(
 
 // ggml_compute_forward_diag_mask_inf
 
-void ggml_compute_forward_diag_mask_inf_f32(
+static void ggml_compute_forward_diag_mask_inf_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -4994,7 +4999,7 @@ void ggml_compute_forward_diag_mask_inf_f32(
     }
 }
 
-void ggml_compute_forward_diag_mask_inf(
+static void ggml_compute_forward_diag_mask_inf(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -5017,7 +5022,7 @@ void ggml_compute_forward_diag_mask_inf(
 
 // ggml_compute_forward_soft_max
 
-void ggml_compute_forward_soft_max_f32(
+static void ggml_compute_forward_soft_max_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -5066,9 +5071,9 @@ void ggml_compute_forward_soft_max_f32(
                 p[i] = 0.0;
             } else {
                 //const float val = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max);
-                ggml_fp16_t s = ggml_fp32_to_fp16(p[i] - max);
+                ggml_fp16_t s = GGML_FP32_TO_FP16(p[i] - max);
                 memcpy(&ss, &s, sizeof(ss));
-                const float val = ggml_fp16_to_fp32(table_exp_f16[ss]);
+                const float val = GGML_FP16_TO_FP32(table_exp_f16[ss]);
                 sum += val;
                 p[i] = val;
             }
@@ -5088,7 +5093,7 @@ void ggml_compute_forward_soft_max_f32(
     }
 }
 
-void ggml_compute_forward_soft_max(
+static void ggml_compute_forward_soft_max(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -5110,7 +5115,7 @@ void ggml_compute_forward_soft_max(
 
 // ggml_compute_forward_rope
 
-void ggml_compute_forward_rope_f32(
+static void ggml_compute_forward_rope_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -5167,7 +5172,7 @@ void ggml_compute_forward_rope_f32(
     }
 }
 
-void ggml_compute_forward_rope(
+static void ggml_compute_forward_rope(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -5190,7 +5195,7 @@ void ggml_compute_forward_rope(
 
 // ggml_compute_forward_conv_1d_1s
 
-void ggml_compute_forward_conv_1d_1s_f16_f32(
+static void ggml_compute_forward_conv_1d_1s_f16_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -5272,7 +5277,7 @@ void ggml_compute_forward_conv_1d_1s_f16_f32(
                 const float * const src = (float *)((char *) src1->data + i11*nb11);
                 ggml_fp16_t * dst_data = wdata;
                 for (int i10 = 0; i10 < ne10; i10++) {
-                    dst_data[(i10 + nh)*ew0 + i11] = ggml_fp32_to_fp16(src[i10]);
+                    dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]);
                 }
             }
         }
@@ -5310,7 +5315,7 @@ void ggml_compute_forward_conv_1d_1s_f16_f32(
     }
 }
 
-void ggml_compute_forward_conv_1d_1s_f32(
+static void ggml_compute_forward_conv_1d_1s_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -5430,7 +5435,7 @@ void ggml_compute_forward_conv_1d_1s_f32(
     }
 }
 
-void ggml_compute_forward_conv_1d_1s(
+static void ggml_compute_forward_conv_1d_1s(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -5456,7 +5461,7 @@ void ggml_compute_forward_conv_1d_1s(
 
 // ggml_compute_forward_conv_1d_2s
 
-void ggml_compute_forward_conv_1d_2s_f16_f32(
+static void ggml_compute_forward_conv_1d_2s_f16_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -5538,7 +5543,7 @@ void ggml_compute_forward_conv_1d_2s_f16_f32(
                 const float * const src = (float *)((char *) src1->data + i11*nb11);
                 ggml_fp16_t * dst_data = wdata;
                 for (int i10 = 0; i10 < ne10; i10++) {
-                    dst_data[(i10 + nh)*ew0 + i11] = ggml_fp32_to_fp16(src[i10]);
+                    dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]);
                 }
             }
         }
@@ -5576,7 +5581,7 @@ void ggml_compute_forward_conv_1d_2s_f16_f32(
     }
 }
 
-void ggml_compute_forward_conv_1d_2s_f32(
+static void ggml_compute_forward_conv_1d_2s_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -5696,7 +5701,7 @@ void ggml_compute_forward_conv_1d_2s_f32(
     }
 }
 
-void ggml_compute_forward_conv_1d_2s(
+static void ggml_compute_forward_conv_1d_2s(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -5722,7 +5727,7 @@ void ggml_compute_forward_conv_1d_2s(
 
 // ggml_compute_forward_flash_attn
 
-void ggml_compute_forward_flash_attn_f32(
+static void ggml_compute_forward_flash_attn_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * q,
         const struct ggml_tensor * k,
@@ -5875,9 +5880,9 @@ void ggml_compute_forward_flash_attn_f32(
                     S[i] = 0.0;
                 } else {
                     //const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max);
-                    ggml_fp16_t s = ggml_fp32_to_fp16(S[i] - max);
+                    ggml_fp16_t s = GGML_FP32_TO_FP16(S[i] - max);
                     memcpy(&ss, &s, sizeof(ss));
-                    const float val = ggml_fp16_to_fp32(table_exp_f16[ss]);
+                    const float val = GGML_FP16_TO_FP32(table_exp_f16[ss]);
                     sum += val;
                     S[i] = val;
                 }
@@ -5903,7 +5908,7 @@ void ggml_compute_forward_flash_attn_f32(
     }
 }
 
-void ggml_compute_forward_flash_attn_f16(
+static void ggml_compute_forward_flash_attn_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * q,
         const struct ggml_tensor * k,
@@ -6056,9 +6061,9 @@ void ggml_compute_forward_flash_attn_f16(
                     S[i] = 0.0;
                 } else {
                     //const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max);
-                    ggml_fp16_t s = ggml_fp32_to_fp16(S[i] - max);
+                    ggml_fp16_t s = GGML_FP32_TO_FP16(S[i] - max);
                     memcpy(&ss, &s, sizeof(ss));
-                    const float val = ggml_fp16_to_fp32(table_exp_f16[ss]);
+                    const float val = GGML_FP16_TO_FP32(table_exp_f16[ss]);
                     sum += val;
                     S[i] = val;
                 }
@@ -6073,7 +6078,7 @@ void ggml_compute_forward_flash_attn_f16(
         ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
 
         for (int i = 0; i < M; i++) {
-            S16[i] = ggml_fp32_to_fp16(S[i]);
+            S16[i] = GGML_FP32_TO_FP16(S[i]);
         }
 
         for (int ic = 0; ic < nev1; ++ic) {
@@ -6090,7 +6095,7 @@ void ggml_compute_forward_flash_attn_f16(
     }
 }
 
-void ggml_compute_forward_flash_attn(
+static void ggml_compute_forward_flash_attn(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * q,
         const struct ggml_tensor * k,
@@ -6118,7 +6123,7 @@ void ggml_compute_forward_flash_attn(
 
 // ggml_compute_forward_flash_ff
 
-void ggml_compute_forward_flash_ff_f16(
+static void ggml_compute_forward_flash_ff_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * a,  // F16
         const struct ggml_tensor * b0, // F16 fc_w
@@ -6271,7 +6276,7 @@ void ggml_compute_forward_flash_ff_f16(
         ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
 
         for (int i = 0; i < M; i++) {
-            S16[i] = ggml_fp32_to_fp16(S[i]);
+            S16[i] = GGML_FP32_TO_FP16(S[i]);
         }
 
         ggml_vec_gelu_f16(neb01, S16, S16);
@@ -6298,7 +6303,7 @@ void ggml_compute_forward_flash_ff_f16(
     }
 }
 
-void ggml_compute_forward_flash_ff(
+static void ggml_compute_forward_flash_ff(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * a,
         const struct ggml_tensor * b0,
@@ -6327,7 +6332,7 @@ void ggml_compute_forward_flash_ff(
 
 /////////////////////////////////
 
-void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
+static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
     assert(params);
 
     switch (tensor->op) {
@@ -6470,12 +6475,12 @@ void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tenso
             {
                 GGML_ASSERT(false);
             } break;
-    };
+    }
 }
 
 ////////////////////////////////////////////////////////////////////////////////
 
-void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) {
+static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) {
     struct ggml_tensor * src0 = tensor->src0;
     struct ggml_tensor * src1 = tensor->src1;
 
@@ -6716,10 +6721,10 @@ void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tenso
             {
                 GGML_ASSERT(false);
             } break;
-    };
+    }
 }
 
-void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
+static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
     if (node->grad == NULL) {
         // this usually happens when we generate intermediate nodes from constants in the backward pass
         // it can also happen during forward pass, if the user performs computations with constants
@@ -6770,7 +6775,7 @@ void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node)
     }
 }
 
-void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) {
+static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) {
     if (!expand) {
         cgraph->n_nodes = 0;
         cgraph->n_leafs = 0;
@@ -6881,6 +6886,11 @@ typedef int ggml_lock_t;
 
 #define GGML_LOCK_INITIALIZER 0
 
+typedef pthread_t ggml_thread_t;
+
+#define ggml_thread_create pthread_create
+#define ggml_thread_join   pthread_join
+
 #else
 
 //typedef pthread_spinlock_t ggml_lock_t;
@@ -6899,6 +6909,11 @@ typedef int ggml_lock_t;
 
 #define GGML_LOCK_INITIALIZER 0
 
+typedef pthread_t ggml_thread_t;
+
+#define ggml_thread_create pthread_create
+#define ggml_thread_join   pthread_join
+
 #endif
 
 struct ggml_compute_state_shared {
@@ -6913,7 +6928,7 @@ struct ggml_compute_state_shared {
 };
 
 struct ggml_compute_state {
-    pthread_t thrd;
+    ggml_thread_t thrd;
 
     struct ggml_compute_params params;
     struct ggml_tensor * node;
@@ -6921,16 +6936,7 @@ struct ggml_compute_state {
     struct ggml_compute_state_shared * shared;
 };
 
-// function used by each compute thread
-void * ggml_graph_compute_one(void * data) {
-    struct ggml_compute_state * state = (struct ggml_compute_state *) data;
-
-    ggml_compute_forward(&state->params, state->node);
-
-    return NULL;
-}
-
-thread_ret_t ggml_graph_compute_thread(void * data) {
+static thread_ret_t ggml_graph_compute_thread(void * data) {
     struct ggml_compute_state * state = (struct ggml_compute_state *) data;
 
     const int n_threads = state->shared->n_threads;
@@ -7010,7 +7016,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                 .node   = NULL,
                 .shared = &state_shared,
             };
-            int rc = pthread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
+            int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
             assert(rc == 0);
             UNUSED(rc);
         }
@@ -7185,7 +7191,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     {
                         assert(false);
                     } break;
-            };
+            }
         }
 
         if (cgraph->work != NULL && work_size > cgraph->work_size) {
@@ -7354,7 +7360,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
         atomic_store(&state_shared.has_work, true);
 
         for (int j = 0; j < n_threads - 1; j++) {
-            int rc = pthread_join(workers[j].thrd, NULL);
+            int rc = ggml_thread_join(workers[j].thrd, NULL);
             assert(rc == 0);
             UNUSED(rc);
         }
@@ -7432,7 +7438,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
 }
 
 // check if node is part of the graph
-bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
+static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
     if (cgraph == NULL) {
         return true;
     }
@@ -7446,7 +7452,7 @@ bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor
     return false;
 }
 
-struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
+static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
     for (int i = 0; i < cgraph->n_nodes; i++) {
         struct ggml_tensor * parent = cgraph->nodes[i];
 
@@ -7575,7 +7581,7 @@ label=\"<x>CONST %d [%d, %d]\"; ]\n",
 
 ////////////////////////////////////////////////////////////////////////////////
 
-void ggml_opt_set_params(int np, struct ggml_tensor * const ps[], const float * x) {
+static void ggml_opt_set_params(int np, struct ggml_tensor * const ps[], const float * x) {
     int i = 0;
     for (int p = 0; p < np; ++p) {
         const int ne = ggml_nelements(ps[p]) ;
@@ -7586,7 +7592,7 @@ void ggml_opt_set_params(int np, struct ggml_tensor * const ps[], const float *
     }
 }
 
-void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float * x) {
+static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float * x) {
     int i = 0;
     for (int p = 0; p < np; ++p) {
         const int ne = ggml_nelements(ps[p]) ;
@@ -7597,7 +7603,7 @@ void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float * x) {
     }
 }
 
-void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) {
+static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) {
     int i = 0;
     for (int p = 0; p < np; ++p) {
         const int ne = ggml_nelements(ps[p]) ;
@@ -7614,7 +7620,7 @@ void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) {
 //   ref: https://arxiv.org/pdf/1412.6980.pdf
 //
 
-enum ggml_opt_result ggml_opt_adam(
+static enum ggml_opt_result ggml_opt_adam(
         struct ggml_context * ctx,
         struct ggml_opt_params params,
         struct ggml_tensor * f,
@@ -7907,7 +7913,7 @@ static enum ggml_opt_result linesearch_backtracking(
     return GGML_LINESEARCH_FAIL;
 }
 
-enum ggml_opt_result ggml_opt_lbfgs(
+static enum ggml_opt_result ggml_opt_lbfgs(
         struct ggml_context * ctx,
         struct ggml_opt_params params,
         struct ggml_tensor * f,
@@ -8270,6 +8276,14 @@ int ggml_cpu_has_avx512(void) {
 #endif
 }
 
+int ggml_cpu_has_fma(void) {
+#if defined(__FMA__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
 int ggml_cpu_has_neon(void) {
 #if defined(__ARM_NEON)
     return 1;
@@ -8278,6 +8292,22 @@ int ggml_cpu_has_neon(void) {
 #endif
 }
 
+int ggml_cpu_has_arm_fma(void) {
+#if defined(__ARM_FEATURE_FMA)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_f16c(void) {
+#if defined(__F16C__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
 int ggml_cpu_has_fp16_va(void) {
 #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
     return 1;