]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : use flash attention (#2152)
authorGeorgi Gerganov <redacted>
Wed, 15 May 2024 06:38:19 +0000 (09:38 +0300)
committerGitHub <redacted>
Wed, 15 May 2024 06:38:19 +0000 (09:38 +0300)
* whisper : use flash attention in the encoder

* whisper : add kv_pad

* whisper : remove extra backend instance (huh?)

* whisper : use FA for cross-attention

* whisper : use FA for self-attention

* whisper : simplify encoder FA

* whisper : add flash_attn runtime parameter

* scripts : add bench log

* scripts : add M1 Pro bench log

13 files changed:
examples/bench/bench.cpp
examples/command/command.cpp
examples/lsp/lsp.cpp
examples/main/main.cpp
examples/server/server.cpp
examples/stream/stream.cpp
examples/talk-llama/talk-llama.cpp
examples/talk/talk.cpp
examples/wchess/wchess.cmd/wchess.cmd.cpp
scripts/bench-all-gg.txt [new file with mode: 0644]
scripts/bench-all.sh
whisper.cpp
whisper.h

index b77621ac884de50b27476630d6efede5c9abc763..cac9385c82f272f5735b21226f5f9d2929c59c60 100644 (file)
@@ -12,7 +12,8 @@ struct whisper_params {
 
     std::string model = "models/ggml-base.en.bin";
 
-    bool use_gpu = true;
+    bool use_gpu    = true;
+    bool flash_attn = false;
 };
 
 void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@@ -25,10 +26,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
             whisper_print_usage(argc, argv, params);
             exit(0);
         }
-        else if (arg == "-t"  || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
-        else if (arg == "-m"  || arg == "--model")   { params.model     = argv[++i]; }
-        else if (arg == "-w"  || arg == "--what")    { params.what      = atoi(argv[++i]); }
-        else if (arg == "-ng" || arg == "--no-gpu")  { params.use_gpu   = false; }
+        else if (arg == "-t"  || arg == "--threads")    { params.n_threads  = std::stoi(argv[++i]); }
+        else if (arg == "-m"  || arg == "--model")      { params.model      = argv[++i]; }
+        else if (arg == "-w"  || arg == "--what")       { params.what       = atoi(argv[++i]); }
+        else if (arg == "-ng" || arg == "--no-gpu")     { params.use_gpu    = false; }
+        else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
         else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
             whisper_print_usage(argc, argv, params);
@@ -49,6 +51,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -m FNAME, --model FNAME [%-7s] model path\n",                                  params.model.c_str());
     fprintf(stderr, "  -w N,     --what N      [%-7d] what to benchmark:\n",                          params.what);
     fprintf(stderr, "  -ng,      --no-gpu      [%-7s] disable GPU\n",                                 params.use_gpu ? "false" : "true");
+    fprintf(stderr, "  -fa,      --flash-attn  [%-7s] enable flash attention\n",                      params.flash_attn ? "true" : "false");
     fprintf(stderr, "                           %-7s  0 - whisper\n",                                 "");
     fprintf(stderr, "                           %-7s  1 - memcpy\n",                                  "");
     fprintf(stderr, "                           %-7s  2 - ggml_mul_mat\n",                            "");
@@ -59,7 +62,9 @@ int whisper_bench_full(const whisper_params & params) {
     // whisper init
 
     struct whisper_context_params cparams = whisper_context_default_params();
-    cparams.use_gpu = params.use_gpu;
+
+    cparams.use_gpu    = params.use_gpu;
+    cparams.flash_attn = params.flash_attn;
 
     struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
 
index ec749d602472e144e7dd6596e4dc6f7220b3c80c..cd6cc02399456fcdd861d7bb15805bee877dce1b 100644 (file)
@@ -44,6 +44,7 @@ struct whisper_params {
     bool print_energy  = false;
     bool no_timestamps = true;
     bool use_gpu       = true;
+    bool flash_attn    = false;
 
     std::string language  = "en";
     std::string model     = "models/ggml-base.en.bin";
@@ -80,6 +81,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-ps"  || arg == "--print-special") { params.print_special = true; }
         else if (arg == "-pe"  || arg == "--print-energy")  { params.print_energy  = true; }
         else if (arg == "-ng"  || arg == "--no-gpu")        { params.use_gpu       = false; }
+        else if (arg == "-fa"  || arg == "--flash-attn")    { params.flash_attn    = true; }
         else if (arg == "-l"   || arg == "--language")      { params.language      = argv[++i]; }
         else if (arg == "-m"   || arg == "--model")         { params.model         = argv[++i]; }
         else if (arg == "-f"   || arg == "--file")          { params.fname_out     = argv[++i]; }
@@ -118,6 +120,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -ps,        --print-special  [%-7s] print special tokens\n",                        params.print_special ? "true" : "false");
     fprintf(stderr, "  -pe,        --print-energy   [%-7s] print sound energy (for debugging)\n",          params.print_energy ? "true" : "false");
     fprintf(stderr, "  -ng,        --no-gpu         [%-7s] disable GPU\n",                                 params.use_gpu ? "false" : "true");
+    fprintf(stderr, "  -fa,        --flash-attn     [%-7s] flash attention\n",                             params.flash_attn ? "true" : "false");
     fprintf(stderr, "  -l LANG,    --language LANG  [%-7s] spoken language\n",                             params.language.c_str());
     fprintf(stderr, "  -m FNAME,   --model FNAME    [%-7s] model path\n",                                  params.model.c_str());
     fprintf(stderr, "  -f FNAME,   --file FNAME     [%-7s] text output file name\n",                       params.fname_out.c_str());
@@ -696,7 +699,9 @@ int main(int argc, char ** argv) {
     // whisper init
 
     struct whisper_context_params cparams = whisper_context_default_params();
-    cparams.use_gpu = params.use_gpu;
+
+    cparams.use_gpu    = params.use_gpu;
+    cparams.flash_attn = params.flash_attn;
 
     struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
 
index e5f8360f83dae26baf16478d196392232b07bb63..3df54266a251d6d316d68f0e611f1f85ec0b2b46 100644 (file)
@@ -31,6 +31,7 @@ struct whisper_params {
     bool print_special = false;
     bool print_energy  = false;
     bool use_gpu       = true;
+    bool flash_attn    = false;
 
     std::string language  = "en";
     std::string model     = "models/ggml-base.en.bin";
@@ -74,6 +75,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-ps"  || arg == "--print-special") { params.print_special = true; }
         else if (arg == "-pe"  || arg == "--print-energy")  { params.print_energy  = true; }
         else if (arg == "-ng"  || arg == "--no-gpu")        { params.use_gpu       = false; }
+        else if (arg == "-fa"  || arg == "--flash-attn")    { params.flash_attn    = true; }
         else if (arg == "-l"   || arg == "--language")      { params.language      = argv[++i]; }
         else if (arg == "-m"   || arg == "--model")         { params.model         = argv[++i]; }
         else {
@@ -105,6 +107,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -ps,        --print-special  [%-7s] print special tokens\n",                        params.print_special ? "true" : "false");
     fprintf(stderr, "  -pe,        --print-energy   [%-7s] print sound energy (for debugging)\n",          params.print_energy ? "true" : "false");
     fprintf(stderr, "  -ng,        --no-gpu         [%-7s] disable GPU\n",                                 params.use_gpu ? "false" : "true");
+    fprintf(stderr, "  -fa,        --flash-attn     [%-7s] flash attention\n",                             params.flash_attn ? "true" : "false");
     fprintf(stderr, "  -l LANG,    --language LANG  [%-7s] spoken language\n",                             params.language.c_str());
     fprintf(stderr, "  -m FNAME,   --model FNAME    [%-7s] model path\n",                                  params.model.c_str());
     fprintf(stderr, "\n");
@@ -436,7 +439,10 @@ int main(int argc, char ** argv) {
 
     // whisper init
     struct whisper_context_params cparams = whisper_context_default_params();
-    cparams.use_gpu = params.use_gpu;
+
+    cparams.use_gpu    = params.use_gpu;
+    cparams.flash_attn = params.flash_attn;
+
     struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
     // init audio
 
index d11c1c3f81b0c76cd8000f2a817217f9db709613..45eb17fe7f327aada7bce233a60b695b04599bd2 100644 (file)
@@ -70,6 +70,7 @@ struct whisper_params {
     bool no_timestamps   = false;
     bool log_score       = false;
     bool use_gpu         = true;
+    bool flash_attn      = false;
 
     std::string language  = "en";
     std::string prompt;
@@ -168,7 +169,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-dtw"  || arg == "--dtw")             { params.dtw             = argv[++i]; }
         else if (arg == "-ls"   || arg == "--log-score")       { params.log_score       = true; }
         else if (arg == "-ng"   || arg == "--no-gpu")          { params.use_gpu         = false; }
-        else if (                  arg == "--suppress-regex")  { params.suppress_regex = argv[++i]; }
+        else if (arg == "-fa"   || arg == "--flash-attn")      { params.flash_attn      = true; }
+        else if (                  arg == "--suppress-regex")  { params.suppress_regex  = argv[++i]; }
         else if (                  arg == "--grammar")         { params.grammar         = argv[++i]; }
         else if (                  arg == "--grammar-rule")    { params.grammar_rule    = argv[++i]; }
         else if (                  arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
@@ -234,6 +236,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -dtw MODEL --dtw MODEL         [%-7s] compute token-level timestamps\n",                 params.dtw.c_str());
     fprintf(stderr, "  -ls,       --log-score         [%-7s] log best decoder scores of tokens\n",              params.log_score?"true":"false");
     fprintf(stderr, "  -ng,       --no-gpu            [%-7s] disable GPU\n",                                    params.use_gpu ? "false" : "true");
+    fprintf(stderr, "  -fa,       --flash-attn        [%-7s] flash attention\n",                                params.flash_attn ? "true" : "false");
     fprintf(stderr, "  --suppress-regex REGEX         [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
     fprintf(stderr, "  --grammar GRAMMAR              [%-7s] GBNF grammar to guide decoding\n",                 params.grammar.c_str());
     fprintf(stderr, "  --grammar-rule RULE            [%-7s] top-level GBNF grammar rule name\n",               params.grammar_rule.c_str());
@@ -977,7 +980,9 @@ int main(int argc, char ** argv) {
     // whisper init
 
     struct whisper_context_params cparams = whisper_context_default_params();
-    cparams.use_gpu = params.use_gpu;
+
+    cparams.use_gpu    = params.use_gpu;
+    cparams.flash_attn = params.flash_attn;
 
     if (!params.dtw.empty()) {
         cparams.dtw_token_timestamps = true;
index e3b96698228b2402ad7e365afdff874c5e1f2938..c78b3026e18d7b871730f09ad3dc6199270c4c4f 100644 (file)
@@ -75,6 +75,7 @@ struct whisper_params {
     bool print_progress  = false;
     bool no_timestamps   = false;
     bool use_gpu         = true;
+    bool flash_attn      = false;
 
     std::string language        = "en";
     std::string prompt          = "";
@@ -178,6 +179,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
         else if (arg == "-oved" || arg == "--ov-e-device")     { params.openvino_encode_device = argv[++i]; }
         else if (arg == "-dtw"  || arg == "--dtw")             { params.dtw             = argv[++i]; }
         else if (arg == "-ng"   || arg == "--no-gpu")          { params.use_gpu         = false; }
+        else if (arg == "-fa"   || arg == "--flash-attn")      { params.flash_attn      = true; }
         // server params
         else if (                  arg == "--port")            { sparams.port        = std::stoi(argv[++i]); }
         else if (                  arg == "--host")            { sparams.hostname    = argv[++i]; }
@@ -502,7 +504,10 @@ int main(int argc, char ** argv) {
     }
     // whisper init
     struct whisper_context_params cparams = whisper_context_default_params();
-    cparams.use_gpu = params.use_gpu;
+
+    cparams.use_gpu    = params.use_gpu;
+    cparams.flash_attn = params.flash_attn;
+
     if (!params.dtw.empty()) {
         cparams.dtw_token_timestamps = true;
         cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE;
index b82e379dc612fc5b46d2de68f873da676d95064e..60c1b0894e44edbb09f32eb13927b61f4df22716 100644 (file)
@@ -36,6 +36,7 @@ struct whisper_params {
     bool tinydiarize   = false;
     bool save_audio    = false; // save audio to wav file
     bool use_gpu       = true;
+    bool flash_attn    = false;
 
     std::string language  = "en";
     std::string model     = "models/ggml-base.en.bin";
@@ -72,6 +73,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-tdrz" || arg == "--tinydiarize")   { params.tinydiarize   = true; }
         else if (arg == "-sa"   || arg == "--save-audio")    { params.save_audio    = true; }
         else if (arg == "-ng"   || arg == "--no-gpu")        { params.use_gpu       = false; }
+        else if (arg == "-fa"   || arg == "--flash-attn")    { params.flash_attn    = true; }
 
         else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
@@ -109,6 +111,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -tdrz,    --tinydiarize   [%-7s] enable tinydiarize (requires a tdrz model)\n",     params.tinydiarize ? "true" : "false");
     fprintf(stderr, "  -sa,      --save-audio    [%-7s] save the recorded audio to a file\n",              params.save_audio ? "true" : "false");
     fprintf(stderr, "  -ng,      --no-gpu        [%-7s] disable GPU inference\n",                          params.use_gpu ? "false" : "true");
+    fprintf(stderr, "  -fa,      --flash-attn    [%-7s] flash attention during inference\n",               params.flash_attn ? "true" : "false");
     fprintf(stderr, "\n");
 }
 
@@ -153,7 +156,9 @@ int main(int argc, char ** argv) {
     }
 
     struct whisper_context_params cparams = whisper_context_default_params();
-    cparams.use_gpu = params.use_gpu;
+
+    cparams.use_gpu    = params.use_gpu;
+    cparams.flash_attn = params.flash_attn;
 
     struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
 
index 838d6f56357674aa9cf518bf546531225caee994..4aab62b9a6f1cfa4872ea986eff4e4ab2fb58ac1 100644 (file)
@@ -66,6 +66,7 @@ struct whisper_params {
     bool no_timestamps  = true;
     bool verbose_prompt = false;
     bool use_gpu        = true;
+    bool flash_attn     = false;
 
     std::string person      = "Georgi";
     std::string bot_name    = "LLaMA";
@@ -105,6 +106,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-pe"  || arg == "--print-energy")   { params.print_energy   = true; }
         else if (arg == "-vp"  || arg == "--verbose-prompt") { params.verbose_prompt = true; }
         else if (arg == "-ng"  || arg == "--no-gpu")         { params.use_gpu        = false; }
+        else if (arg == "-fa"  || arg == "--flash-attn")     { params.flash_attn     = true; }
         else if (arg == "-p"   || arg == "--person")         { params.person         = argv[++i]; }
         else if (arg == "-bn"   || arg == "--bot-name")      { params.bot_name       = argv[++i]; }
         else if (arg == "--session")                         { params.path_session   = argv[++i]; }
@@ -123,7 +125,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
             }
         }
         else if (arg == "-f"   || arg == "--file")          { params.fname_out     = argv[++i]; }
-        else if (arg == "-ng"  || arg == "--no-gpu")        { params.use_gpu       = false; }
         else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
             whisper_print_usage(argc, argv, params);
@@ -154,6 +155,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -pe,      --print-energy   [%-7s] print sound energy (for debugging)\n",          params.print_energy ? "true" : "false");
     fprintf(stderr, "  -vp,      --verbose-prompt [%-7s] print prompt at start\n",                       params.verbose_prompt ? "true" : "false");
     fprintf(stderr, "  -ng,      --no-gpu         [%-7s] disable GPU\n",                                 params.use_gpu ? "false" : "true");
+    fprintf(stderr, "  -fa,      --flash-attn     [%-7s] flash attention\n",                             params.flash_attn ? "true" : "false");
     fprintf(stderr, "  -p NAME,  --person NAME    [%-7s] person name (for prompt selection)\n",          params.person.c_str());
     fprintf(stderr, "  -bn NAME, --bot-name NAME  [%-7s] bot name (to display)\n",                       params.bot_name.c_str());
     fprintf(stderr, "  -w TEXT,  --wake-command T [%-7s] wake-up command to listen for\n",               params.wake_cmd.c_str());
@@ -285,7 +287,9 @@ int main(int argc, char ** argv) {
     // whisper init
 
     struct whisper_context_params cparams = whisper_context_default_params();
-    cparams.use_gpu = params.use_gpu;
+
+    cparams.use_gpu    = params.use_gpu;
+    cparams.flash_attn = params.flash_attn;
 
     struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams);
     if (!ctx_wsp) {
@@ -316,6 +320,7 @@ int main(int argc, char ** argv) {
     lcparams.n_ctx      = 2048;
     lcparams.seed       = 1;
     lcparams.n_threads  = params.n_threads;
+    lcparams.flash_attn = params.flash_attn;
 
     struct llama_context * ctx_llama = llama_new_context_with_model(model_llama, lcparams);
 
index c1c6f8ba0b2c324a31b44c5115a0ac3247ed7dec..3e34e5724ff33420ae24d7bedeaf184f6c806dc0 100644 (file)
@@ -32,6 +32,7 @@ struct whisper_params {
     bool print_energy  = false;
     bool no_timestamps = true;
     bool use_gpu       = true;
+    bool flash_attn    = false;
 
     std::string person    = "Santa";
     std::string language  = "en";
@@ -64,6 +65,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-ps"  || arg == "--print-special") { params.print_special = true; }
         else if (arg == "-pe"  || arg == "--print-energy")  { params.print_energy  = true; }
         else if (arg == "-ng"  || arg == "--no-gpu")        { params.use_gpu       = false; }
+        else if (arg == "-fa"  || arg == "--flash-attn")    { params.flash_attn    = true; }
         else if (arg == "-p"   || arg == "--person")        { params.person        = argv[++i]; }
         else if (arg == "-l"   || arg == "--language")      { params.language      = argv[++i]; }
         else if (arg == "-mw"  || arg == "--model-whisper") { params.model_wsp     = argv[++i]; }
@@ -99,6 +101,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -ps,      --print-special [%-7s] print special tokens\n",                        params.print_special ? "true" : "false");
     fprintf(stderr, "  -pe,      --print-energy  [%-7s] print sound energy (for debugging)\n",          params.print_energy ? "true" : "false");
     fprintf(stderr, "  -ng,      --no-gpu        [%-7s] disable GPU\n",                                 params.use_gpu ? "false" : "true");
+    fprintf(stderr, "  -fa,      --flash-attn    [%-7s] flash attention\n",                             params.flash_attn ? "true" : "false");
     fprintf(stderr, "  -p NAME,  --person NAME   [%-7s] person name (for prompt selection)\n",          params.person.c_str());
     fprintf(stderr, "  -l LANG,  --language LANG [%-7s] spoken language\n",                             params.language.c_str());
     fprintf(stderr, "  -mw FILE, --model-whisper [%-7s] whisper model file\n",                          params.model_wsp.c_str());
@@ -188,7 +191,9 @@ int main(int argc, char ** argv) {
 
     // whisper init
     struct whisper_context_params cparams = whisper_context_default_params();
-    cparams.use_gpu = params.use_gpu;
+
+    cparams.use_gpu    = params.use_gpu;
+    cparams.flash_attn = params.flash_attn;
 
     struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams);
 
index f66b1765f5b3011eabad13dcb5fbfc57b7feb2e6..09e53f13172b8c5a4c7213bff5c57f89c499f336 100644 (file)
@@ -32,6 +32,7 @@ struct whisper_params {
     bool print_energy  = false;
     bool no_timestamps = true;
     bool use_gpu       = true;
+    bool flash_attn    = false;
 
     std::string language  = "en";
     std::string model     = "models/ggml-base.en.bin";
@@ -61,6 +62,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -ps,        --print-special  [%-7s] print special tokens\n",                        params.print_special ? "true" : "false");
     fprintf(stderr, "  -pe,        --print-energy   [%-7s] print sound energy (for debugging)\n",          params.print_energy ? "true" : "false");
     fprintf(stderr, "  -ng,        --no-gpu         [%-7s] disable GPU\n",                                 params.use_gpu ? "false" : "true");
+    fprintf(stderr, "  -fa,        --flash-attn     [%-7s] flash attention during decoding\n",             params.flash_attn ? "true" : "false");
     fprintf(stderr, "  -l LANG,    --language LANG  [%-7s] spoken language\n",                             params.language.c_str());
     fprintf(stderr, "  -m FNAME,   --model FNAME    [%-7s] model path\n",                                  params.model.c_str());
     fprintf(stderr, "  -f FNAME,   --file FNAME     [%-7s] text output file name\n",                       params.fname_out.c_str());
@@ -92,6 +94,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-ps"  || arg == "--print-special") { params.print_special = true; }
         else if (arg == "-pe"  || arg == "--print-energy")  { params.print_energy  = true; }
         else if (arg == "-ng"  || arg == "--no-gpu")        { params.use_gpu       = false; }
+        else if (arg == "-fa"  || arg == "--flash-attn")    { params.flash_attn    = true; }
         else if (arg == "-l"   || arg == "--language")      { params.language      = argv[++i]; }
         else if (arg == "-m"   || arg == "--model")         { params.model         = argv[++i]; }
         else if (arg == "-f"   || arg == "--file")          { params.fname_out     = argv[++i]; }
@@ -183,7 +186,9 @@ int main(int argc, char ** argv) {
     // whisper init
 
     struct whisper_context_params cparams = whisper_context_default_params();
-    cparams.use_gpu = params.use_gpu;
+
+    cparams.use_gpu    = params.use_gpu;
+    cparams.flash_attn = params.flash_attn;
 
     struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
     if (!ctx) {
diff --git a/scripts/bench-all-gg.txt b/scripts/bench-all-gg.txt
new file mode 100644 (file)
index 0000000..6fd5605
--- /dev/null
@@ -0,0 +1,298 @@
+## M1 Pro
+
+make -j && ./scripts/bench-all.sh 8
+
+Running memcpy benchmark
+
+memcpy:   39.10 GB/s (heat-up)
+memcpy:   44.75 GB/s ( 1 thread)
+memcpy:   44.78 GB/s ( 1 thread)
+memcpy:   44.97 GB/s ( 2 thread)
+memcpy:   48.04 GB/s ( 3 thread)
+memcpy:   50.55 GB/s ( 4 thread)
+memcpy:   55.20 GB/s ( 5 thread)
+memcpy:   65.60 GB/s ( 6 thread)
+memcpy:   70.64 GB/s ( 7 thread)
+memcpy:   73.34 GB/s ( 8 thread)
+sum:    -5120002535.000000
+
+
+make -j && ./scripts/bench-all.sh 1 0 0
+
+Running ggml_mul_mat benchmark with 1 threads
+
+  64 x   64: Q4_0   237.1 GFLOPS (128 runs) | Q4_1   168.6 GFLOPS (128 runs)
+  64 x   64: Q5_0   136.4 GFLOPS (128 runs) | Q5_1   135.6 GFLOPS (128 runs) | Q8_0   243.1 GFLOPS (128 runs)
+  64 x   64: F16    140.4 GFLOPS (128 runs) | F32    316.6 GFLOPS (128 runs)
+ 128 x  128: Q4_0   496.6 GFLOPS (128 runs) | Q4_1   348.6 GFLOPS (128 runs)
+ 128 x  128: Q5_0   273.2 GFLOPS (128 runs) | Q5_1   274.1 GFLOPS (128 runs) | Q8_0   505.1 GFLOPS (128 runs)
+ 128 x  128: F16    300.4 GFLOPS (128 runs) | F32    653.9 GFLOPS (128 runs)
+ 256 x  256: Q4_0   791.7 GFLOPS (128 runs) | Q4_1   615.3 GFLOPS (128 runs)
+ 256 x  256: Q5_0   651.0 GFLOPS (128 runs) | Q5_1   674.7 GFLOPS (128 runs) | Q8_0   803.1 GFLOPS (128 runs)
+ 256 x  256: F16    869.6 GFLOPS (128 runs) | F32    957.2 GFLOPS (128 runs)
+ 512 x  512: Q4_0   973.3 GFLOPS (128 runs) | Q4_1   897.9 GFLOPS (128 runs)
+ 512 x  512: Q5_0  1078.8 GFLOPS (128 runs) | Q5_1   998.4 GFLOPS (128 runs) | Q8_0   752.4 GFLOPS (128 runs)
+ 512 x  512: F16    892.5 GFLOPS (128 runs) | F32   1399.6 GFLOPS (128 runs)
+1024 x 1024: Q4_0  1402.7 GFLOPS (128 runs) | Q4_1  1218.5 GFLOPS (128 runs)
+1024 x 1024: Q5_0  1444.8 GFLOPS (128 runs) | Q5_1  1444.7 GFLOPS (128 runs) | Q8_0  1395.7 GFLOPS (128 runs)
+1024 x 1024: F16   1524.1 GFLOPS (128 runs) | F32   1726.6 GFLOPS (128 runs)
+2048 x 2048: Q4_0  1479.4 GFLOPS ( 87 runs) | Q4_1  1378.5 GFLOPS ( 81 runs)
+2048 x 2048: Q5_0  1454.6 GFLOPS ( 85 runs) | Q5_1  1462.9 GFLOPS ( 86 runs) | Q8_0  1483.2 GFLOPS ( 87 runs)
+2048 x 2048: F16   1488.0 GFLOPS ( 87 runs) | F32   1538.2 GFLOPS ( 90 runs)
+4096 x 4096: Q4_0  1509.7 GFLOPS ( 11 runs) | Q4_1  1433.0 GFLOPS ( 11 runs)
+4096 x 4096: Q5_0  1422.4 GFLOPS ( 11 runs) | Q5_1  1437.0 GFLOPS ( 11 runs) | Q8_0  1523.0 GFLOPS ( 12 runs)
+4096 x 4096: F16   1551.3 GFLOPS ( 12 runs) | F32   1451.0 GFLOPS ( 11 runs)
+
+|    CPU | Config |         Model |  Th |  FA |    Enc. |    Dec. |    Bch5 |      PP |  Commit |
+|    --- |    --- |           --- | --- | --- |     --- |     --- |     --- |     --- |     --- |
+| M1 Pro |  METAL |          tiny |   1 |   0 |   39.21 |    1.74 |    0.61 |    0.04 | 22c96b4 |
+| M1 Pro |  METAL |          base |   1 |   0 |   70.76 |    2.60 |    0.93 |    0.06 | 22c96b4 |
+| M1 Pro |  METAL |         small |   1 |   0 |  217.28 |    6.42 |    2.14 |    0.17 | 22c96b4 |
+| M1 Pro |  METAL |        medium |   1 |   0 |  596.74 |   14.43 |    4.75 |    0.45 | 22c96b4 |
+
+
+make -j && ./scripts/bench-all.sh 1 1 1
+
+|    CPU | Config |         Model |  Th |  FA |    Enc. |    Dec. |    Bch5 |      PP |  Commit |
+|    --- |    --- |           --- | --- | --- |     --- |     --- |     --- |     --- |     --- |
+| M1 Pro |  METAL |          tiny |   1 |   1 |   30.77 |    1.59 |    0.54 |    0.03 | 22c96b4 |
+| M1 Pro |  METAL |          base |   1 |   1 |   60.42 |    2.29 |    0.81 |    0.05 | 22c96b4 |
+| M1 Pro |  METAL |         small |   1 |   1 |  183.82 |    5.12 |    1.81 |    0.14 | 22c96b4 |
+| M1 Pro |  METAL |        medium |   1 |   1 |  517.92 |   11.60 |    4.01 |    0.38 | 22c96b4 |
+
+
+## M2 Ultra
+
+make -j && ./scripts/bench-all.sh 8
+
+Running memcpy benchmark
+
+memcpy:   46.58 GB/s (heat-up)
+memcpy:   54.16 GB/s ( 1 thread)
+memcpy:   54.23 GB/s ( 1 thread)
+memcpy:   99.63 GB/s ( 2 thread)
+memcpy:  140.59 GB/s ( 3 thread)
+memcpy:  176.52 GB/s ( 4 thread)
+memcpy:  158.90 GB/s ( 5 thread)
+memcpy:  163.00 GB/s ( 6 thread)
+memcpy:  189.69 GB/s ( 7 thread)
+memcpy:  197.15 GB/s ( 8 thread)
+sum:    -5120002007.000000
+
+
+make -j && ./scripts/bench-all.sh 1
+
+Running ggml_mul_mat benchmark with 1 threads
+
+  64 x   64: Q4_0   245.8 GFLOPS (128 runs) | Q4_1   168.6 GFLOPS (128 runs)
+  64 x   64: Q5_0   115.7 GFLOPS (128 runs) | Q5_1   125.9 GFLOPS (128 runs) | Q8_0   215.8 GFLOPS (128 runs)
+  64 x   64: F16    139.5 GFLOPS (128 runs) | F32    337.2 GFLOPS (128 runs)
+ 128 x  128: Q4_0   494.8 GFLOPS (128 runs) | Q4_1   350.4 GFLOPS (128 runs)
+ 128 x  128: Q5_0   257.1 GFLOPS (128 runs) | Q5_1   261.4 GFLOPS (128 runs) | Q8_0   509.4 GFLOPS (128 runs)
+ 128 x  128: F16    302.3 GFLOPS (128 runs) | F32    672.8 GFLOPS (128 runs)
+ 256 x  256: Q4_0   795.7 GFLOPS (128 runs) | Q4_1   663.7 GFLOPS (128 runs)
+ 256 x  256: Q5_0   737.8 GFLOPS (128 runs) | Q5_1   757.6 GFLOPS (128 runs) | Q8_0   827.7 GFLOPS (128 runs)
+ 256 x  256: F16    872.6 GFLOPS (128 runs) | F32    956.3 GFLOPS (128 runs)
+ 512 x  512: Q4_0  1188.0 GFLOPS (128 runs) | Q4_1  1085.0 GFLOPS (128 runs)
+ 512 x  512: Q5_0  1421.1 GFLOPS (128 runs) | Q5_1  1454.9 GFLOPS (128 runs) | Q8_0  1191.4 GFLOPS (128 runs)
+ 512 x  512: F16   1577.4 GFLOPS (128 runs) | F32   1982.0 GFLOPS (128 runs)
+1024 x 1024: Q4_0  2342.6 GFLOPS (128 runs) | Q4_1  1955.8 GFLOPS (128 runs)
+1024 x 1024: Q5_0  2306.7 GFLOPS (128 runs) | Q5_1  2217.0 GFLOPS (128 runs) | Q8_0  2230.7 GFLOPS (128 runs)
+1024 x 1024: F16   2593.8 GFLOPS (128 runs) | F32   3269.0 GFLOPS (128 runs)
+2048 x 2048: Q4_0  3735.7 GFLOPS (128 runs) | Q4_1  3205.3 GFLOPS (128 runs)
+2048 x 2048: Q5_0  3584.5 GFLOPS (128 runs) | Q5_1  3621.7 GFLOPS (128 runs) | Q8_0  3622.3 GFLOPS (128 runs)
+2048 x 2048: F16   3763.6 GFLOPS (128 runs) | F32   4153.3 GFLOPS (128 runs)
+4096 x 4096: Q4_0  3891.1 GFLOPS ( 29 runs) | Q4_1  3554.0 GFLOPS ( 26 runs)
+4096 x 4096: Q5_0  3753.1 GFLOPS ( 28 runs) | Q5_1  3750.1 GFLOPS ( 28 runs) | Q8_0  3768.5 GFLOPS ( 28 runs)
+4096 x 4096: F16   3864.2 GFLOPS ( 29 runs) | F32   3970.5 GFLOPS ( 29 runs)
+
+
+make -j && ./scripts/bench-all.sh 1 1 0
+
+|      CPU | Config |         Model |  Th |  FA |    Enc. |    Dec. |    Bch5 |      PP |  Commit |
+|      --- |    --- |           --- | --- | --- |     --- |     --- |     --- |     --- |     --- |
+| M2 ULTRA |  METAL |          tiny |   1 |   0 |   12.32 |    1.35 |    0.49 |    0.01 | 22c96b4 |
+| M2 ULTRA |  METAL |     tiny-q5_0 |   1 |   0 |   11.65 |    1.30 |    0.51 |    0.01 | 22c96b4 |
+| M2 ULTRA |  METAL |     tiny-q5_1 |   1 |   0 |   12.08 |    1.30 |    0.51 |    0.01 | 22c96b4 |
+| M2 ULTRA |  METAL |          base |   1 |   0 |   17.58 |    1.90 |    0.76 |    0.02 | 22c96b4 |
+| M2 ULTRA |  METAL |     base-q5_0 |   1 |   0 |   18.89 |    1.86 |    0.79 |    0.02 | 22c96b4 |
+| M2 ULTRA |  METAL |     base-q5_1 |   1 |   0 |   20.69 |    1.88 |    0.79 |    0.02 | 22c96b4 |
+| M2 ULTRA |  METAL |         small |   1 |   0 |   49.32 |    3.85 |    1.71 |    0.05 | 22c96b4 |
+| M2 ULTRA |  METAL |    small-q5_0 |   1 |   0 |   54.91 |    3.81 |    1.82 |    0.06 | 22c96b4 |
+| M2 ULTRA |  METAL |    small-q5_1 |   1 |   0 |   54.92 |    3.81 |    1.79 |    0.06 | 22c96b4 |
+| M2 ULTRA |  METAL |        medium |   1 |   0 |  134.34 |    8.04 |    3.82 |    0.13 | 22c96b4 |
+| M2 ULTRA |  METAL |   medium-q5_0 |   1 |   0 |  151.68 |    7.59 |    4.07 |    0.14 | 22c96b4 |
+| M2 ULTRA |  METAL |   medium-q5_1 |   1 |   0 |  151.58 |    7.67 |    4.07 |    0.14 | 22c96b4 |
+| M2 ULTRA |  METAL |    medium-dis |   1 |   0 |  120.82 |    1.07 |    0.41 |    0.02 | 22c96b4 |
+| M2 ULTRA |  METAL |      large-v2 |   1 |   0 |  235.63 |   12.27 |    5.85 |    0.22 | 22c96b4 |
+| M2 ULTRA |  METAL | large-v2-q5_0 |   1 |   0 |  273.38 |   11.17 |    6.40 |    0.26 | 22c96b4 |
+| M2 ULTRA |  METAL | large-v2-q5_1 |   1 |   0 |  272.44 |   11.32 |    6.29 |    0.26 | 22c96b4 |
+| M2 ULTRA |  METAL |  large-v2-dis |   1 |   0 |  212.51 |    1.20 |    0.47 |    0.02 | 22c96b4 |
+
+
+make -j && ./scripts/bench-all.sh 1 1 1
+
+|      CPU | Config |         Model |  Th |  FA |    Enc. |    Dec. |    Bch5 |      PP |  Commit |
+|      --- |    --- |           --- | --- | --- |     --- |     --- |     --- |     --- |     --- |
+| M2 ULTRA |  METAL |          tiny |   1 |   1 |    9.07 |    1.33 |    0.45 |    0.01 | 22c96b4 |
+| M2 ULTRA |  METAL |     tiny-q5_0 |   1 |   1 |    9.74 |    1.33 |    0.47 |    0.01 | 22c96b4 |
+| M2 ULTRA |  METAL |     tiny-q5_1 |   1 |   1 |    8.93 |    1.31 |    0.46 |    0.01 | 22c96b4 |
+| M2 ULTRA |  METAL |          base |   1 |   1 |   15.75 |    1.87 |    0.71 |    0.02 | 22c96b4 |
+| M2 ULTRA |  METAL |     base-q5_0 |   1 |   1 |   17.04 |    1.83 |    0.74 |    0.02 | 22c96b4 |
+| M2 ULTRA |  METAL |     base-q5_1 |   1 |   1 |   17.17 |    1.83 |    0.74 |    0.02 | 22c96b4 |
+| M2 ULTRA |  METAL |         small |   1 |   1 |   42.33 |    3.64 |    1.60 |    0.05 | 22c96b4 |
+| M2 ULTRA |  METAL |    small-q5_0 |   1 |   1 |   47.61 |    3.63 |    1.70 |    0.05 | 22c96b4 |
+| M2 ULTRA |  METAL |    small-q5_1 |   1 |   1 |   47.70 |    3.66 |    1.68 |    0.05 | 22c96b4 |
+| M2 ULTRA |  METAL |        medium |   1 |   1 |  114.42 |    7.53 |    3.55 |    0.11 | 22c96b4 |
+| M2 ULTRA |  METAL |   medium-q5_0 |   1 |   1 |  132.63 |    7.02 |    3.77 |    0.13 | 22c96b4 |
+| M2 ULTRA |  METAL |   medium-q5_1 |   1 |   1 |  132.28 |    7.10 |    3.76 |    0.13 | 22c96b4 |
+| M2 ULTRA |  METAL |    medium-dis |   1 |   1 |  102.34 |    1.01 |    0.42 |    0.01 | 22c96b4 |
+| M2 ULTRA |  METAL |      large-v2 |   1 |   1 |  203.01 |   11.03 |    5.45 |    0.20 | 22c96b4 |
+| M2 ULTRA |  METAL | large-v2-q5_0 |   1 |   1 |  240.05 |   10.18 |    5.98 |    0.23 | 22c96b4 |
+| M2 ULTRA |  METAL | large-v2-q5_1 |   1 |   1 |  239.22 |   10.23 |    5.87 |    0.23 | 22c96b4 |
+| M2 ULTRA |  METAL |  large-v2-dis |   1 |   1 |  181.14 |    1.14 |    0.48 |    0.02 | 22c96b4 |
+
+
+
+## Ryzen 9 5950X + RTX 2060
+
+make -j && ./scripts/bench-all.sh 8 0 0
+
+Running memcpy benchmark
+
+memcpy:   12.36 GB/s (heat-up)
+memcpy:   12.33 GB/s ( 1 thread)
+memcpy:   12.38 GB/s ( 1 thread)
+memcpy:   14.48 GB/s ( 2 thread)
+memcpy:   15.00 GB/s ( 3 thread)
+memcpy:   14.77 GB/s ( 4 thread)
+memcpy:   14.60 GB/s ( 5 thread)
+memcpy:   14.57 GB/s ( 6 thread)
+memcpy:   14.34 GB/s ( 7 thread)
+memcpy:   14.40 GB/s ( 8 thread)
+sum:    -5119998076.000000
+
+Running ggml_mul_mat benchmark with 8 threads
+
+  64 x   64: Q4_0     3.1 GFLOPS (128 runs) | Q4_1     3.1 GFLOPS (128 runs)
+  64 x   64: Q5_0     3.0 GFLOPS (128 runs) | Q5_1     2.9 GFLOPS (128 runs) | Q8_0     3.1 GFLOPS (128 runs)
+  64 x   64: F16      3.0 GFLOPS (128 runs) | F32      3.0 GFLOPS (128 runs)
+ 128 x  128: Q4_0    21.1 GFLOPS (128 runs) | Q4_1    20.3 GFLOPS (128 runs)
+ 128 x  128: Q5_0    20.6 GFLOPS (128 runs) | Q5_1    20.4 GFLOPS (128 runs) | Q8_0    22.1 GFLOPS (128 runs)
+ 128 x  128: F16     21.7 GFLOPS (128 runs) | F32     21.7 GFLOPS (128 runs)
+ 256 x  256: Q4_0   105.7 GFLOPS (128 runs) | Q4_1    94.4 GFLOPS (128 runs)
+ 256 x  256: Q5_0    94.8 GFLOPS (128 runs) | Q5_1    87.5 GFLOPS (128 runs) | Q8_0   107.2 GFLOPS (128 runs)
+ 256 x  256: F16     95.1 GFLOPS (128 runs) | F32     94.3 GFLOPS (128 runs)
+ 512 x  512: Q4_0   214.7 GFLOPS (128 runs) | Q4_1   189.8 GFLOPS (128 runs)
+ 512 x  512: Q5_0   187.7 GFLOPS (128 runs) | Q5_1   176.2 GFLOPS (128 runs) | Q8_0   252.2 GFLOPS (128 runs)
+ 512 x  512: F16    220.8 GFLOPS (128 runs) | F32    218.3 GFLOPS (128 runs)
+1024 x 1024: Q4_0   333.7 GFLOPS (128 runs) | Q4_1   305.8 GFLOPS (128 runs)
+1024 x 1024: Q5_0   283.2 GFLOPS (128 runs) | Q5_1   268.2 GFLOPS (125 runs) | Q8_0   394.1 GFLOPS (128 runs)
+1024 x 1024: F16    355.0 GFLOPS (128 runs) | F32    313.0 GFLOPS (128 runs)
+2048 x 2048: Q4_0   395.0 GFLOPS ( 23 runs) | Q4_1   380.6 GFLOPS ( 23 runs)
+2048 x 2048: Q5_0   336.6 GFLOPS ( 20 runs) | Q5_1   318.4 GFLOPS ( 19 runs) | Q8_0   482.6 GFLOPS ( 29 runs)
+2048 x 2048: F16    424.5 GFLOPS ( 25 runs) | F32    337.7 GFLOPS ( 20 runs)
+4096 x 4096: Q4_0   412.8 GFLOPS (  4 runs) | Q4_1   405.1 GFLOPS (  3 runs)
+4096 x 4096: Q5_0   346.0 GFLOPS (  3 runs) | Q5_1   334.6 GFLOPS (  3 runs) | Q8_0   502.6 GFLOPS (  4 runs)
+4096 x 4096: F16    412.5 GFLOPS (  4 runs) | F32    274.0 GFLOPS (  3 runs)
+
+|           CPU | Config |         Model |  Th |  FA |    Enc. |    Dec. |    Bch5 |      PP |  Commit |
+|           --- |    --- |           --- | --- | --- |     --- |     --- |     --- |     --- |     --- |
+| Ryzen 9 5950X |   AVX2 |          tiny |   8 |   0 |  195.29 |    1.57 |    0.51 |    0.26 | 22c96b4 |
+| Ryzen 9 5950X |   AVX2 |     tiny-q5_0 |   8 |   0 |  213.33 |    1.10 |    0.50 |    0.30 | 22c96b4 |
+| Ryzen 9 5950X |   AVX2 |     tiny-q5_1 |   8 |   0 |  219.38 |    1.18 |    0.53 |    0.32 | 22c96b4 |
+| Ryzen 9 5950X |   AVX2 |          base |   8 |   0 |  424.85 |    3.71 |    1.03 |    0.46 | 22c96b4 |
+| Ryzen 9 5950X |   AVX2 |     base-q5_0 |   8 |   0 |  473.61 |    1.81 |    0.82 |    0.52 | 22c96b4 |
+| Ryzen 9 5950X |   AVX2 |     base-q5_1 |   8 |   0 |  484.14 |    1.92 |    0.85 |    0.56 | 22c96b4 |
+| Ryzen 9 5950X |   AVX2 |         small |   8 |   0 | 1458.32 |   12.66 |    3.09 |    1.26 | 22c96b4 |
+| Ryzen 9 5950X |   AVX2 |    small-q5_0 |   8 |   0 | 1673.22 |    6.42 |    2.18 |    1.45 | 22c96b4 |
+| Ryzen 9 5950X |   AVX2 |    small-q5_1 |   8 |   0 | 1724.78 |    6.72 |    2.32 |    1.52 | 22c96b4 |
+| Ryzen 9 5950X |   AVX2 |        medium |   8 |   0 | 4333.87 |   36.80 |    8.56 |    3.37 | 22c96b4 |
+| Ryzen 9 5950X |   AVX2 |   medium-q5_0 |   8 |   0 | 5194.09 |   19.21 |    5.71 |    3.97 | 22c96b4 |
+| Ryzen 9 5950X |   AVX2 |   medium-q5_1 |   8 |   0 | 5450.39 |   20.01 |    5.99 |    4.17 | 22c96b4 |
+| Ryzen 9 5950X |   AVX2 |    medium-dis |   8 |   0 | 3995.19 |    5.08 |    1.21 |    0.55 | 22c96b4 |
+| Ryzen 9 5950X |   AVX2 |      large-v2 |   8 |   0 | 8056.16 |   69.74 |   16.11 |    6.13 | 22c96b4 |
+| Ryzen 9 5950X |   AVX2 | large-v2-q5_0 |   8 |   0 | 9799.58 |   35.16 |   10.49 |    7.28 | 22c96b4 |
+| Ryzen 9 5950X |   AVX2 | large-v2-q5_1 |   8 |   0 |      ms |   36.74 |   11.02 |    7.65 | 22c96b4 |
+| Ryzen 9 5950X |   AVX2 |  large-v2-dis |   8 |   0 | 7490.03 |    7.40 |    1.70 |    0.72 | 22c96b4 |
+
+
+WHISPER_CUDA=1 make -j && ./scripts/bench-all.sh 8 1 0
+
+|      GPU |    Config |         Model |  Th |  FA |    Enc. |    Dec. |    Bch5 |      PP |  Commit |
+|      --- |       --- |           --- | --- | --- |     --- |     --- |     --- |     --- |     --- |
+| RTX 2060 | AVX2 CUDA |          tiny |   8 |   0 |   12.54 |    0.93 |    0.29 |    0.02 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |     tiny-q5_0 |   8 |   0 |   12.73 |    0.98 |    0.24 |    0.02 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |     tiny-q5_1 |   8 |   0 |   12.72 |    0.99 |    0.24 |    0.02 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |          base |   8 |   0 |   24.14 |    1.28 |    0.41 |    0.03 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |     base-q5_0 |   8 |   0 |   24.58 |    1.38 |    0.35 |    0.03 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |     base-q5_1 |   8 |   0 |   24.58 |    1.37 |    0.35 |    0.03 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |         small |   8 |   0 |   74.70 |    2.91 |    0.84 |    0.07 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |    small-q5_0 |   8 |   0 |   76.12 |    2.84 |    0.77 |    0.08 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |    small-q5_1 |   8 |   0 |   76.14 |    2.84 |    0.76 |    0.08 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |        medium |   8 |   0 |  200.69 |    6.46 |    1.83 |    0.17 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |   medium-q5_0 |   8 |   0 |  204.80 |    5.90 |    1.65 |    0.19 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |   medium-q5_1 |   8 |   0 |  205.61 |    5.85 |    1.61 |    0.19 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |    medium-dis |   8 |   0 |  186.17 |    0.86 |    0.24 |    0.02 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |      large-v2 |   8 |   0 |  347.22 |   10.36 |    2.82 |    0.29 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA | large-v2-q5_0 |   8 |   0 |  357.06 |    8.81 |    2.58 |    0.34 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA | large-v2-q5_1 |   8 |   0 |  356.97 |    8.62 |    2.49 |    0.33 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |  large-v2-dis |   8 |   0 |  318.05 |    1.03 |    0.34 |    0.04 | 22c96b4 |
+
+
+WHISPER_CUDA=1 make -j && ./scripts/bench-all.sh 8 1 1
+
+|      GPU |    Config |         Model |  Th |  FA |    Enc. |    Dec. |    Bch5 |      PP |  Commit |
+|      --- |       --- |           --- | --- | --- |     --- |     --- |     --- |     --- |     --- |
+| RTX 2060 | AVX2 CUDA |          tiny |   8 |   1 |    7.21 |    0.76 |    0.29 |    0.02 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |     tiny-q5_0 |   8 |   1 |    7.42 |    0.82 |    0.18 |    0.02 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |     tiny-q5_1 |   8 |   1 |    7.38 |    0.82 |    0.18 |    0.02 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |          base |   8 |   1 |   13.49 |    1.04 |    0.36 |    0.02 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |     base-q5_0 |   8 |   1 |   13.94 |    1.13 |    0.26 |    0.03 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |     base-q5_1 |   8 |   1 |   13.94 |    1.14 |    0.26 |    0.03 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |         small |   8 |   1 |   42.81 |    2.33 |    0.69 |    0.05 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |    small-q5_0 |   8 |   1 |   44.43 |    2.25 |    0.59 |    0.06 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |    small-q5_1 |   8 |   1 |   44.11 |    2.24 |    0.58 |    0.06 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |        medium |   8 |   1 |  115.47 |    5.17 |    1.45 |    0.11 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |   medium-q5_0 |   8 |   1 |  120.37 |    4.63 |    1.25 |    0.13 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |   medium-q5_1 |   8 |   1 |  120.28 |    4.55 |    1.21 |    0.13 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |    medium-dis |   8 |   1 |  101.69 |    0.75 |    0.20 |    0.02 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |      large-v2 |   8 |   1 |  205.67 |    8.49 |    2.19 |    0.18 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA | large-v2-q5_0 |   8 |   1 |  214.07 |    6.88 |    1.94 |    0.22 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA | large-v2-q5_1 |   8 |   1 |  213.98 |    6.70 |    1.86 |    0.22 | 22c96b4 |
+| RTX 2060 | AVX2 CUDA |  large-v2-dis |   8 |   1 |  176.71 |    0.91 |    0.31 |    0.03 | 22c96b4 |
+
+
+
+
+# V100
+
+WHISPER_CUDA=1 make -j && ./scripts/bench-all.sh 8 1 0
+
+|  GPU |    Config |         Model |  Th |  FA |    Enc. |    Dec. |    Bch5 |      PP |  Commit |
+|  --- |       --- |           --- | --- | --- |     --- |     --- |     --- |     --- |     --- |
+| V100 | AVX2 CUDA |          tiny |   1 |   0 |    6.21 |    1.11 |    0.30 |    0.02 | 22c96b4 |
+| V100 | AVX2 CUDA |     tiny-q5_1 |   1 |   0 |    5.97 |    1.10 |    0.26 |    0.02 | 22c96b4 |
+| V100 | AVX2 CUDA |          base |   1 |   0 |   10.95 |    1.47 |    0.42 |    0.03 | 22c96b4 |
+| V100 | AVX2 CUDA |     base-q5_1 |   1 |   0 |   11.13 |    1.53 |    0.36 |    0.03 | 22c96b4 |
+| V100 | AVX2 CUDA |         small |   1 |   0 |   31.57 |    2.96 |    0.84 |    0.05 | 22c96b4 |
+| V100 | AVX2 CUDA |    small-q5_1 |   1 |   0 |   32.19 |    3.14 |    0.75 |    0.05 | 22c96b4 |
+| V100 | AVX2 CUDA |        medium |   1 |   0 |   85.88 |    6.49 |    1.80 |    0.10 | 22c96b4 |
+| V100 | AVX2 CUDA |   medium-q5_0 |   1 |   0 |   87.53 |    5.82 |    1.37 |    0.10 | 22c96b4 |
+| V100 | AVX2 CUDA |      large-v2 |   1 |   0 |  142.23 |    8.92 |    2.62 |    0.15 | 22c96b4 |
+
+
+WHISPER_CUDA=1 make -j && ./scripts/bench-all.sh 8 1 1
+
+|  GPU |    Config |         Model |  Th |  FA |    Enc. |    Dec. |    Bch5 |      PP |  Commit |
+|  --- |       --- |           --- | --- | --- |     --- |     --- |     --- |     --- |     --- |
+| V100 | AVX2 CUDA |          tiny |   1 |   1 |    3.96 |    0.82 |    0.24 |    0.02 | 22c96b4 |
+| V100 | AVX2 CUDA |     tiny-q5_1 |   1 |   1 |    4.05 |    0.85 |    0.18 |    0.02 | 22c96b4 |
+| V100 | AVX2 CUDA |          base |   1 |   1 |    7.21 |    1.16 |    0.36 |    0.02 | 22c96b4 |
+| V100 | AVX2 CUDA |     base-q5_1 |   1 |   1 |    7.39 |    1.21 |    0.26 |    0.02 | 22c96b4 |
+| V100 | AVX2 CUDA |         small |   1 |   1 |   19.81 |    2.41 |    0.71 |    0.04 | 22c96b4 |
+| V100 | AVX2 CUDA |    small-q5_1 |   1 |   1 |   20.50 |    2.31 |    0.51 |    0.04 | 22c96b4 |
+| V100 | AVX2 CUDA |        medium |   1 |   1 |   56.02 |    4.89 |    1.44 |    0.07 | 22c96b4 |
+| V100 | AVX2 CUDA |   medium-q5_0 |   1 |   1 |   57.85 |    4.73 |    1.09 |    0.08 | 22c96b4 |
+| V100 | AVX2 CUDA |      large-v2 |   1 |   1 |   92.73 |    7.18 |    2.14 |    0.10 | 22c96b4 |
+
index 6939dafaca042f3032473f414c729875223e2a68..8a857c67b6cd4c875e2e7e7806320462c6e2f321 100755 (executable)
@@ -2,7 +2,7 @@
 
 # Helper script to run the bench tool on all models and print the results in share-able format
 
-printf "Usage: ./bench.sh [n_threads] [encoder-only]\n"
+printf "Usage: ./bench.sh [n_threads] [encoder-only] [flash-attn]\n"
 
 if [ -z "$1" ]; then
     n_threads=4
@@ -11,12 +11,19 @@ else
 fi
 
 encoder_only=0
-if [ -z "$2" ]; then
+if [ -z "$2" ] || [ "$2" -eq 0 ]; then
     encoder_only=0
 else
     encoder_only=$2
 fi
 
+fattn=""
+if [ -z "$3" ] || [ "$3" -eq 0 ]; then
+    fattn=""
+else
+    fattn="-fa"
+fi
+
 models=(                                                                                                    \
       "tiny"     "tiny-q4_0"     "tiny-q4_1"     "tiny-q5_0"     "tiny-q5_1"     "tiny-q8_0"                \
       "base"     "base-q4_0"     "base-q4_1"     "base-q5_0"     "base-q5_1"     "base-q8_0"                \
@@ -44,13 +51,19 @@ if [ "$encoder_only" -eq 0 ]; then
     printf "\n"
 fi
 
-printf "| %6s | %6s | %16s | %13s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "Bch5" "PP" "Commit"
-printf "| %6s | %6s | %16s | %13s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---"
+if [ "$fattn" == "-fa" ]; then
+    fattn_i=1
+else
+    fattn_i=0
+fi
+
+printf "| %6s | %6s | %16s | %13s | %3s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "FA" "Enc." "Dec." "Bch5" "PP" "Commit"
+printf "| %6s | %6s | %16s | %13s | %3s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---"
 
 for model in "${models[@]}"; do
     # actual run
     # store stderr output in a variable in order to parse it later
-    output=$(./bench -m ./models/ggml-$model.bin -t $n_threads 2>&1)
+    output=$(./bench -m ./models/ggml-$model.bin -t $n_threads $fattn 2>&1)
     ret=$?
 
     # parse the output:
@@ -95,6 +108,6 @@ for model in "${models[@]}"; do
     commit=$(git rev-parse --short HEAD)
 
     if [ $ret -eq 0 ]; then
-        printf "| <todo> | <todo> | %16s | %13s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$batchd_time" "$prompt_time" "$commit"
+        printf "| <todo> | <todo> | %16s | %13s | %3s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$fattn_i" "$encode_time" "$decode_time" "$batchd_time" "$prompt_time" "$commit"
     fi
 done
index ff4223daf429ea416261f6fea3256fa2cdf203df..84aec8238cdb42d19cab3ef2e97a5aa0b91a2a94 100644 (file)
@@ -809,14 +809,15 @@ struct whisper_state {
     // shared between all decoders
     whisper_kv_cache kv_cross;
 
+    // padded buffer for flash-attention
+    whisper_kv_cache kv_pad;
+
     whisper_mel mel;
 
     whisper_batch batch;
 
     whisper_decoder decoders[WHISPER_MAX_DECODERS];
 
-    ggml_backend_t backend = nullptr;
-
     // ggml-alloc:
     // - stores meta info about the intermediate tensors into the `meta` buffers
     // - stores the actual tensor data into the `data` buffers
@@ -902,14 +903,12 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
 }
 
 static bool kv_cache_init(
-        const struct whisper_hparams & hparams,
              struct whisper_kv_cache & cache,
                       ggml_backend_t   backend,
                            ggml_type   wtype,
+                             int64_t   n_text_state,
+                             int64_t   n_text_layer,
                                  int   n_ctx) {
-    const int64_t n_text_state = hparams.n_text_state;
-    const int64_t n_text_layer = hparams.n_text_layer;
-
     const int64_t n_mem      = n_text_layer*n_ctx;
     const int64_t n_elements = n_text_state*n_mem;
 
@@ -941,6 +940,8 @@ static bool kv_cache_init(
         return false;
     }
 
+    ggml_backend_buffer_clear(cache.buffer, 0);
+
     return true;
 }
 
@@ -1068,6 +1069,26 @@ static void whisper_kv_cache_seq_cp(
     }
 }
 
+static uint32_t whisper_kv_cache_get_padding(const struct whisper_context & wctx) {
+    if (!wctx.params.flash_attn) {
+        return 1u;
+    }
+
+#ifdef GGML_USE_METAL
+    if (ggml_backend_is_metal(wctx.backend)) {
+        return 32u;
+    }
+#endif
+
+#ifdef GGML_USE_CUDA
+    if (ggml_backend_is_cuda(wctx.backend)) {
+        return 256u;
+    }
+#endif
+
+    return 1u;
+}
+
 // [EXPERIMENTAL] Token-level timestamps with DTW
 static bool aheads_masks_init(
         const whisper_context_params & cparams,
@@ -1872,6 +1893,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
     const int n_head  = hparams.n_audio_head;
     const int n_layer = hparams.n_audio_layer;
 
+    const int n_state_head = n_state/n_head;
+
+    auto & kv_pad = wstate.kv_pad;
+
+    WHISPER_ASSERT(!!kv_pad.ctx);
+
+    const int n_ctx_pad = GGML_PAD(n_ctx, 256);
+
     struct ggml_init_params params = {
         /*.mem_size   =*/ wstate.alloc_encode.meta.size(),
         /*.mem_buffer =*/ wstate.alloc_encode.meta.data(),
@@ -1884,7 +1913,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
 
     struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
 
-    const float KQscale = 1.0f/sqrtf(float(n_state)/n_head);
+    const float KQscale = 1.0f/sqrtf(float(n_state_head));
 
     // ===================================================================
     // NOTE: experimenting with partial evaluation of the encoder (ignore)
@@ -1934,14 +1963,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
 
             Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b);
 
-            //Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state)/n_head, -0.25));
+            //Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state_head), -0.25));
 
             // note: no bias for Key
             struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
                     layer.attn_k_w,
                     cur);
 
-            //Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state)/n_head, -0.25));
+            //Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state_head), -0.25));
 
             struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
                     layer.attn_v_w,
@@ -1955,38 +1984,61 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
                 ggml_permute(ctx0,
                         ggml_cpy(ctx0,
                             Qcur,
-                            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
+                            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state_head, n_head, n_ctx)),
                         0, 2, 1, 3);
 
-            struct ggml_tensor * K =
-                ggml_permute(ctx0,
-                        ggml_cpy(ctx0,
-                            Kcur,
-                            ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
-                        0, 2, 1, 3);
-
-            // K * Q
-            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+            if (wctx.params.flash_attn) {
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, ggml_view_1d(ctx0, kv_pad.k, n_ctx*n_state, 0)));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, ggml_view_1d(ctx0, kv_pad.v, n_ctx*n_state, 0)));
 
-            struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
+                struct ggml_tensor * K =
+                    ggml_view_3d(ctx0, kv_pad.k,
+                            n_state_head, n_ctx_pad, n_head,
+                            ggml_element_size(kv_pad.k)*n_state,
+                            ggml_element_size(kv_pad.k)*n_state_head,
+                            0);
 
-            struct ggml_tensor * V =
-                ggml_cpy(ctx0,
-                        ggml_permute(ctx0,
-                            ggml_reshape_3d(ctx0,
-                                Vcur,
-                                n_state/n_head, n_head, n_ctx),
-                            1, 2, 0, 3),
-                        ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
-                        );
+                struct ggml_tensor * V =
+                    ggml_view_3d(ctx0, kv_pad.v,
+                            n_state_head, n_ctx_pad, n_head,
+                            ggml_element_size(kv_pad.v)*n_state,
+                            ggml_element_size(kv_pad.v)*n_state_head,
+                            0);
 
-            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+                cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f);
 
-            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
-
-            cur = ggml_cpy(ctx0,
-                    KQV_merged,
-                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
+                cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx);
+            } else {
+                struct ggml_tensor * K =
+                    ggml_permute(ctx0,
+                            ggml_cpy(ctx0,
+                                Kcur,
+                                ggml_new_tensor_3d(ctx0, wctx.itype, n_state_head, n_head, n_ctx)),
+                            0, 2, 1, 3);
+
+                // K * Q
+                struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+
+                struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
+
+                struct ggml_tensor * V =
+                    ggml_cpy(ctx0,
+                            ggml_permute(ctx0,
+                                ggml_reshape_3d(ctx0,
+                                    Vcur,
+                                    n_state_head, n_head, n_ctx),
+                                1, 2, 0, 3),
+                            ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state_head, n_head)
+                            );
+
+                struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+
+                struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+
+                cur = ggml_cpy(ctx0,
+                        KQV_merged,
+                        ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
+            }
         }
 
         // projection
@@ -2085,6 +2137,10 @@ static struct ggml_cgraph * whisper_build_graph_cross(
     const int n_state = hparams.n_audio_state;
     const int n_head  = hparams.n_audio_head;
 
+    const int n_state_head = n_state/n_head;
+
+    const int n_ctx_pad = GGML_PAD(n_ctx, 256);
+
     struct ggml_init_params params = {
         /*.mem_size   =*/ wstate.alloc_cross.meta.size(),
         /*.mem_buffer =*/ wstate.alloc_cross.meta.data(),
@@ -2097,18 +2153,18 @@ static struct ggml_cgraph * whisper_build_graph_cross(
 
     struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
 
-    const float  Kscale = pow(float(n_state) / n_head, -0.25);
+    const float  Kscale = pow(float(n_state_head), -0.25);
 
     for (int il = 0; il < model.hparams.n_text_layer; ++il) {
         auto & layer = model.layers_decoder[il];
 
-        struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
+        struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
                 layer.cross_attn_k_w,
                 cur);
 
         Kcross = ggml_scale(ctx0, Kcross, Kscale);
 
-        struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
+        struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
                 layer.cross_attn_v_w,
                 cur);
 
@@ -2116,15 +2172,25 @@ static struct ggml_cgraph * whisper_build_graph_cross(
                     Vcross,
                     layer.cross_attn_v_b);
 
-        Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
+        struct ggml_tensor * k;
+        struct ggml_tensor * v;
 
-        struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k,
-                n_state*n_ctx,
-                (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
+        if (wctx.params.flash_attn) {
+            k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
+                    (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx_pad));
 
-        struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
-                (   n_ctx)*ggml_element_size(wstate.kv_cross.v),
-                (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
+            v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx,
+                    (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx_pad));
+        } else {
+            Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
+
+            k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
+                    (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
+
+            v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
+                    (   n_ctx)*ggml_element_size(wstate.kv_cross.v),
+                    (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
+        }
 
         ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k));
         ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v));
@@ -2195,7 +2261,7 @@ static bool whisper_encode_internal(
         }
 
         if (!whisper_encode_external(wstate)) {
-            if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
+            if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) {
                 return false;
             }
         } else {
@@ -2218,7 +2284,7 @@ static bool whisper_encode_internal(
             return false;
         }
 
-        if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
+        if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) {
             return false;
         }
     }
@@ -2234,7 +2300,7 @@ static bool whisper_encode_internal(
             return false;
         }
 
-        if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
+        if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) {
             return false;
         }
     }
@@ -2263,11 +2329,15 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
     const int n_head  = hparams.n_text_head;
     const int n_layer = hparams.n_text_layer;
 
+    const int n_state_head = n_state/n_head;
+
     const int n_tokens    = batch.n_tokens;
     const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
 
-    const int32_t n_kv     = worst_case ? n_ctx            : kv_self.n;
-    const int32_t kv_head  = worst_case ? n_ctx - n_tokens : kv_self.head;
+    const int n_audio_ctx_pad = GGML_PAD(n_audio_ctx, 256);
+
+    const int32_t n_kv    = worst_case ? n_ctx            : kv_self.n;
+    const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
 
     //WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
 
@@ -2289,12 +2359,14 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
     ggml_set_name(position, "position");
     ggml_set_input(position);
 
-    const float KQscale = pow(float(n_state)/n_head, -0.25);
+    const float KQscale = pow(float(n_state_head), -0.25);
 
-    struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+    struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1);
     ggml_set_name(KQ_mask, "KQ_mask");
     ggml_set_input(KQ_mask);
 
+    struct ggml_tensor * KQ_mask_f16 = ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16);
+
     // token encoding + position encoding
     struct ggml_tensor * cur =
         ggml_add(ctx0,
@@ -2350,12 +2422,25 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
                             Vcur,
                             layer.attn_v_b);
 
-                Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
+                struct ggml_tensor * k;
+                struct ggml_tensor * v;
+
+                if (wctx.params.flash_attn) {
+                    k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
+                            (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
+
+                    v = ggml_view_1d(ctx0, kv_self.v, n_tokens*n_state,
+                            (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + kv_head));
+                } else {
+                    Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
 
-                struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
-                struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
-                        (   n_ctx)*ggml_element_size(kv_self.v),
-                        (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v));
+                    k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
+                            (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
+
+                    v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
+                            (   n_ctx)*ggml_element_size(kv_self.v),
+                            (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v));
+                }
 
                 ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
                 ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
@@ -2365,35 +2450,48 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
 
             struct ggml_tensor * Q =
                 ggml_permute(ctx0,
-                        ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
+                        ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
                         0, 2, 1, 3);
 
             struct ggml_tensor * K =
                 ggml_view_3d(ctx0, kv_self.k,
-                        n_state/n_head, n_kv, n_head,
+                        n_state_head, n_kv, n_head,
                         ggml_element_size(kv_self.k)*n_state,
-                        ggml_element_size(kv_self.k)*n_state/n_head,
+                        ggml_element_size(kv_self.k)*n_state_head,
                         ggml_element_size(kv_self.k)*n_state*n_ctx*il);
 
-            // K * Q
-            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+            if (wctx.params.flash_attn) {
+                struct ggml_tensor * V =
+                    ggml_view_3d(ctx0, kv_self.v,
+                            n_state_head, n_kv, n_head,
+                            ggml_element_size(kv_self.v)*n_state,
+                            ggml_element_size(kv_self.v)*n_state_head,
+                            ggml_element_size(kv_self.v)*n_state*n_ctx*il);
+
+                cur = ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f);
+
+                cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
+            } else {
+                // K * Q
+                struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
 
-            struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f);
+                struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f);
 
-            struct ggml_tensor * V =
-                ggml_view_3d(ctx0, kv_self.v,
-                        n_kv, n_state/n_head, n_head,
-                        n_ctx*ggml_element_size(kv_self.v),
-                        n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
-                        n_ctx*ggml_element_size(kv_self.v)*n_state*il);
+                struct ggml_tensor * V =
+                    ggml_view_3d(ctx0, kv_self.v,
+                            n_kv, n_state_head, n_head,
+                            n_ctx*ggml_element_size(kv_self.v),
+                            n_ctx*ggml_element_size(kv_self.v)*n_state_head,
+                            n_ctx*ggml_element_size(kv_self.v)*n_state*il);
 
-            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+                struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
 
-            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+                struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 
-            cur = ggml_cpy(ctx0,
-                    KQV_merged,
-                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
+                cur = ggml_cpy(ctx0,
+                        KQV_merged,
+                        ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
+            }
         }
 
         // projection
@@ -2432,80 +2530,77 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
                         Qcur,
                         layer.cross_attn_q_b);
 
-            Qcur = ggml_scale(ctx0, Qcur, KQscale);
-
-            // Kcross is already scaled
-            struct ggml_tensor * Kcross =
-                ggml_view_3d(ctx0, wstate.kv_cross.k,
-                        n_state/n_head, n_audio_ctx, n_head,
-                        ggml_element_size(wstate.kv_cross.k)*n_state,
-                        ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
-                        ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
-
-            //struct ggml_tensor * Vcross =
-            //    ggml_reshape_3d(ctx0,
-            //            ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state),
-            //            n_state/n_head, n_head, n_audio_ctx);
-
-            //struct ggml_tensor * V_trans =
-            //    ggml_cpy(ctx0,
-            //            ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
-            //            ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head));
-
-            struct ggml_tensor * V =
-                ggml_view_3d(ctx0, wstate.kv_cross.v,
-                        n_audio_ctx, n_state/n_head, n_head,
-                        n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
-                        n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
-                        n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
-
-            // ------
-
             struct ggml_tensor * Q =
                 ggml_permute(ctx0,
-                        ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
+                        ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
                         0, 2, 1, 3);
 
-            // K * Q
-            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q);
-
-            //struct ggml_tensor * KQ_scaled =
-            //    ggml_scale(ctx0,
-            //            KQ,
-            //            ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
-            //            );
+            if (wctx.params.flash_attn) {
+                struct ggml_tensor * Kcross =
+                    ggml_view_3d(ctx0, wstate.kv_cross.k,
+                            n_state_head, n_audio_ctx_pad, n_head,
+                            ggml_element_size(wstate.kv_cross.k)*n_state,
+                            ggml_element_size(wstate.kv_cross.k)*n_state_head,
+                            ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx_pad*il);
 
-            // no masking for cross-attention
-            //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
+                struct ggml_tensor * Vcross =
+                    ggml_view_3d(ctx0, wstate.kv_cross.v,
+                            n_state_head, n_audio_ctx_pad, n_head,
+                            ggml_element_size(wstate.kv_cross.v)*n_state,
+                            ggml_element_size(wstate.kv_cross.v)*n_state_head,
+                            ggml_element_size(wstate.kv_cross.v)*n_state*n_audio_ctx_pad*il);
 
-            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
+                cur = ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f);
 
-            // [EXPERIMENTAL] Token-level timestamps with DTW
-            if (wctx.params.dtw_token_timestamps) {
-                if (wstate.aheads_masks.m[il] != nullptr) {
-                    struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]);
-                    aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
-                    aheads_KQs = ggml_cont(ctx0, aheads_KQs);
-                    aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs);
-                    aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
-                    aheads_KQs = ggml_cont(ctx0, aheads_KQs);
-                    aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]);
-                    if (aheads_cross_QKs == NULL) {
-                        aheads_cross_QKs = aheads_KQs;
-                    } else {
-                        aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs);
+                cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
+            } else {
+                struct ggml_tensor * Kcross =
+                    ggml_view_3d(ctx0, wstate.kv_cross.k,
+                            n_state_head, n_audio_ctx, n_head,
+                            ggml_element_size(wstate.kv_cross.k)*n_state,
+                            ggml_element_size(wstate.kv_cross.k)*n_state_head,
+                            ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
+
+                struct ggml_tensor * Vcross =
+                    ggml_view_3d(ctx0, wstate.kv_cross.v,
+                            n_audio_ctx, n_state_head, n_head,
+                            n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
+                            n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state_head,
+                            n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
+
+                // ------
+
+                // K * Q
+                struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q);
+
+                struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
+
+                // [EXPERIMENTAL] Token-level timestamps with DTW
+                if (wctx.params.dtw_token_timestamps) {
+                    if (wstate.aheads_masks.m[il] != nullptr) {
+                        struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]);
+                        aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
+                        aheads_KQs = ggml_cont(ctx0, aheads_KQs);
+                        aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs);
+                        aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
+                        aheads_KQs = ggml_cont(ctx0, aheads_KQs);
+                        aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]);
+                        if (aheads_cross_QKs == NULL) {
+                            aheads_cross_QKs = aheads_KQs;
+                        } else {
+                            aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs);
+                        }
                     }
                 }
-            }
 
-            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+                struct ggml_tensor * KQV = ggml_mul_mat(ctx0, Vcross, KQ_soft_max);
 
-            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+                struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 
-            // cur = KQV_merged.contiguous().view(n_state, n_tokens)
-            cur = ggml_cpy(ctx0,
-                    KQV_merged,
-                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
+                cur = ggml_cpy(ctx0,
+                        KQV_merged,
+                        ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
+            }
         }
 
         // projection
@@ -2638,7 +2733,9 @@ static bool whisper_decode_internal(
             return false;
         }
 
-        kv_self.n = whisper_kv_cache_cell_max(kv_self);
+        const uint32_t pad = whisper_kv_cache_get_padding(wctx);
+        kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(whisper_kv_cache_cell_max(kv_self), pad)));
+
         //kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
         //printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
     }
@@ -2672,9 +2769,10 @@ static bool whisper_decode_internal(
             struct ggml_tensor * KQ_mask = ggml_graph_get_tensor(gf, "KQ_mask");
 
             auto & kv_self = wstate.kv_self;
-            const int32_t n_kv     = kv_self.n;
 
-            wstate.inp_mask.resize(n_kv*n_tokens);
+            const int32_t n_kv = kv_self.n;
+
+            wstate.inp_mask.resize(ggml_nelements(KQ_mask));
 
             float * data = wstate.inp_mask.data();
             memset(data, 0, ggml_nbytes(KQ_mask));
@@ -2690,6 +2788,12 @@ static bool whisper_decode_internal(
                         }
                     }
                 }
+
+                for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
+                    for (int j = 0; j < n_kv; ++j) {
+                        data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
+                    }
+                }
             }
 
             ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float));
@@ -2697,7 +2801,7 @@ static bool whisper_decode_internal(
 
         logits = gf->nodes[gf->n_nodes - 1];
 
-        if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
+        if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) {
             return false;
         }
     }
@@ -3144,18 +3248,14 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
 
     whisper_state * state = new whisper_state;
 
-    state->backend = whisper_backend_init(ctx->params);
-    if (!state->backend) {
-        WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
-        whisper_free_state(state);
-        return nullptr;
-    }
-
     // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
     // in theory, there can be a case where this is not enough, but in practice it should always be enough
     const int factor = 3;
 
-    if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) {
+    if (!kv_cache_init(state->kv_self, ctx->backend, ctx->itype,
+                ctx->model.hparams.n_text_state,
+                ctx->model.hparams.n_text_layer,
+                GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
         WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
         whisper_free_state(state);
         return nullptr;
@@ -3166,7 +3266,10 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
         WHISPER_LOG_INFO("%s: kv self size  = %7.2f MB\n", __func__, memory_size / 1e6);
     }
 
-    if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
+    if (!kv_cache_init(state->kv_cross, ctx->backend, ctx->itype,
+                ctx->model.hparams.n_text_state,
+                ctx->model.hparams.n_text_layer,
+                GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
         WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
         whisper_free_state(state);
         return nullptr;
@@ -3177,6 +3280,20 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
         WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
     }
 
+    if (!kv_cache_init(state->kv_pad, ctx->backend, ctx->itype,
+                ctx->model.hparams.n_audio_state,
+                1,
+                GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
+        WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
+        whisper_free_state(state);
+        return nullptr;
+    }
+
+    {
+        const size_t memory_size = ggml_nbytes(state->kv_pad.k) + ggml_nbytes(state->kv_pad.v);
+        WHISPER_LOG_INFO("%s: kv pad  size  = %7.2f MB\n", __func__, memory_size / 1e6);
+    }
+
     // [EXPERIMENTAL] Token-level timestamps with DTW
     if (ctx->params.dtw_token_timestamps) {
         if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, ctx->backend)) {
@@ -3347,6 +3464,7 @@ int whisper_ctx_init_openvino_encoder(
 struct whisper_context_params whisper_context_default_params() {
     struct whisper_context_params result = {
         /*.use_gpu              =*/ true,
+        /*.flash_attn           =*/ false,
         /*.gpu_device           =*/ 0,
 
         /*.dtw_token_timestamps =*/ false,
@@ -3445,6 +3563,16 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu
 struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) {
     ggml_time_init();
 
+    if (params.flash_attn && params.dtw_token_timestamps) {
+        WHISPER_LOG_WARN("%s: dtw_token_timestamps is not supported with flash_attn - disabling\n", __func__);
+        params.dtw_token_timestamps = false;
+    }
+
+    WHISPER_LOG_INFO("%s: use gpu    = %d\n", __func__, params.use_gpu);
+    WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn);
+    WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device);
+    WHISPER_LOG_INFO("%s: dtw        = %d\n", __func__, params.dtw_token_timestamps);
+
     whisper_context * ctx = new whisper_context;
     ctx->params = params;
 
@@ -3533,6 +3661,7 @@ void whisper_free_state(struct whisper_state * state) {
     if (state) {
         kv_cache_free(state->kv_self);
         kv_cache_free(state->kv_cross);
+        kv_cache_free(state->kv_pad);
 
 #ifdef WHISPER_USE_COREML
         if (state->ctx_coreml != nullptr) {
@@ -3555,8 +3684,6 @@ void whisper_free_state(struct whisper_state * state) {
         ggml_gallocr_free(state->alloc_cross.alloc);
         ggml_gallocr_free(state->alloc_decode.alloc);
 
-        ggml_backend_free(state->backend);
-
         // [EXPERIMENTAL] Token-level timestamps with DTW
         aheads_masks_free(state->aheads_masks);
 
index 6a875d3bbb9d34e67b6efe15e9eadc99511d5344..9c7c58d874b0c6be7fd7bb23ff1ba79cd0b9dc4b 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -113,6 +113,7 @@ extern "C" {
 
     struct whisper_context_params {
         bool  use_gpu;
+        bool  flash_attn;
         int   gpu_device;  // CUDA device
 
         // [EXPERIMENTAL] Token-level timestamps with DTW