]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : add Flash Attention (#5021)
authorGeorgi Gerganov <redacted>
Tue, 30 Apr 2024 09:16:08 +0000 (12:16 +0300)
committerGitHub <redacted>
Tue, 30 Apr 2024 09:16:08 +0000 (12:16 +0300)
* ggml : add ggml_flash_attn_ext API

* ggml : fix GQA support in ggml_flash_attn_ext

* ggml : online attention (CPU)

* metal : initial implementation

* metal : f16 precision

* metal : reduce branches

* metal : specialize for head size

* wip : 8 rows per simd group

* wip : 4 rows per simd group

* wip : template for rows per warp

* metal : parallelize across KV size

* metal : parallel reduce across heads

* metal : efficient flash_attn_f16 implementation

* metal : avoid redundant loads of the attention

* metal : scale and mask in matrix form

* metal : fix comment

* llama : avoid ggml_cast, use F32 query

* metal : add parallel reduce version (disabled)

* metal : move output into local memory + optimize

- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments

* metal : add tests, fix scaling, support C > 32

* metal : improve precision

* ggml : fix f16 mad

* metal : minor

* metal : support Q > 8

* tests : add ATTN tests

* metal : disable buffer allocation logs

* tests : more

* metal : faster inner loop for C == 32

* metal : fix array initialization

* tests : ifdef

* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext

* ggml : fix ggml_soft_max mask requirement

* cuda : fix soft_max to use correct mask size

* cuda : add flash_attn kernel (wip)

* metal : optimize softmax for C > 32

* metal : optimize softmax

* tests : minor fix

* cuda : avoid zeroing fragments

* tests : update dims

* cuda : fix __hisinf() result check

* cuda : avoid warp_reduce for smax

* cuda : use int instead of int64_t

Noticeably improves performance (thanks to Johannes)

* cuda : make loops use the same loop values

Thanks Johannes again for the tip

* cuda : unroll some of the loops

* cuda : avoid __hisinf branches

* cuda : use half2 in softmax

* cuda : switch to 1 warp for bs > 16

* cuda : speed-up reduce part of the kernel

* cuda : unroll Q*K^T loop

* cuda : fix -INF block check

* cuda : simplify softmax

* cuda : fix matrix names

* cuda : minor

* llama : adapt to F16 KQ_pos

* llama : adapt new models to F16 KQ_mask

* ggml : fix F16 store (ARM NEON)

* llama : fix type of KQ_mask and KQ_pos

* ggml : fix CPU soft_max

* tests : add hs=256

* cuda : fix build

* metal : improve perf via smaller int registers

* cuda : adapt soft_max to F16 mask and pos

* CUDA: faster FlashAttention, kernel for bs == 1

* 16 cols for Phi-2

* no vec for hs, no hs==256 ncols==32 for Volta

* adjust kernel selection logic

* 4 warps, 256 stride for all D

* no ncols == 64

* Multiple parallel blocks for batch size 1

* fix compile warnings

* fix excessive KQ_b loads

* fix cmake build

* fix KV cache padding, NaN from INFINITY (#6438)

* llama : flash_attn cparam + fix defrag

* server: support flash_attn param

* server: bench: enable flash_attn param

* CUDA: refactor host code, dyn. par. blocks

* fix flash_attn_vec_f16 race condition

* flush softmax exp below threshold to 0

* store temp KQ in registers

* Calculate KQ as FP32 if KQV has GGML_PREC_F32

* Add __hgt2_mask implementation for CUDA 11

* fix KQ FP32 precision fpr parallel_blocks > 1

* llama-bench : add -fa,--flash-attn arg

* metal : add BS=1 kernel for flash attention (#6508)

* metal : add BS=1 kernel for flash attention (wip)

* metal : support more than 1 warps

* metal : opts

* metal : opt

* metal : switch to parallel reduce

* metal : reduce registers

* metal : simplify

* metal : initial FA vec kernel

* metal : use F32 attention accumulators

* batched-bench : add fattn arg

* llama : simplify llama_build_kv_store

ggml-ci

* llama : adapt build_olmo to changes

* ggml : fix arm fp16 store on windows

* metal : clean-up

* metal : clean-up kernel code

* metal : minor

* tests : remove benchmarks

ggml-ci

* ggml : fix avx512 const correctness

ggml-ci

* ggml : fix soft_max with bias on CPU

ggml-ci

* common : print --flash-attn in help

* ggml : fix num dimensions in ggml_flash_attn_ext

* llama : force disable flash attention for incompatible models

* ggml : ggml_soft_max support F16/F32 mask/pos

ggml-ci

* cuda : uint -> uint32_t

* cuda : "constexpr dim3" -> "const dim3"

ggml-ci

* cuda : try to fix __hgt2_mask

ggml-ci

* ggml : add TODO's for F16/F32 mask/pos support in other backends

* llama : replace bool need_kq_pos with use_alibi

* llama : prep ALiBi support for BERT models

ggml-ci

* llama : fix n_batch requirements

ggml-ci

* cont

* server : add help for --flash-attn arg

* llama : disable FA for AMD

* tests : remove TMP_ATTN_BENCH

ggml-ci

* llama : support save/load state with FA enabled

ggml-ci

* ci : add CUDA save-load-state tests

ggml-ci

* llama : llama_kv_cache_clear zeroes data + fix save-load seq

ggml-ci

* llama : fix copy-paste errors, add TODO

* llama : disallow incompatible states

* llama : update llama_state_get_size after v_trans field

* metal : remove tmp log

* llama : add static reminder for llama_state_get_size

* metal : fix max nsg

ggml-ci

* ci : fix arg order

ggml-ci

---------

Co-authored-by: Johannes Gäßler <redacted>
Co-authored-by: Pierrick HYMBERT <redacted>
22 files changed:
ci/run.sh
common/common.cpp
common/common.h
examples/batched-bench/batched-bench.cpp
examples/llama-bench/llama-bench.cpp
examples/server/bench/bench.py
examples/server/server.cpp
ggml-cuda.cu
ggml-cuda/common.cuh
ggml-cuda/fattn.cu [new file with mode: 0644]
ggml-cuda/fattn.cuh [new file with mode: 0644]
ggml-cuda/softmax.cu
ggml-kompute.cpp
ggml-metal.m
ggml-metal.metal
ggml-sycl.cpp
ggml-vulkan.cpp
ggml.c
ggml.h
llama.cpp
llama.h
tests/test-backend-ops.cpp

index a1cd9908f65fadbdd9a86eb90211494c015758c5..bf21b6b31c52d540df60f1a7f637c0c8a5f301d4 100755 (executable)
--- a/ci/run.sh
+++ b/ci/run.sh
@@ -336,7 +336,8 @@ function gg_run_open_llama_3b_v2 {
 
     (time ./bin/imatrix --model ${model_f16} -f ${wiki_test_60} -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
 
-    (time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
+    (time ./bin/save-load-state     --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
+    (time ./bin/save-load-state -fa --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
 
     function check_ppl {
         qnt="$1"
@@ -517,7 +518,10 @@ function gg_run_open_llama_7b_v2 {
 
     (time ./bin/imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
 
-    (time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
+    (time ./bin/save-load-state     -ngl 10 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
+    (time ./bin/save-load-state -fa -ngl 10 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
+    (time ./bin/save-load-state     -ngl 99 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
+    (time ./bin/save-load-state -fa -ngl 99 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
 
     function check_ppl {
         qnt="$1"
index 099d0356fa4b8895c4460e4e17e273cc4a9482a9..243b88abf1aab44dcce3be44b2e10322ceab4e61 100644 (file)
@@ -947,6 +947,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         params.cont_batching = true;
         return true;
     }
+    if (arg == "-fa" || arg == "--flash-attn") {
+        params.flash_attn = true;
+        return true;
+    }
     if (arg == "--color") {
         params.use_color = true;
         return true;
@@ -1513,6 +1517,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  -ns N, --sequences N  number of sequences to decode (default: %d)\n", params.n_sequences);
     printf("  -ps N, --p-split N    speculative decoding split probability (default: %.1f)\n", (double)params.p_split);
     printf("  -cb, --cont-batching  enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
+    printf("  -fa, --flash-attn     enable Flash Attention (default: %s)\n", params.flash_attn ? "enabled" : "disabled");
     printf("  --mmproj MMPROJ_FILE  path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");
     printf("  --image IMAGE_FILE    path to an image file. use with multimodal models. Specify multiple times for batching\n");
     if (llama_supports_mlock()) {
@@ -1885,6 +1890,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
     cparams.cb_eval           = params.cb_eval;
     cparams.cb_eval_user_data = params.cb_eval_user_data;
     cparams.offload_kqv       = !params.no_kv_offload;
+    cparams.flash_attn        = params.flash_attn;
 
     cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
     cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
@@ -2707,6 +2713,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
     fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed);
     fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
     fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
+    fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
     fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
 
     const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());
index 8afdf2bdf189b41855b2210c83683fd2cb8fb266..0618eea7429b1b150b9f2c3f01c5f296d1f7a8a7 100644 (file)
@@ -150,6 +150,7 @@ struct gpt_params {
     bool multiline_input   = false; // reverse the usage of `\`
     bool simple_io         = false; // improves compatibility with subprocesses and limited consoles
     bool cont_batching     = true;  // insert new sequences for decoding on-the-fly
+    bool flash_attn        = false; // flash attention
 
     bool input_prefix_bos  = false; // prefix BOS to user inputs, preceding input_prefix
     bool ignore_eos        = false; // ignore generated EOS tokens
index 1e34de620a41b8fdfe6481e2b61eea99cf74e4dc..2924d8116f44f16e3a40b2134a49ea6e1d7542a4 100644 (file)
@@ -32,7 +32,7 @@ int main(int argc, char ** argv) {
     gpt_params params;
 
     if (argc == 1 || argv[1][0] == '-') {
-        printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [IS_PP_SHARED] [NGL] <PP> <TG> <PL>\n" , argv[0]);
+        printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [FATTN] [IS_PP_SHARED] [NGL] <PP> <TG> <PL>\n" , argv[0]);
         printf("  <PP>, <TG> and PL are comma-separated lists of numbers without spaces\n\n");
         printf("  example: %s ggml-model-f16.gguf 2048 2048 512 0 999 128,256,512 128,256 1,2,4,8,16,32\n\n", argv[0]);
         return 1 ;
@@ -41,6 +41,7 @@ int main(int argc, char ** argv) {
     int n_kv_max     = 2048;
     int n_batch      = 2048;
     int n_ubatch     = 512;
+    bool flash_attn  = false;
     int is_pp_shared = 0;
     int n_gpu_layers = 0;
 
@@ -66,23 +67,27 @@ int main(int argc, char ** argv) {
     }
 
     if (argc >= 6) {
-        is_pp_shared = std::atoi(argv[5]);
+        flash_attn = std::atoi(argv[5]);
     }
 
     if (argc >= 7) {
-        n_gpu_layers = std::atoi(argv[6]);
+        is_pp_shared = std::atoi(argv[6]);
     }
 
     if (argc >= 8) {
-        n_pp = parse_list(argv[7]);
+        n_gpu_layers = std::atoi(argv[7]);
     }
 
     if (argc >= 9) {
-        n_tg = parse_list(argv[8]);
+        n_pp = parse_list(argv[8]);
     }
 
     if (argc >= 10) {
-        n_pl = parse_list(argv[9]);
+        n_tg = parse_list(argv[9]);
+    }
+
+    if (argc >= 11) {
+        n_pl = parse_list(argv[10]);
     }
 
     // init LLM
@@ -108,10 +113,11 @@ int main(int argc, char ** argv) {
 
     llama_context_params ctx_params = llama_context_default_params();
 
-    ctx_params.seed      = 1234;
-    ctx_params.n_ctx     = n_kv_max;
-    ctx_params.n_batch   = n_batch;
-    ctx_params.n_ubatch  = n_ubatch;
+    ctx_params.seed       = 1234;
+    ctx_params.n_ctx      = n_kv_max;
+    ctx_params.n_batch    = n_batch;
+    ctx_params.n_ubatch   = n_ubatch;
+    ctx_params.flash_attn = flash_attn;
 
     ctx_params.n_threads       = params.n_threads;
     ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
@@ -169,7 +175,7 @@ int main(int argc, char ** argv) {
     }
 
     LOG_TEE("\n");
-    LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
+    LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, flash_attn, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
     LOG_TEE("\n");
 
     LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP",     "TG",     "B",    "N_KV",     "T_PP s",   "S_PP t/s", "T_TG s",   "S_TG t/s", "T s",      "S t/s");
index 8b532c8b6a98a59ed4b5a6a2a5eb9ae3eb29a565..95c3095dd04da3ffd310eb40fcbc17586e54807e 100644 (file)
@@ -174,6 +174,7 @@ struct cmd_params {
     std::vector<llama_split_mode> split_mode;
     std::vector<int> main_gpu;
     std::vector<bool> no_kv_offload;
+    std::vector<bool> flash_attn;
     std::vector<std::vector<float>> tensor_split;
     std::vector<bool> use_mmap;
     std::vector<bool> embeddings;
@@ -195,6 +196,7 @@ static const cmd_params cmd_params_defaults = {
     /* split_mode    */ {LLAMA_SPLIT_MODE_LAYER},
     /* main_gpu      */ {0},
     /* no_kv_offload */ {false},
+    /* flash_attn    */ {false},
     /* tensor_split  */ {std::vector<float>(llama_max_devices(), 0.0f)},
     /* use_mmap      */ {true},
     /* embeddings    */ {false},
@@ -220,6 +222,7 @@ static void print_usage(int /* argc */, char ** argv) {
     printf("  -sm, --split-mode <none|layer|row>  (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str());
     printf("  -mg, --main-gpu <i>                 (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
     printf("  -nkvo, --no-kv-offload <0|1>        (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
+    printf("  -fa, --flash-attn <0|1>             (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
     printf("  -mmp, --mmap <0|1>                  (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
     printf("  -embd, --embeddings <0|1>           (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
     printf("  -ts, --tensor-split <ts0/ts1/..>    (default: 0)\n");
@@ -393,6 +396,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
             }
             auto p = split<bool>(argv[i], split_delim);
             params.no_kv_offload.insert(params.no_kv_offload.end(), p.begin(), p.end());
+        } else if (arg == "-fa" || arg == "--flash-attn") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            auto p = split<bool>(argv[i], split_delim);
+            params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end());
         } else if (arg == "-mmp" || arg == "--mmap") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -477,6 +487,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
     if (params.split_mode.empty())   { params.split_mode = cmd_params_defaults.split_mode; }
     if (params.main_gpu.empty())     { params.main_gpu = cmd_params_defaults.main_gpu; }
     if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; }
+    if (params.flash_attn.empty())   { params.flash_attn = cmd_params_defaults.flash_attn; }
     if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; }
     if (params.use_mmap.empty())     { params.use_mmap = cmd_params_defaults.use_mmap; }
     if (params.embeddings.empty())   { params.embeddings = cmd_params_defaults.embeddings; }
@@ -498,6 +509,7 @@ struct cmd_params_instance {
     llama_split_mode split_mode;
     int main_gpu;
     bool no_kv_offload;
+    bool flash_attn;
     std::vector<float> tensor_split;
     bool use_mmap;
     bool embeddings;
@@ -532,6 +544,7 @@ struct cmd_params_instance {
         cparams.type_k = type_k;
         cparams.type_v = type_v;
         cparams.offload_kqv = !no_kv_offload;
+        cparams.flash_attn = flash_attn;
         cparams.embeddings = embeddings;
 
         return cparams;
@@ -554,6 +567,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
     for (const auto & tk : params.type_k)
     for (const auto & tv : params.type_v)
     for (const auto & nkvo : params.no_kv_offload)
+    for (const auto & fa : params.flash_attn)
     for (const auto & nt : params.n_threads) {
         for (const auto & n_prompt : params.n_prompt) {
             if (n_prompt == 0) {
@@ -572,6 +586,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
                 /* .split_mode   = */ sm,
                 /* .main_gpu     = */ mg,
                 /* .no_kv_offload= */ nkvo,
+                /* .flash_attn   = */ fa,
                 /* .tensor_split = */ ts,
                 /* .use_mmap     = */ mmp,
                 /* .embeddings   = */ embd,
@@ -596,6 +611,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
                 /* .split_mode   = */ sm,
                 /* .main_gpu     = */ mg,
                 /* .no_kv_offload= */ nkvo,
+                /* .flash_attn   = */ fa,
                 /* .tensor_split = */ ts,
                 /* .use_mmap     = */ mmp,
                 /* .embeddings   = */ embd,
@@ -633,6 +649,7 @@ struct test {
     llama_split_mode split_mode;
     int main_gpu;
     bool no_kv_offload;
+    bool flash_attn;
     std::vector<float> tensor_split;
     bool use_mmap;
     bool embeddings;
@@ -657,6 +674,7 @@ struct test {
         split_mode = inst.split_mode;
         main_gpu = inst.main_gpu;
         no_kv_offload = inst.no_kv_offload;
+        flash_attn = inst.flash_attn;
         tensor_split = inst.tensor_split;
         use_mmap = inst.use_mmap;
         embeddings = inst.embeddings;
@@ -731,7 +749,7 @@ struct test {
             "n_batch", "n_ubatch",
             "n_threads", "type_k", "type_v",
             "n_gpu_layers", "split_mode",
-            "main_gpu", "no_kv_offload",
+            "main_gpu", "no_kv_offload", "flash_attn",
             "tensor_split", "use_mmap", "embeddings",
             "n_prompt", "n_gen", "test_time",
             "avg_ns", "stddev_ns",
@@ -753,7 +771,7 @@ struct test {
         }
         if (field == "cuda" || field == "opencl"  || field == "vulkan" || field == "kompute" || field == "metal" ||
             field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
-            field == "use_mmap" || field == "embeddings") {
+            field == "flash_attn" || field == "use_mmap" || field == "embeddings") {
             return BOOL;
         }
         if (field == "avg_ts" || field == "stddev_ts") {
@@ -787,7 +805,7 @@ struct test {
             std::to_string(n_batch), std::to_string(n_ubatch),
             std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
             std::to_string(n_gpu_layers), split_mode_str(split_mode),
-            std::to_string(main_gpu), std::to_string(no_kv_offload),
+            std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
             tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
             std::to_string(n_prompt), std::to_string(n_gen), test_time,
             std::to_string(avg_ns()), std::to_string(stdev_ns()),
@@ -955,6 +973,9 @@ struct markdown_printer : public printer {
         if (field == "no_kv_offload") {
             return "nkvo";
         }
+        if (field == "flash_attn") {
+            return "fa";
+        }
         if (field == "use_mmap") {
             return "mmap";
         }
@@ -1001,6 +1022,9 @@ struct markdown_printer : public printer {
         if (params.no_kv_offload.size() > 1 || params.no_kv_offload != cmd_params_defaults.no_kv_offload) {
             fields.emplace_back("no_kv_offload");
         }
+        if (params.flash_attn.size() > 1 || params.flash_attn != cmd_params_defaults.flash_attn) {
+            fields.emplace_back("flash_attn");
+        }
         if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
             fields.emplace_back("tensor_split");
         }
index 6ca637bddc3a12eec2e397ea048198951767b787..86c5de101445c162c6150b25c73fbde1255afe0a 100644 (file)
@@ -268,6 +268,7 @@ def start_server_background(args):
     server_args.extend(['--defrag-thold', "0.1"])
     server_args.append('--cont-batching')
     server_args.append('--metrics')
+    server_args.append('--flash-attn')
     server_args.extend(['--log-format', "text"])
     args = [str(arg) for arg in [server_path, *server_args]]
     print(f"bench: starting server with: {' '.join(args)}")
index 01453af2c97e3337c9aa099c175ee2c0edc1d90b..f60530cf3db5617692c4280f0a9c082d8cb69258 100644 (file)
@@ -2377,6 +2377,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
     printf("  --embeddings              enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
     printf("  -np N, --parallel N       number of slots for process requests (default: %d)\n", params.n_parallel);
     printf("  -cb, --cont-batching      enable continuous batching (a.k.a dynamic batching) (default: enabled)\n");
+    printf("  -fa, --flash-attn         enable Flash Attention (default: %s)\n", params.flash_attn ? "enabled" : "disabled");
     printf("  -spf FNAME, --system-prompt-file FNAME\n");
     printf("                            set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n");
     printf("  -ctk TYPE, --cache-type-k TYPE\n");
@@ -2742,6 +2743,8 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
             params.embedding = true;
         } else if (arg == "-cb" || arg == "--cont-batching") {
             params.cont_batching = true;
+        } else if (arg == "-fa" || arg == "--flash-attn") {
+            params.flash_attn = true;
         } else if (arg == "-np" || arg == "--parallel") {
             if (++i >= argc) {
                 invalid_param = true;
index d277104d1217703b9922b0909937af84260c791c..c30554f0cbe378e9b1102a56c9f697504395b955 100644 (file)
@@ -14,6 +14,7 @@
 #include "ggml-cuda/cpy.cuh"
 #include "ggml-cuda/diagmask.cuh"
 #include "ggml-cuda/dmmv.cuh"
+#include "ggml-cuda/fattn.cuh"
 #include "ggml-cuda/getrows.cuh"
 #include "ggml-cuda/im2col.cuh"
 #include "ggml-cuda/mmq.cuh"
@@ -140,6 +141,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
         info.devices[id].cc = 100*prop.major + 10*prop.minor;
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
         info.devices[id].smpb = prop.sharedMemPerBlock;
+        info.devices[id].nsm  = prop.multiProcessorCount;
     }
 
     for (int id = 0; id < info.device_count; ++id) {
@@ -2290,6 +2292,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_ARGSORT:
             ggml_cuda_op_argsort(ctx, dst);
             break;
+        case GGML_OP_FLASH_ATTN_EXT:
+            ggml_cuda_flash_attn_ext(ctx, dst);
+            break;
         default:
             return false;
     }
@@ -2564,6 +2569,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_ARANGE:
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_LEAKY_RELU:
+        case GGML_OP_FLASH_ATTN_EXT:
             return true;
         default:
             return false;
index 481065b2a3484b350f258b981a5b4efae291477a..156eba6d1ef74d780ccaa74e72eea9428ccac7e8 100644 (file)
 #define CC_PASCAL     600
 #define MIN_CC_DP4A   610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
 #define CC_VOLTA      700
+#define CC_AMPERE     800
 #define CC_OFFSET_AMD 1000000
 #define CC_RDNA1      (CC_OFFSET_AMD + 1010)
 #define CC_RDNA2      (CC_OFFSET_AMD + 1030)
@@ -271,7 +272,6 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
     return a;
 }
 
-#ifdef GGML_CUDA_F16
 static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
 #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
 #pragma unroll
@@ -284,7 +284,6 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
    NO_DEVICE_CODE;
 #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
 }
-#endif // GGML_CUDA_F16
 
 static __device__ __forceinline__ float warp_reduce_max(float x) {
 #pragma unroll
@@ -294,19 +293,26 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
     return x;
 }
 
-//static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
-//#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
-//#pragma unroll
-//    for (int mask = 16; mask > 0; mask >>= 1) {
-//        x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
-//    }
-//    return x;
-//#else
-//    GGML_UNUSED(x);
-//    NO_DEVICE_CODE;
-//#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
-//}
+static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
+#pragma unroll
+   for (int mask = 16; mask > 0; mask >>= 1) {
+       x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
+   }
+   return x;
+#else
+   GGML_UNUSED(x);
+   NO_DEVICE_CODE;
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
+}
 
+#if CUDART_VERSION < 12000
+static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
+    const uint32_t mask_low  = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
+    const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
+    return mask_low | mask_high;
+}
+#endif // CUDART_VERSION < 12000
 
 #if defined(GGML_USE_HIPBLAS)
 #define __CUDA_ARCH__ 1300
@@ -391,6 +397,11 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
 }
 #endif // defined(GGML_USE_HIPBLAS)
 
+#define FP16_AVAILABLE     defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \
+    defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL
+
+#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
+
 // TODO: move to ggml-common.h
 static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
 
@@ -404,6 +415,7 @@ struct ggml_cuda_device_info {
 
     struct cuda_device_info {
         int     cc;                 // compute capability
+        int     nsm;                // number of streaming multiprocessors
         size_t  smpb;               // max. shared memory per block
         bool    vmm;                // virtual memory support
         size_t  vmm_granularity;    // granularity of virtual memory
diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu
new file mode 100644 (file)
index 0000000..df1e800
--- /dev/null
@@ -0,0 +1,944 @@
+#include "common.cuh"
+#include "fattn.cuh"
+
+#include <cstdint>
+
+#if FP16_MMA_AVAILABLE
+#include <mma.h>
+#endif
+
+#define FATTN_KQ_STRIDE       256
+#define HALF_MAX_HALF         __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
+#define SOFTMAX_FTZ_THRESHOLD -20.0f                   // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
+
+template<int D, int parallel_blocks> // D == head size
+__launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1)
+static __global__ void flash_attn_vec_ext_f16(
+        const char * __restrict__ Q,
+        const char * __restrict__ K,
+        const char * __restrict__ V,
+        const char * __restrict__ mask,
+        float      * __restrict__ dst,
+        float2     * __restrict__ dst_meta,
+        const float scale,
+        const int ne00,
+        const int ne01,
+        const int ne02,
+        const int ne03,
+        const int ne10,
+        const int ne11,
+        const int ne12,
+        const int ne13,
+        const int ne31,
+        const int nb31,
+        const int nb01,
+        const int nb02,
+        const int nb03,
+        const int nb11,
+        const int nb12,
+        const int nb13,
+        const int ne0,
+        const int ne1,
+        const int ne2,
+        const int ne3) {
+#if FP16_AVAILABLE
+    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    const int ic = blockIdx.x / parallel_blocks; // Index of the Q/QKV column to work on.
+    const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+    const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.y              + nb01*ic);
+    const half2  * K_h2  = (const half2  *) (K    + nb12*(blockIdx.y / gqa_ratio));
+    const half   * V_h   = (const half   *) (V    + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
+    const half   * maskh = (const half   *)  mask + ne11*ic;
+
+    const int stride_KV  = nb11 / sizeof(half);
+    const int stride_KV2 = nb11 / sizeof(half2);
+
+    constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
+    const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
+    __builtin_assume(tid < nwarps*WARP_SIZE);
+
+    __shared__ half KQ[nwarps*WARP_SIZE];
+    KQ[tid] = -INFINITY;
+    half2 * KQ2 = (half2 *) KQ;
+
+    half kqmax = -HALF_MAX_HALF;
+    half kqsum = 0.0f;
+
+    __shared__ half kqmax_shared[WARP_SIZE];
+    __shared__ half kqsum_shared[WARP_SIZE];
+    if (threadIdx.y == 0) {
+        kqmax_shared[threadIdx.x] = -HALF_MAX_HALF;
+        kqsum_shared[threadIdx.x] = 0.0f;
+    }
+    __syncthreads();
+
+    // Convert Q to half2 and store in registers:
+    half2 Q_h2[(D/2 + WARP_SIZE - 1) / WARP_SIZE];
+#pragma unroll
+    for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+        const int i = i0 + threadIdx.x;
+        if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+            break;
+        }
+
+        Q_h2[i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(Q_f2[i].x, Q_f2[i].y);
+    }
+
+    half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value.
+
+    const int k_start  = parallel_blocks == 1 ? 0 : ip*D;
+    for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
+        // Calculate KQ tile and keep track of new maximum KQ values:
+        half kqmax_new = kqmax;
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
+            const int i_KQ = i_KQ_0 + threadIdx.y;
+
+            if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
+                break;
+            }
+
+            half2 sum2 = make_half2(0.0f, 0.0f);
+#pragma unroll
+            for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
+                const int k_KQ = k_KQ_0 + threadIdx.x;
+                if (k_KQ_0 + WARP_SIZE > D/2 && k_KQ >= D/2) {
+                    break;
+                }
+
+                const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
+                sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE];
+            }
+
+            sum2 = warp_reduce_sum(sum2);
+            half sum = __low2half(sum2) + __high2half(sum2);
+            sum += mask ? maskh[k_VKQ_0 + i_KQ] : __float2half(0.0f);
+            kqmax_new = __hmax(kqmax_new, sum);
+            if (threadIdx.x == 0) {
+                KQ[i_KQ] = sum;
+            }
+        }
+
+        kqmax_new = warp_reduce_max(kqmax_new);
+        if (threadIdx.x == 0) {
+            kqmax_shared[threadIdx.y] = kqmax_new;
+        }
+        __syncthreads();
+        kqmax_new = kqmax_shared[threadIdx.x];
+        kqmax_new = warp_reduce_max(kqmax_new);
+
+        const half KQ_max_scale = hexp(kqmax - kqmax_new);
+        kqmax = kqmax_new;
+
+        const half val = hexp(KQ[tid] - kqmax);
+        kqsum = kqsum*KQ_max_scale + val;
+        KQ[tid] = val;
+
+        VKQ *= __half2half2(KQ_max_scale);
+
+        __syncthreads();
+
+        if (tid < D) {
+#pragma unroll
+            for (int k0 = 0; k0 < D; k0 += 2) {
+                if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
+                    break;
+                }
+
+                half2 V_k;
+                reinterpret_cast<half&>(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid];
+                reinterpret_cast<half&>(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid];
+                VKQ += V_k*KQ2[k0/2];
+            }
+        }
+
+        __syncthreads();
+    }
+
+    if (tid >= D) {
+        kqsum = 0.0f;
+    }
+
+    kqsum = warp_reduce_sum(kqsum);
+    if (threadIdx.x == 0) {
+        kqsum_shared[threadIdx.y] = kqsum;
+    }
+    __syncthreads();
+    kqsum = kqsum_shared[threadIdx.x];
+    kqsum = warp_reduce_sum(kqsum);
+
+    if (tid >= D) {
+        return;
+    }
+
+    half dst_val = (__low2half(VKQ) + __high2half(VKQ));
+    if (parallel_blocks == 1) {
+        dst_val /= kqsum;
+    }
+    dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = dst_val;
+
+    if (parallel_blocks == 1 || tid != 0) {
+        return;
+    }
+    dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax, kqsum);
+#else
+   NO_DEVICE_CODE;
+#endif // FP16_AVAILABLE
+}
+
+// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
+template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
+__launch_bounds__(nwarps*WARP_SIZE, 1)
+static __global__ void flash_attn_ext_f16(
+        const char * __restrict__ Q,
+        const char * __restrict__ K,
+        const char * __restrict__ V,
+        const char * __restrict__ mask,
+        float      * __restrict__ dst,
+        float2     * __restrict__ dst_meta,
+        const float scale,
+        const int ne00,
+        const int ne01,
+        const int ne02,
+        const int ne03,
+        const int ne10,
+        const int ne11,
+        const int ne12,
+        const int ne13,
+        const int ne31,
+        const int nb31,
+        const int nb01,
+        const int nb02,
+        const int nb03,
+        const int nb11,
+        const int nb12,
+        const int nb13,
+        const int ne0,
+        const int ne1,
+        const int ne2,
+        const int ne3) {
+#if FP16_MMA_AVAILABLE
+    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
+    const int ip  =        blockIdx.x % parallel_blocks;  // Index in group of blocks running for the same column in parallel.
+
+    static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
+    static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
+    constexpr int frag_m = ncols == 8 ? 32 : 16;
+    constexpr int frag_n = ncols == 8 ?  8 : 16;
+    static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
+    typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a,    frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
+    typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a,    frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
+    typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b,    frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
+    typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t>                      frag_c_KQ;
+    typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half>                          frag_c_VKQ;
+
+    constexpr int KQ_stride_tc  = nwarps*frag_m; // Number of KQ rows calculated in parallel.
+    constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
+    static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
+
+    // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
+    constexpr int D_padded = D + 8;
+    constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
+    constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
+
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+    const float * Q_f   = (const float *) (Q + nb02* blockIdx.y              + nb01*ic0);
+    const half  * K_h   = (const half  *) (K + nb12*(blockIdx.y / gqa_ratio));
+    const half  * V_h   = (const half  *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
+    const half  * maskh = (const half  *)  mask + (nb31/sizeof(half))* ic0;
+    const half2 * mask2 = (const half2 *)  mask + (nb31/sizeof(half))*(ic0/2);
+
+    const int stride_Q  = nb01 / sizeof(float);
+    const int stride_KV = nb11 / sizeof(half);
+
+    frag_b Q_b[D/16][ncols/frag_n];
+
+    // A single buffer for temporarily holding tiles of KQ and VKQ parts:
+    constexpr int mem_KQ = ncols*kqs_padded*kqar;
+    constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
+    __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
+    float * KQ_f = (float *) KQ;
+    half2 * KQ2 = (half2 *) KQ;
+
+    float    KQ_rowsum_f[ncols/nwarps] = {0.0f};
+    float       KQ_max_f[ncols/nwarps];
+    float KQ_max_scale_f[ncols/nwarps] = {0.0f};
+
+#pragma unroll
+    for (int j = 0; j < ncols/nwarps; ++j) {
+        KQ_max_f[j] = -FLT_MAX/2.0f;
+    }
+
+    half2    KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
+    half2       KQ_max_h2[ncols/nwarps];
+    half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
+
+#pragma unroll
+    for (int j = 0; j < ncols/nwarps; ++j) {
+        KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
+    }
+
+    __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
+    half2 * VKQ2 = (half2 *) VKQ;
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+#pragma unroll
+        for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+            if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+                break;
+            }
+            VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
+        }
+    }
+
+    // Convert Q to half and apply scale, temporarily store in KQ:
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+#pragma unroll
+        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+            if (i0 + WARP_SIZE > D && i >= D) {
+                break;
+            }
+            KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
+        }
+    }
+
+    __syncthreads();
+
+    // Load Q into tensor core fragments/registers since it will be used frequently:
+#pragma unroll
+    for (int i0 = 0; i0 < D; i0 += 16) {
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+            nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
+        }
+    }
+
+    __syncthreads();
+
+    // Iterate over ne11 == previous tokens:
+    for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
+        // Calculate tile of KQ:
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
+            frag_c_KQ KQ_c[ncols/frag_n];
+#pragma unroll
+            for (int j = 0; j < ncols/frag_n; ++j) {
+                nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
+            }
+#pragma unroll
+            for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
+                frag_a_K K_a;
+                nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
+#pragma unroll
+                for (int j = 0; j < ncols/frag_n; ++j) {
+                    nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
+                }
+            }
+#pragma unroll
+            for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+                nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
+            }
+        }
+
+        __syncthreads();
+
+        // Calculate softmax for each KQ column using the current max. value.
+        // The divisor is stored in KQ_rowsum and will be applied at the end.
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            if (std::is_same<KQ_acc_t, float>::value) {
+                float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
+                }
+
+                float KQ_max_new = KQ_max_f[j0/nwarps];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
+                    KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
+                }
+                KQ_max_new = warp_reduce_max(KQ_max_new);
+
+                const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
+                KQ_max_scale_f[j0/nwarps] = expf(diff);
+                if (diff <= SOFTMAX_FTZ_THRESHOLD) {
+                    KQ_max_scale_f[j0/nwarps] = 0.0f;
+                }
+                KQ_max_f[j0/nwarps] = KQ_max_new;
+
+                float KQ_rowsum_add = 0.0f;
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
+                    KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
+                    if (diff <= SOFTMAX_FTZ_THRESHOLD) {
+                        KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
+                    }
+                    KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
+                    KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
+                }
+                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
+
+                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+                KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
+            } else {
+                half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
+                }
+
+                half2 KQ_max_new = KQ_max_h2[j0/nwarps];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
+                    KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
+                }
+                KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
+                const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
+                KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
+                const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
+                *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
+                KQ_max_h2[j0/nwarps] = KQ_max_new;
+
+                half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
+                    KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
+                    const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
+                    *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
+                    KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
+                    KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
+                }
+                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
+
+                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+                KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
+            }
+        }
+
+        __syncthreads();
+
+        frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+#pragma unroll
+            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
+                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
+                nvcuda::wmma::load_matrix_sync(
+                    KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
+                    KQ + j0*(kqar*kqs_padded) + k,
+                    kqar*kqs_padded);
+            }
+        }
+
+        frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
+#pragma unroll
+        for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
+#pragma unroll
+            for (int j = 0; j < ncols/frag_n; ++j) {
+                nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
+            }
+
+#pragma unroll
+            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
+                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
+
+                frag_a_V v_a;
+                nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
+#pragma unroll
+                for (int j = 0; j < ncols/frag_n; ++j) {
+                    nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
+                }
+            }
+        }
+
+        __syncthreads();
+
+        const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
+#pragma unroll
+            for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+                nvcuda::wmma::store_matrix_sync(
+                    KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
+                    VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
+                    D_padded, nvcuda::wmma::mem_col_major);
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            half2 VKQ_scale;
+            if (std::is_same<KQ_acc_t, float>::value) {
+                VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
+            } else {
+                VKQ_scale = KQ_max_scale_h2[j0/nwarps];
+            }
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+                if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+                    break;
+                }
+
+                half2 VKQ_add = make_half2(0.0f, 0.0f);
+#pragma unroll
+                for (int l = 0; l < VKQ_ratio; ++l) {
+                    VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
+                }
+                VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
+            }
+        }
+
+        __syncthreads();
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j_VKQ = j0 + threadIdx.y;
+        if (ic0 + j_VKQ >= ne01) {
+            return;
+        }
+        const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
+
+        float KQ_rowsum_j;
+        if (std::is_same<KQ_acc_t, float>::value) {
+            KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
+        } else {
+            KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+            if (i0 + WARP_SIZE > D && i >= D) {
+                break;
+            }
+            float dst_val = VKQ[j_VKQ*D_padded + i];
+            if (parallel_blocks == 1) {
+                dst_val /= KQ_rowsum_j;
+            }
+            dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
+        }
+
+        if (parallel_blocks == 1 || threadIdx.x != 0) {
+            continue;
+        }
+
+        float2 dst_meta_val;
+        if (std::is_same<KQ_acc_t, float>::value) {
+            dst_meta_val.x = KQ_max_f[j0/nwarps];
+        } else {
+            dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
+        }
+        dst_meta_val.y = KQ_rowsum_j;
+        dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
+    }
+#else
+   NO_DEVICE_CODE;
+#endif // FP16_MMA_AVAILABLE
+}
+
+template<int D, int parallel_blocks> // D == head size
+__launch_bounds__(D, 1)
+static __global__ void flash_attn_combine_results(
+        const float  * __restrict__ VKQ_parts,
+        const float2 * __restrict__ VKQ_meta,
+        float * __restrict__ dst) {
+#if FP16_AVAILABLE
+    VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
+    VKQ_meta  += parallel_blocks   * gridDim.y*blockIdx.x;
+    dst       +=                 D * gridDim.y*blockIdx.x;
+
+    const int tid = threadIdx.x;
+    __builtin_assume(tid < D);
+
+    __shared__ float2 meta[parallel_blocks];
+    if (tid < 2*parallel_blocks) {
+        ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
+    }
+
+    __syncthreads();
+
+    float kqmax = meta[0].x;
+#pragma unroll
+    for (int l = 1; l < parallel_blocks; ++l) {
+        kqmax = max(kqmax, meta[l].x);
+    }
+
+    float VKQ_numerator   = 0.0f;
+    float VKQ_denominator = 0.0f;
+#pragma unroll
+    for (int l = 0; l < parallel_blocks; ++l) {
+        const float diff = meta[l].x - kqmax;
+        const float KQ_max_scale = expf(diff);
+        const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
+        *((uint32_t *) &KQ_max_scale) &= ftz_mask;
+
+        VKQ_numerator   += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
+        VKQ_denominator += KQ_max_scale * meta[l].y;
+    }
+
+    dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
+#else
+   NO_DEVICE_CODE;
+#endif // FP16_AVAILABLE
+}
+
+constexpr int get_max_power_of_2(int x) {
+    return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
+}
+
+static_assert(get_max_power_of_2(1) == 1, "Test failed.");
+static_assert(get_max_power_of_2(2) == 2, "Test failed.");
+static_assert(get_max_power_of_2(4) == 4, "Test failed.");
+static_assert(get_max_power_of_2(6) == 2, "Test failed.");
+
+// Number of VKQ rows calculated in parallel:
+constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
+    return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
+}
+
+static_assert(get_VKQ_stride(128, 1, 32) ==  32, "Test failed.");
+static_assert(get_VKQ_stride(128, 2, 32) ==  64, "Test failed.");
+static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
+static_assert(get_VKQ_stride( 64, 1, 32) ==  32, "Test failed.");
+static_assert(get_VKQ_stride( 64, 2, 32) ==  64, "Test failed.");
+static_assert(get_VKQ_stride( 64, 4, 32) ==  64, "Test failed.");
+static_assert(get_VKQ_stride( 80, 1, 16) ==  16, "Test failed.");
+static_assert(get_VKQ_stride( 80, 2, 16) ==  16, "Test failed.");
+static_assert(get_VKQ_stride( 80, 4, 16) ==  16, "Test failed.");
+
+template <int D, int parallel_blocks> void launch_fattn_vec_f16(
+        const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
+        ggml_cuda_pool & pool, cudaStream_t main_stream
+) {
+    ggml_cuda_pool_alloc<float>  dst_tmp(pool);
+    ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
+
+    if (parallel_blocks > 1) {
+        dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
+        dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
+    }
+
+    constexpr int  nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
+    const     dim3 block_dim(WARP_SIZE, nwarps, 1);
+    const     dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]);
+    const     int  shmem = 0;
+
+    float scale;
+    memcpy(&scale, KQV->op_params, sizeof(float));
+
+    flash_attn_vec_ext_f16<D, parallel_blocks>
+        <<<blocks_num, block_dim, shmem, main_stream>>> (
+                (const char *) Q->data,
+                (const char *) K->data,
+                (const char *) V->data,
+                mask ? ((const char *) mask->data) : nullptr,
+                parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
+                scale,
+                Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
+                K->ne[0], K->ne[1], K->ne[2], K->ne[3],
+                mask ? mask->ne[1] : 0, mask ?  mask->nb[1] : 0,
+                Q->nb[1], Q->nb[2], Q->nb[3],
+                K->nb[1], K->nb[2], K->nb[3],
+                KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
+                );
+    CUDA_CHECK(cudaGetLastError());
+
+    if (parallel_blocks == 1) {
+        return;
+    }
+
+    const dim3 block_dim_combine(D, 1, 1);
+    const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
+    const int  shmem_combine = 0;
+
+    flash_attn_combine_results<D, parallel_blocks>
+        <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
+        (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
+    CUDA_CHECK(cudaGetLastError());
+}
+
+template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename KQ_acc_t> void launch_fattn_f16_impl(
+        const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
+        ggml_cuda_pool & pool, cudaStream_t main_stream
+) {
+    ggml_cuda_pool_alloc<float>  dst_tmp(pool);
+    ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
+
+    if (parallel_blocks > 1) {
+        dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
+        dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
+    }
+
+    constexpr int  frag_m = (cols_per_block) == 8 && (D) % 32 == 0 ? 32 : 16;
+    const     dim3 block_dim(WARP_SIZE, nwarps, 1);
+    const     dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
+    const     int  shmem = 0;
+
+    float scale;
+    memcpy(&scale, KQV->op_params, sizeof(float));
+
+    flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>
+        <<<blocks_num, block_dim, shmem, main_stream>>> (
+                (const char *) Q->data,
+                (const char *) K->data,
+                (const char *) V->data,
+                mask ? ((const char *) mask->data) : nullptr,
+                (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
+                scale,
+                Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
+                K->ne[0], K->ne[1], K->ne[2], K->ne[3],
+                mask ? mask->ne[1] : 0, mask ?  mask->nb[1] : 0,
+                Q->nb[1], Q->nb[2], Q->nb[3],
+                K->nb[1], K->nb[2], K->nb[3],
+                KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
+                );
+    CUDA_CHECK(cudaGetLastError());
+
+    if ((parallel_blocks) == 1) {
+        return;
+    }
+
+    const dim3 block_dim_combine(D, 1, 1);
+    const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
+    const int  shmem_combine = 0;
+
+    flash_attn_combine_results<D, parallel_blocks>
+        <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
+        (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
+    CUDA_CHECK(cudaGetLastError());
+}
+
+template <int D, int cols_per_block, int nwarps, typename KQ_acc_t> void launch_fattn_f16(
+        const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
+        const int nsm, ggml_cuda_pool & pool, cudaStream_t main_stream
+) {
+    const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
+
+    if (4*blocks_num_pb1 < 2*nsm) {
+        launch_fattn_f16_impl<D, cols_per_block, nwarps, 4, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
+        return;
+    }
+    if (2*blocks_num_pb1 < 2*nsm) {
+        launch_fattn_f16_impl<D, cols_per_block, nwarps, 2, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
+        return;
+    }
+    launch_fattn_f16_impl<D, cols_per_block, nwarps, 1, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
+}
+
+void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * Q = dst->src[0];
+    const ggml_tensor * K = dst->src[1];
+    const ggml_tensor * V = dst->src[2];
+
+    const ggml_tensor * mask = dst->src[3];
+
+    ggml_tensor * KQV = dst;
+
+    GGML_ASSERT(Q->type == GGML_TYPE_F32);
+    GGML_ASSERT(K->type == GGML_TYPE_F16);
+    GGML_ASSERT(V->type == GGML_TYPE_F16);
+    GGML_ASSERT(KQV->type == GGML_TYPE_F32);
+
+    GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
+    GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
+                                "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
+
+    GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
+
+    ggml_cuda_set_device(ctx.device);
+
+    const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
+
+    const int32_t precision = KQV->op_params[1];
+
+    if (precision != GGML_PREC_DEFAULT) {
+        if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
+            constexpr int cols_per_block = 16;
+            constexpr int nwarps         =  4;
+            switch (Q->ne[0]) {
+                case 64:
+                    launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    break;
+                case 80:
+                    launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    break;
+                case 96:
+                    launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    break;
+                case 112:
+                    launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    break;
+                case 128:
+                    launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    break;
+                case 256:
+                    launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    break;
+                default:
+                    GGML_ASSERT(false);
+                    break;
+            }
+        } else {
+            constexpr int cols_per_block = 32;
+            constexpr int nwarps         =  4;
+            switch (Q->ne[0]) {
+                case 64:
+                    launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    break;
+                case 80:
+                    launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    break;
+                case 96:
+                    launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    break;
+                case 112:
+                    launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    break;
+                case 128:
+                    launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    break;
+                // case 256:
+                //     launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                //     break;
+                default:
+                    GGML_ASSERT(false);
+                    break;
+            }
+        }
+        return;
+    }
+
+    if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
+        constexpr int parallel_blocks = 4;
+        switch (Q->ne[0]) {
+            case 64:
+                launch_fattn_vec_f16< 64, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
+                break;
+            case 128:
+                launch_fattn_vec_f16<128, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
+                break;
+            case 256:
+                launch_fattn_vec_f16<256, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
+                break;
+            default:
+                GGML_ASSERT(false);
+                break;
+        }
+        return;
+    }
+
+    if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
+        constexpr int cols_per_block = 8;
+        constexpr int nwarps         = 4;
+        switch (Q->ne[0]) {
+            case 64:
+                launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                break;
+            case 96:
+                launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                break;
+            case 128:
+                launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                break;
+            case 256:
+                launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                break;
+            default:
+                GGML_ASSERT(false);
+                break;
+        }
+        return;
+    }
+
+    if (Q->ne[1] <= 32) {
+        constexpr int cols_per_block = 16;
+        constexpr int nwarps         =  4;
+        switch (Q->ne[0]) {
+            case 64:
+                launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                break;
+            case 80:
+                launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                break;
+            case 96:
+                launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                break;
+            case 112:
+                launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                break;
+            case 128:
+                launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                break;
+            case 256:
+                launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                break;
+            default:
+                GGML_ASSERT(false);
+                break;
+        }
+        return;
+    }
+
+    constexpr int cols_per_block = 32;
+    constexpr int nwarps         =  4;
+    switch (Q->ne[0]) {
+        case 64:
+            launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+            break;
+        case 80:
+            launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+            break;
+        case 96:
+            launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+            break;
+        case 112:
+            launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+            break;
+        case 128:
+            launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+            break;
+        case 256:
+            launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+            break;
+        default:
+            GGML_ASSERT(false);
+            break;
+    }
+    return;
+}
diff --git a/ggml-cuda/fattn.cuh b/ggml-cuda/fattn.cuh
new file mode 100644 (file)
index 0000000..ad3ca7a
--- /dev/null
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index fa8f987cf7c1d4e3c26fe60bca485556f9d85e02..6ed225999bddfc914118a6fda3536222047bf6aa 100644 (file)
@@ -1,7 +1,17 @@
 #include "softmax.cuh"
 
-template <bool vals_smem, int ncols_template, int block_size_template>
-static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
+template <typename T>
+static __device__ __forceinline__ float t2f32(T val) {
+    return (float) val;
+}
+
+template <>
+__device__ float __forceinline__ t2f32<half>(half val) {
+    return __half2float(val);
+}
+
+template <bool vals_smem, int ncols_template, int block_size_template, typename T>
+static __global__ void soft_max_f32(const float * x, const T * mask, const T * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
     const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
 
     const int tid  = threadIdx.x;
@@ -43,7 +53,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
         const int64_t ix = (int64_t)rowx*ncols + col;
         const int64_t iy = (int64_t)rowy*ncols + col;
 
-        const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
+        const float val = x[ix]*scale + (mask ? t2f32(mask[iy]) : 0.0f) + (pos ? slope*t2f32(pos[col]) : 0.0f);
 
         vals[col] = val;
         max_val = max(max_val, val);
@@ -114,7 +124,8 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
     }
 }
 
-static void soft_max_f32_cuda(const float * x, const float * mask, const float * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
+template<typename T>
+static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
     int nth = WARP_SIZE;
     while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
     const dim3 block_dims(nth,     1, 1);
@@ -167,15 +178,19 @@ static void soft_max_f32_cuda(const float * x, const float * mask, const float *
 void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
+    const ggml_tensor * src2 = dst->src[2];
+
     const float * src0_d = (const float *)src0->data;
-    const float * src1_d = src1 ? (const float *)src1->data : nullptr;
+    const void  * src1_d = src1 ? (const void *)src1->data : nullptr;
+
     float * dst_d = (float *)dst->data;
     cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
+    GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
+    GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
 
     const int64_t ne00    = src0->ne[0];
     const int64_t nrows_x = ggml_nrows(src0);
@@ -188,14 +203,25 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
 
     // positions tensor
-    float * src2_dd = nullptr;
+    void * src2_d = nullptr;
 
-    ggml_tensor * src2 = dst->src[2];
     const bool use_src2 = src2 != nullptr;
 
     if (use_src2) {
-        src2_dd = (float *)src2->data;
+        src2_d = (void *)src2->data;
     }
 
-    soft_max_f32_cuda(src0_d, src1_d, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+    const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
+
+    if (use_f16) {
+        const half * src1_dd = (const half *)src1_d;
+        const half * src2_dd = (const half *)src2_d;
+
+        soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+    } else {
+        const float * src1_dd = (const float *)src1_d;
+        const float * src2_dd = (const float *)src2_d;
+
+        soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+    }
 }
index 407062e6fd47625d6cb2f78e17649f64387d2866..9a469821d804214afb3d99bb83a7cfa31e093349 100644 (file)
@@ -1427,6 +1427,7 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
         for (int i = node_start; i < node_end; ++i) {
             struct ggml_tensor * src0 = gf->nodes[i]->src[0];
             struct ggml_tensor * src1 = gf->nodes[i]->src[1];
+            struct ggml_tensor * src2 = gf->nodes[i]->src[2]; GGML_UNUSED(src2);
             struct ggml_tensor * dst = gf->nodes[i];
             GGML_ASSERT(dst->data != nullptr);
 
@@ -1559,6 +1560,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
                     {
                         float scale;
                         memcpy(&scale, dst->op_params, sizeof(float));
+
+#pragma message("TODO: add ggml_vk_soft_max() F16/F32 src1 and src2 support")
+#pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/5021")
+                        GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
+                        GGML_ASSERT(src2 == nullptr);
+
                         ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
                     } break;
                 case GGML_OP_DIAG_MASK_INF:
index 9cb4219885cb7da018a6bad7fc1d14ab2dc5b8f9..c6d580b8462d01848bfe0ed0ef23676bd9603a7b 100644 (file)
@@ -46,8 +46,10 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
     GGML_METAL_KERNEL_TYPE_SILU,
     GGML_METAL_KERNEL_TYPE_SILU_4,
-    GGML_METAL_KERNEL_TYPE_SOFT_MAX,
-    GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
+    GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
+    GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
+    GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
+    GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,
     GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
     GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
@@ -177,6 +179,14 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
     GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
     GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
     GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
     GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
     GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
@@ -443,7 +453,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         }
 
         /*
-            GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
+            GGML_METAL_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
                     (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
                     (int) kernel->pipeline.threadExecutionWidth); \
         */
@@ -459,172 +469,182 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
                 return NULL; \
             } \
         } else { \
-            GGML_METAL_LOG_WARN("%s: skipping %-32s (not supported)\n", __func__, "kernel_"#name); \
+            GGML_METAL_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
         }
 
         // simd_sum and simd_max requires MTLGPUFamilyApple7
 
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD,                       add,                    true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW,                   add_row,                true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL,                       mul,                    true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW,                   mul_row,                true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV,                       div,                    true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW,                   div_row,                true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE,                     scale,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4,                   scale_4,                true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP,                     clamp,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH,                      tanh,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU,                      relu,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU,                      gelu,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4,                    gelu_4,                 true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK,                gelu_quick,             true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,              gelu_quick_4,           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU,                      silu,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4,                    silu_4,                 true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX,                  soft_max,               ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,                soft_max_4,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,             diag_mask_inf,          true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,           diag_mask_inf_8,        true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,              get_rows_f32,           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,              get_rows_f16,           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,             get_rows_q4_0,          true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,             get_rows_q4_1,          true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,             get_rows_q5_0,          true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,             get_rows_q5_1,          true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,             get_rows_q8_0,          true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,             get_rows_q2_K,          true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,             get_rows_q3_K,          true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,             get_rows_q4_K,          true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,             get_rows_q5_K,          true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,             get_rows_q6_K,          true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,          get_rows_iq2_xxs,       true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,           get_rows_iq2_xs,        true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,          get_rows_iq3_xxs,       true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,            get_rows_iq3_s,         true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,            get_rows_iq2_s,         true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,            get_rows_iq1_s,         true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,            get_rows_iq1_m,         true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,           get_rows_iq4_nl,        true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,           get_rows_iq4_xs,        true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,              get_rows_i32,           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                  rms_norm,               ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                group_norm,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM,                      norm,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,            mul_mv_f32_f32,         ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,            mul_mv_f16_f16,         ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,            mul_mv_f16_f32,         ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,       mul_mv_f16_f32_1row,    ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,         mul_mv_f16_f32_l4,      ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,           mul_mv_q4_0_f32,        ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,           mul_mv_q4_1_f32,        ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,           mul_mv_q5_0_f32,        ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,           mul_mv_q5_1_f32,        ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,           mul_mv_q8_0_f32,        ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,           mul_mv_q2_K_f32,        ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,           mul_mv_q3_K_f32,        ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,           mul_mv_q4_K_f32,        ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,           mul_mv_q5_K_f32,        ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,           mul_mv_q6_K_f32,        ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,        mul_mv_iq2_xxs_f32,     ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,         mul_mv_iq2_xs_f32,      ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,        mul_mv_iq3_xxs_f32,     ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,          mul_mv_iq3_s_f32,       ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,          mul_mv_iq2_s_f32,       ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,          mul_mv_iq1_s_f32,       ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,          mul_mv_iq1_m_f32,       ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,         mul_mv_iq4_nl_f32,      ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,         mul_mv_iq4_xs_f32,      ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,         mul_mv_id_f32_f32,      ctx->support_simdgroup_reduction);
-      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,         mul_mv_id_f16_f16,      ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,         mul_mv_id_f16_f32,      ctx->support_simdgroup_reduction);
-      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,    mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
-      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,      mul_mv_id_f16_f32_l4,   ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,        mul_mv_id_q4_0_f32,     ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,        mul_mv_id_q4_1_f32,     ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,        mul_mv_id_q5_0_f32,     ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,        mul_mv_id_q5_1_f32,     ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,        mul_mv_id_q8_0_f32,     ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,        mul_mv_id_q2_K_f32,     ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,        mul_mv_id_q3_K_f32,     ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,        mul_mv_id_q4_K_f32,     ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,        mul_mv_id_q5_K_f32,     ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,        mul_mv_id_q6_K_f32,     ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,     mul_mv_id_iq2_xxs_f32,  ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,      mul_mv_id_iq2_xs_f32,   ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,     mul_mv_id_iq3_xxs_f32,  ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,       mul_mv_id_iq3_s_f32,    ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,       mul_mv_id_iq2_s_f32,    ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,       mul_mv_id_iq1_s_f32,    ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,       mul_mv_id_iq1_m_f32,    ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,      mul_mv_id_iq4_nl_f32,   ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,      mul_mv_id_iq4_xs_f32,   ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,            mul_mm_f32_f32,         ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,            mul_mm_f16_f32,         ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,           mul_mm_q4_0_f32,        ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,           mul_mm_q4_1_f32,        ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,           mul_mm_q5_0_f32,        ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,           mul_mm_q5_1_f32,        ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,           mul_mm_q8_0_f32,        ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,           mul_mm_q2_K_f32,        ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,           mul_mm_q3_K_f32,        ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,           mul_mm_q4_K_f32,        ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,           mul_mm_q5_K_f32,        ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,           mul_mm_q6_K_f32,        ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,        mul_mm_iq2_xxs_f32,     ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,         mul_mm_iq2_xs_f32,      ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,        mul_mm_iq3_xxs_f32,     ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,          mul_mm_iq3_s_f32,       ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,          mul_mm_iq2_s_f32,       ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,          mul_mm_iq1_s_f32,       ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,          mul_mm_iq1_m_f32,       ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,         mul_mm_iq4_nl_f32,      ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,         mul_mm_iq4_xs_f32,      ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,         mul_mm_id_f32_f32,      ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,         mul_mm_id_f16_f32,      ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,        mul_mm_id_q4_0_f32,     ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,        mul_mm_id_q4_1_f32,     ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,        mul_mm_id_q5_0_f32,     ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,        mul_mm_id_q5_1_f32,     ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,        mul_mm_id_q8_0_f32,     ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,        mul_mm_id_q2_K_f32,     ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,        mul_mm_id_q3_K_f32,     ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,        mul_mm_id_q4_K_f32,     ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,        mul_mm_id_q5_K_f32,     ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,        mul_mm_id_q6_K_f32,     ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,     mul_mm_id_iq2_xxs_f32,  ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,      mul_mm_id_iq2_xs_f32,   ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,     mul_mm_id_iq3_xxs_f32,  ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,       mul_mm_id_iq3_s_f32,    ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,       mul_mm_id_iq2_s_f32,    ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,       mul_mm_id_iq1_s_f32,    ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,       mul_mm_id_iq1_m_f32,    ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,      mul_mm_id_iq4_nl_f32,   ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,      mul_mm_id_iq4_xs_f32,   ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32,                  rope_f32,               true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16,                  rope_f16,               true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32,                 alibi_f32,              true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16,                im2col_f16,             true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32,                im2col_f32,             true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,               upscale_f32,            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32,                   pad_f32,                true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,    timestep_embedding_f32, true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32,                arange_f32,             true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,       argsort_f32_i32_asc,    true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,      argsort_f32_i32_desc,   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,            leaky_relu_f32,         true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16,               cpy_f32_f16,            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32,               cpy_f32_f32,            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,              cpy_f32_q8_0,           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,              cpy_f32_q4_0,           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,              cpy_f32_q4_1,           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,              cpy_f32_q5_0,           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,              cpy_f32_q5_1,           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,            cpy_f32_iq4_nl,         true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16,               cpy_f16_f16,            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32,               cpy_f16_f32,            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT,                    concat,                 true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR,                       sqr,                    true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                  sum_rows,               true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD,                           add,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW,                       add_row,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL,                           mul,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW,                       mul_row,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV,                           div,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW,                       div_row,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE,                         scale,                          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4,                       scale_4,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP,                         clamp,                          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH,                          tanh,                           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU,                          relu,                           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU,                          gelu,                           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4,                        gelu_4,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK,                    gelu_quick,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,                  gelu_quick_4,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU,                          silu,                           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4,                        silu_4,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,                  soft_max_f16,                   ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,                soft_max_f16_4,                 ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,                  soft_max_f32,                   ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,                soft_max_f32_4,                 ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,                 diag_mask_inf,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,               diag_mask_inf_8,                true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,                  get_rows_f32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,                  get_rows_f16,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,                 get_rows_q4_0,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,                 get_rows_q4_1,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,                 get_rows_q5_0,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,                 get_rows_q5_1,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,                 get_rows_q8_0,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,                 get_rows_q2_K,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,                 get_rows_q3_K,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,                 get_rows_q4_K,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,                 get_rows_q5_K,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,                 get_rows_q6_K,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,              get_rows_iq2_xxs,               true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,               get_rows_iq2_xs,                true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,              get_rows_iq3_xxs,               true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,                get_rows_iq3_s,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,                get_rows_iq2_s,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,                get_rows_iq1_s,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,                get_rows_iq1_m,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,               get_rows_iq4_nl,                true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,               get_rows_iq4_xs,                true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,                  get_rows_i32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                      rms_norm,                       ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                    group_norm,                     ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM,                          norm,                           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,                mul_mv_f32_f32,                 ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,                mul_mv_f16_f16,                 ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,                mul_mv_f16_f32,                 ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,           mul_mv_f16_f32_1row,            ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,             mul_mv_f16_f32_l4,              ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,               mul_mv_q4_0_f32,                ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,               mul_mv_q4_1_f32,                ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,               mul_mv_q5_0_f32,                ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,               mul_mv_q5_1_f32,                ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,               mul_mv_q8_0_f32,                ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,               mul_mv_q2_K_f32,                ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,               mul_mv_q3_K_f32,                ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,               mul_mv_q4_K_f32,                ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,               mul_mv_q5_K_f32,                ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,               mul_mv_q6_K_f32,                ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,            mul_mv_iq2_xxs_f32,             ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,             mul_mv_iq2_xs_f32,              ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,            mul_mv_iq3_xxs_f32,             ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,              mul_mv_iq3_s_f32,               ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,              mul_mv_iq2_s_f32,               ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,              mul_mv_iq1_s_f32,               ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,              mul_mv_iq1_m_f32,               ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,             mul_mv_iq4_nl_f32,              ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,             mul_mv_iq4_xs_f32,              ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,             mul_mv_id_f32_f32,              ctx->support_simdgroup_reduction);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,             mul_mv_id_f16_f16,              ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,             mul_mv_id_f16_f32,              ctx->support_simdgroup_reduction);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,        mul_mv_id_f16_f32_1row,         ctx->support_simdgroup_reduction);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,          mul_mv_id_f16_f32_l4,           ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,            mul_mv_id_q4_0_f32,             ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,            mul_mv_id_q4_1_f32,             ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,            mul_mv_id_q5_0_f32,             ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,            mul_mv_id_q5_1_f32,             ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,            mul_mv_id_q8_0_f32,             ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,            mul_mv_id_q2_K_f32,             ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,            mul_mv_id_q3_K_f32,             ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,            mul_mv_id_q4_K_f32,             ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,            mul_mv_id_q5_K_f32,             ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,            mul_mv_id_q6_K_f32,             ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,         mul_mv_id_iq2_xxs_f32,          ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,          mul_mv_id_iq2_xs_f32,           ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,         mul_mv_id_iq3_xxs_f32,          ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,           mul_mv_id_iq3_s_f32,            ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,           mul_mv_id_iq2_s_f32,            ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,           mul_mv_id_iq1_s_f32,            ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,           mul_mv_id_iq1_m_f32,            ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,          mul_mv_id_iq4_nl_f32,           ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,          mul_mv_id_iq4_xs_f32,           ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,                mul_mm_f32_f32,                 ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,                mul_mm_f16_f32,                 ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,               mul_mm_q4_0_f32,                ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,               mul_mm_q4_1_f32,                ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,               mul_mm_q5_0_f32,                ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,               mul_mm_q5_1_f32,                ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,               mul_mm_q8_0_f32,                ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,               mul_mm_q2_K_f32,                ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,               mul_mm_q3_K_f32,                ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,               mul_mm_q4_K_f32,                ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,               mul_mm_q5_K_f32,                ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,               mul_mm_q6_K_f32,                ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,            mul_mm_iq2_xxs_f32,             ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,             mul_mm_iq2_xs_f32,              ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,            mul_mm_iq3_xxs_f32,             ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,              mul_mm_iq3_s_f32,               ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,              mul_mm_iq2_s_f32,               ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,              mul_mm_iq1_s_f32,               ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,              mul_mm_iq1_m_f32,               ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,             mul_mm_iq4_nl_f32,              ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,             mul_mm_iq4_xs_f32,              ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,             mul_mm_id_f32_f32,              ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,             mul_mm_id_f16_f32,              ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,            mul_mm_id_q4_0_f32,             ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,            mul_mm_id_q4_1_f32,             ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,            mul_mm_id_q5_0_f32,             ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,            mul_mm_id_q5_1_f32,             ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,            mul_mm_id_q8_0_f32,             ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,            mul_mm_id_q2_K_f32,             ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,            mul_mm_id_q3_K_f32,             ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,            mul_mm_id_q4_K_f32,             ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,            mul_mm_id_q5_K_f32,             ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,            mul_mm_id_q6_K_f32,             ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,         mul_mm_id_iq2_xxs_f32,          ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,          mul_mm_id_iq2_xs_f32,           ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,         mul_mm_id_iq3_xxs_f32,          ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,           mul_mm_id_iq3_s_f32,            ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,           mul_mm_id_iq2_s_f32,            ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,           mul_mm_id_iq1_s_f32,            ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,           mul_mm_id_iq1_m_f32,            ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,          mul_mm_id_iq4_nl_f32,           ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,          mul_mm_id_iq4_xs_f32,           ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32,                      rope_f32,                       true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16,                      rope_f16,                       true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32,                     alibi_f32,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16,                    im2col_f16,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32,                    im2col_f32,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,                   upscale_f32,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32,                       pad_f32,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,        timestep_embedding_f32,         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32,                    arange_f32,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,           argsort_f32_i32_asc,            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,          argsort_f32_i32_desc,           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,                leaky_relu_f32,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,        flash_attn_ext_f16_h64,         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,        flash_attn_ext_f16_h80,         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,        flash_attn_ext_f16_h96,         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,       flash_attn_ext_f16_h112,        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,       flash_attn_ext_f16_h128,        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,       flash_attn_ext_f16_h256,        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,   flash_attn_ext_vec_f16_h128,    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16,                   cpy_f32_f16,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32,                   cpy_f32_f32,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,                  cpy_f32_q8_0,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,                  cpy_f32_q4_0,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,                  cpy_f32_q4_1,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,                  cpy_f32_q5_0,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,                  cpy_f32_q5_1,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,                cpy_f32_iq4_nl,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16,                   cpy_f16_f16,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32,                   cpy_f16_f32,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT,                        concat,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR,                           sqr,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                      sum_rows,                       true);
     }
 
     [metal_library release];
@@ -743,6 +763,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_ARGSORT:
         case GGML_OP_LEAKY_RELU:
+        case GGML_OP_FLASH_ATTN_EXT:
             return true;
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
@@ -1326,20 +1347,33 @@ static enum ggml_status ggml_metal_graph_compute(
                     } break;
                 case GGML_OP_SOFT_MAX:
                     {
+                        GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
+                        GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
+
                         int nth = 32; // SIMD width
 
                         id<MTLComputePipelineState> pipeline = nil;
 
+                        const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
+
                         if (ne00%4 == 0) {
                             while (nth < ne00/4 && nth < 256) {
                                 nth *= 2;
                             }
-                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline;
+                            if (use_f16) {
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline;
+                            } else {
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
+                            }
                         } else {
                             while (nth < ne00 && nth < 1024) {
                                 nth *= 2;
                             }
-                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
+                            if (use_f16) {
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline;
+                            } else {
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline;
+                            }
                         }
 
                         float scale;
@@ -2503,6 +2537,161 @@ static enum ggml_status ggml_metal_graph_compute(
 
                         [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                     } break;
+                case GGML_OP_FLASH_ATTN_EXT:
+                    {
+                        GGML_ASSERT(ne00 % 4 == 0);
+                        GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+                        struct ggml_tensor * src3 = gf->nodes[i]->src[3];
+
+                        GGML_ASSERT(ggml_are_same_shape(src1, src2));
+                        GGML_ASSERT(src3);
+
+                        size_t offs_src3 = 0;
+
+                        id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
+
+                        GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
+                        GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
+                                "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
+
+                        const int64_t  ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
+                        const int64_t  ne31 = src3 ? src3->ne[1] : 0;
+                        const int64_t  ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
+                        const int64_t  ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
+
+                        const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30);
+                        const uint64_t nb31 = src3 ? src3->nb[1] : 0;
+                        const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
+                        const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
+
+                        const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
+
+                        float scale;
+                        memcpy(&scale, dst->op_params, sizeof(float));
+
+                        id<MTLComputePipelineState> pipeline = nil;
+
+                        bool use_vec_kernel = false;
+
+                        if (ne01 >= 4 || (ne00%128 != 0)) {
+                            switch (ne00) {
+                                case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
+                                case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
+                                case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
+                                case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
+                                case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
+                                case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
+                                default:
+                                          {
+                                              GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                              GGML_METAL_LOG_ERROR("add template specialization for this size\n");
+                                              GGML_ASSERT(false && "add template specialization for this size");
+                                          }
+                            }
+                        } else {
+                            use_vec_kernel = true;
+
+                            switch (ne00) {
+                                case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
+                                case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
+                                default:
+                                          {
+                                              GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                              GGML_METAL_LOG_ERROR("add template specialization for this size\n");
+                                              GGML_ASSERT(false && "add template specialization for this size");
+                                          }
+                            }
+                        }
+
+                        [encoder setComputePipelineState:pipeline];
+                        [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
+                        [encoder setBuffer:id_src1 offset:offs_src1        atIndex:1];
+                        [encoder setBuffer:id_src2 offset:offs_src2        atIndex:2];
+                        [encoder setBuffer:id_src3 offset:offs_src3        atIndex:3];
+                        [encoder setBuffer:id_dst  offset:offs_dst         atIndex:4];
+                        [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:5];
+                        [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:6];
+                        [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:7];
+                        [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:8];
+                        [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:9];
+                        [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:10];
+                        [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:11];
+                        [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:12];
+                        [encoder setBytes:&ne10    length:sizeof( int64_t) atIndex:13];
+                        [encoder setBytes:&ne11    length:sizeof( int64_t) atIndex:14];
+                        [encoder setBytes:&ne12    length:sizeof( int64_t) atIndex:15];
+                        [encoder setBytes:&ne13    length:sizeof( int64_t) atIndex:16];
+                        [encoder setBytes:&nb10    length:sizeof(uint64_t) atIndex:17];
+                        [encoder setBytes:&nb11    length:sizeof(uint64_t) atIndex:18];
+                        [encoder setBytes:&nb12    length:sizeof(uint64_t) atIndex:19];
+                        [encoder setBytes:&nb13    length:sizeof(uint64_t) atIndex:20];
+                        [encoder setBytes:&ne31    length:sizeof( int64_t) atIndex:21];
+                        [encoder setBytes:&nb31    length:sizeof(uint64_t) atIndex:22];
+                        [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:23];
+                        [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:24];
+                        [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:25];
+                        [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:26];
+                        [encoder setBytes:&scale   length:sizeof(   float) atIndex:27];
+
+                        if (!use_vec_kernel) {
+                            // half8x8 kernel
+                            const int64_t nqptg = 8;  // queries per threadgroup    !! sync with kernel template arguments !!
+                            const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
+
+                            GGML_ASSERT(nqptg <= 32);
+                            GGML_ASSERT(nqptg  % 8  == 0);
+                            GGML_ASSERT(ncpsg  % 32 == 0);
+
+                            int64_t nsgmax = 2;
+
+                            while (true) {
+                                const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
+                                if (smem > ctx->device.maxThreadgroupMemoryLength) {
+                                    break;
+                                }
+                                nsgmax *= 2;
+                            }
+                            nsgmax /= 2;
+
+                            // simdgroups per threadgroup (a.k.a. warps)
+                            const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
+
+                            const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
+
+                            //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
+                            GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
+
+                            [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+                        } else {
+                            // half1x4 kernel
+                            const int64_t nqptg = 1;  // queries per threadgroup    !! sync with kernel template arguments !!
+                            const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
+
+                            GGML_ASSERT(nqptg <= 32);
+                            GGML_ASSERT(nqptg  % 1  == 0);
+                            GGML_ASSERT(ncpsg  % 32 == 0);
+
+                            // simdgroups per threadgroup (a.k.a. warps)
+                            const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
+
+                            int64_t nsg = 1;
+                            while (nsg <= nsgt) {
+                                nsg *= 2;
+                            }
+                            nsg /= 2;
+
+                            const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
+
+                            //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
+                            GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
+                            [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+                        }
+                    } break;
                 case GGML_OP_DUP:
                 case GGML_OP_CPY:
                 case GGML_OP_CONT:
@@ -2706,10 +2895,13 @@ GGML_CALL static const char * ggml_backend_metal_buffer_type_get_name(ggml_backe
     UNUSED(buft);
 }
 
-static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device) {
+static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
+#ifndef GGML_METAL_NDEBUG
 #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
     if (@available(macOS 10.12, iOS 16.0, *)) {
-        GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
+        GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)",
+                __func__,
+                size_aligned / 1024.0 / 1024.0,
                 device.currentAllocatedSize / 1024.0 / 1024.0,
                 device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
 
@@ -2719,10 +2911,15 @@ static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device) {
             GGML_METAL_LOG_INFO("\n");
         }
     } else {
-        GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
+        GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n",
+                __func__,
+                size_aligned / 1024.0 / 1024.0,
+                device.currentAllocatedSize / 1024.0 / 1024.0);
     }
+#endif
 #endif
     UNUSED(device);
+    UNUSED(size_aligned);
 }
 
 GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
@@ -2756,8 +2953,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff
         return NULL;
     }
 
-    GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
-    ggml_backend_metal_log_allocated_size(device);
+    //ggml_backend_metal_log_allocated_size(device, size_aligned);
 
     return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
 }
@@ -2844,7 +3040,7 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
             return false;
         }
 
-        GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
+        ggml_backend_metal_log_allocated_size(device, size_aligned);
 
         ++ctx->n_buffers;
     } else {
@@ -2867,7 +3063,8 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
                 return false;
             }
 
-            GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, offs = %12ld", __func__, size_step_aligned / 1024.0 / 1024.0, i);
+            ggml_backend_metal_log_allocated_size(device, size_step_aligned);
+
             if (i + size_step < size) {
                 GGML_METAL_LOG_INFO("\n");
             }
@@ -2876,8 +3073,6 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
         }
     }
 
-    ggml_backend_metal_log_allocated_size(device);
-
     return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
 }
 
index 191880af17942e85a7900d8ebcc65fb26329a740..3d4276ae02b9e5e5b4e7d794fe1ab83e9d458ea7 100644 (file)
@@ -352,11 +352,12 @@ kernel void kernel_sum_rows(
     dst_row[0] = row_sum;
 }
 
+template<typename T>
 kernel void kernel_soft_max(
-        device const float * src0,
-        device const float * src1,
-        device const float * src2,
-        device       float * dst,
+        device const  char * src0,
+        device const  char * src1,
+        device const  char * src2,
+        device        char * dst,
         constant   int64_t & ne00,
         constant   int64_t & ne01,
         constant   int64_t & ne02,
@@ -375,10 +376,10 @@ kernel void kernel_soft_max(
     const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
     const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
 
-    device const float * psrc0 =         src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-    device const float * pmask = src1 != src0 ? src1                               + i01*ne00 : nullptr;
-    device const float * ppos  = src2 != src0 ? src2                                          : nullptr;
-    device       float * pdst  =         dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    device const     T * pmask = src1 != src0 ? (device const    T *) src1         + i01*ne00 : nullptr;
+    device const     T * ppos  = src2 != src0 ? (device const    T *) src2                    : nullptr;
+    device       float * pdst  = (device       float *) dst  + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
 
     float slope = 0.0f;
 
@@ -456,11 +457,12 @@ kernel void kernel_soft_max(
     }
 }
 
+template<typename T>
 kernel void kernel_soft_max_4(
-        device const float * src0,
-        device const float * src1,
-        device const float * src2,
-        device       float * dst,
+        device const  char * src0,
+        device const  char * src1,
+        device const  char * src2,
+        device        char * dst,
         constant   int64_t & ne00,
         constant   int64_t & ne01,
         constant   int64_t & ne02,
@@ -479,10 +481,10 @@ kernel void kernel_soft_max_4(
     const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
     const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
 
-    device const float4 * psrc4 =                (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
-    device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 +                                      i01*ne00) : nullptr;
-    device const float4 * ppos  = src2 != src0 ? (device const float4 *)(src2)                                                 : nullptr;
-    device       float4 * pdst4 =                (device       float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
+    device const      T * pmask = src1 != src0 ? (device const     T *) src1         + i01*ne00/4 : nullptr;
+    device const      T * ppos  = src2 != src0 ? (device const     T *) src2                      : nullptr;
+    device       float4 * pdst4 = (device       float4 *) dst  + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
 
     float slope = 0.0f;
 
@@ -499,7 +501,7 @@ kernel void kernel_soft_max_4(
     float4 lmax4 = -INFINITY;
 
     for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
+        lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)));
     }
 
     const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@@ -525,7 +527,7 @@ kernel void kernel_soft_max_4(
     // parallel sum
     float4 lsum4 = 0.0f;
     for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
+        const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val);
         lsum4 += exp_psrc4;
         pdst4[i00] = exp_psrc4;
     }
@@ -562,6 +564,14 @@ kernel void kernel_soft_max_4(
     }
 }
 
+typedef decltype(kernel_soft_max<float>)    kernel_soft_max_t;
+typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
+
+template [[host_name("kernel_soft_max_f16")]]   kernel kernel_soft_max_t   kernel_soft_max<half>;
+template [[host_name("kernel_soft_max_f32")]]   kernel kernel_soft_max_t   kernel_soft_max<float>;
+template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
+template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
+
 kernel void kernel_diag_mask_inf(
         device const float * src0,
         device       float * dst,
@@ -2084,6 +2094,632 @@ kernel void kernel_leaky_relu_f32(
     dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
 }
 
+typedef void (flash_attn_ext_f16_t)(
+        device const  char * q,
+        device const  char * k,
+        device const  char * v,
+        device const  char * mask,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant   int64_t & ne13,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
+        constant   int64_t & ne31,
+        constant  uint64_t & nb31,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant     float & scale,
+        threadgroup   half * shared,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        uint3  tpitg[[thread_position_in_threadgroup]],
+        uint3    ntg[[threads_per_threadgroup]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]);
+
+// ref: https://arxiv.org/pdf/2307.08691.pdf
+template<int64_t D, int64_t Q = 8, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
+kernel void kernel_flash_attn_ext_f16(
+        device const  char * q,
+        device const  char * k,
+        device const  char * v,
+        device const  char * mask,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant   int64_t & ne13,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
+        constant   int64_t & ne31,
+        constant  uint64_t & nb31,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant     float & scale,
+        threadgroup   half * shared [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        uint3  tpitg[[thread_position_in_threadgroup]],
+        uint3    ntg[[threads_per_threadgroup]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+    const short nsg = ntg.y; // number of simdgroups
+
+    const short iq3 = tgpig[2];
+    const short iq2 = tgpig[1];
+    const short iq1 = tgpig[0]*Q;
+
+    const short D4 = D/4;
+    const short D8 = D/8;
+    const short Q8 = Q/8;
+    const short NW = N_SIMDWIDTH;
+    const short SH = (C + Q); // shared memory per simdgroup in (half)
+
+    const short T  = D + 2*nsg*SH; // shared memory size per query in (half)
+    const short TF = T/2;        // shared memory size per query in (float)
+    const short T4 = T/4;        // shared memory size per query in (half4)
+
+    threadgroup half  * sq  = (threadgroup half  *) (shared +              0*D); // holds the query data
+    threadgroup half4 * sq4 = (threadgroup half4 *) (shared +              0*D); // same as above but in half4
+    threadgroup float * ss  = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
+
+    // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
+    simdgroup_half8x8 lo[D8];
+
+    // load heads from Q to shared memory
+    for (short j = sgitg; j < Q; j += nsg) {
+        device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
+
+        for (short i = tiisg; i < D4; i += NW) {
+            if (iq1 + j < ne01) {
+                sq4[j*T4 + i] = (half4) q4[i];
+            } else {
+                sq4[j*T4 + i] = 0.0h;
+            }
+        }
+    }
+
+    // zero out lo
+    for (short i = 0; i < D8; ++i) {
+        lo[i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
+    }
+
+    // zero out shared memory SH
+    for (short j = 0; j < Q; ++j) {
+        for (short i = tiisg; i < SH; i += NW) {
+            ss[j*TF + i] = 0.0f;
+        }
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    {
+        float S[Q] = { [0 ... Q-1] = 0.0h };
+        float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
+
+        // assume K and V are same shape
+        const short ne22 = ne12;
+        const short ne23 = ne13;
+
+        const uint nb21 = nb11;
+        const uint nb22 = nb12;
+        const uint nb23 = nb13;
+
+        // broadcast
+        const short rk2 = ne02/ne12;
+        const short rk3 = ne03/ne13;
+
+        const short rv2 = ne02/ne22;
+        const short rv3 = ne03/ne23;
+
+        // k indices
+        const short ik2 = iq2/rk2;
+        const short ik3 = iq3/rk3;
+
+        // v indices
+        const short iv2 = iq2/rv2;
+        const short iv3 = iq3/rv3;
+
+        // load the queries from shared memory into local memory
+        simdgroup_half8x8 mq[D8];
+
+        for (short i = 0; i < D8; ++i) {
+            simdgroup_load(mq[i], sq + i*8, T);
+        }
+
+        // pointer to the mask
+        device const half * mp = (device const half *) (mask + iq1*nb31);
+
+        // prepare diagonal scale matrix
+        simdgroup_float8x8 mscale(scale);
+
+        // loop over the KV cache
+        // each simdgroup handles blocks of Q rows and C columns
+        for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
+            const int ic = ic0 + C*sgitg;
+            if (ic >= ne11) {
+                break;
+            }
+
+            // Q*K^T
+            {
+                for (short cc = 0; cc < C/8; ++cc) {
+                    simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
+
+                    device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
+
+                    for (short i = 0; i < D8; ++i) {
+                        simdgroup_half8x8 mk;
+                        simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
+
+                        simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
+                    }
+
+                    // mqk = mqk*scale + mask
+                    simdgroup_half8x8 mm;
+                    simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
+                    simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
+
+                    simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
+                }
+            }
+
+            // used to detect blocks full of -INF
+            float smax = -INFINITY;
+
+            // online softmax
+            {
+                float ms[Q];
+
+                for (short j = 0; j < Q; ++j) {
+                    const short p = tiisg;
+
+                    const float m = M[j];
+                    const float s = ss[j*TF + p];
+
+                    smax = simd_max(max(smax, s));
+                    M[j] = simd_max(max(M[j], s));
+
+                                ms[j] = exp(m - M[j]);
+                    const float vs    = exp(s - M[j]);
+
+                    S[j] = S[j]*ms[j] + simd_sum(vs);
+
+                    // the P matrix from the paper (Q rows, C columns)
+                    ss[j*TF + p] = vs;
+                }
+
+                // create a QxQ diagonal matrix for rescaling the output
+                if (tiisg < Q) {
+                    ss[tiisg*TF + C + tiisg] = ms[tiisg];
+                }
+            }
+
+            // skip -INF blocks
+            if (smax == -INFINITY) {
+                continue;
+            }
+
+            // O = diag(ms)*O
+            {
+                simdgroup_float8x8 mm;
+                simdgroup_load(mm, ss + C, TF, 0, false);
+
+                for (short i = 0; i < D8; ++i) {
+                    simdgroup_multiply(lo[i], mm, lo[i]);
+                }
+            }
+
+            // O = O + (Q*K^T)*V
+            {
+                for (short cc = 0; cc < C/8; ++cc) {
+                    device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
+
+                    for (short i = 0; i < D8; ++i) {
+                        simdgroup_half8x8 mk;
+                        simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
+
+                        simdgroup_float8x8 mv;
+                        simdgroup_load(mv, ss + 8*cc, TF, 0, false);
+
+                        simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]);
+                    }
+                }
+            }
+        }
+
+        // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
+        for (short j = 0; j < Q; ++j) {
+            if (tiisg == 0) {
+                ss[j*TF + 0] = S[j];
+                ss[j*TF + 1] = M[j];
+            }
+        }
+    }
+
+    // reduce the warps sequentially
+    for (short sg = 1; sg < nsg; ++sg) {
+        float S = { 0.0h };
+        float M = { -FLT_MAX/2 };
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        // each simdgroup stores its output to shared memory, reusing sq
+        if (sgitg == sg) {
+            for (short i = 0; i < D8; ++i) {
+                simdgroup_store(lo[i], sq + i*8, T, 0, false);
+            }
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        // the first simdgroup accumulates the results from the other simdgroups
+        if (sgitg == 0) {
+            for (short j = 0; j < Q; ++j) {
+                const float S0 = ss[j*TF +         0];
+                const float S1 = ss[j*TF + sg*SH + 0];
+
+                const float M0 = ss[j*TF +         1];
+                const float M1 = ss[j*TF + sg*SH + 1];
+
+                M = max(M0, M1);
+
+                const float ms0 = exp(M0 - M);
+                const float ms1 = exp(M1 - M);
+
+                S = S0*ms0 + S1*ms1;
+
+                if (tiisg == 0) {
+                    ss[j*TF + 0] = S;
+                    ss[j*TF + 1] = M;
+
+                    ss[j*TF + C + j        ] = ms0;
+                    ss[j*TF + C + j + sg*SH] = ms1;
+                }
+            }
+
+            // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
+            {
+                simdgroup_half8x8 t;
+                simdgroup_float8x8 ms0;
+                simdgroup_float8x8 ms1;
+
+                simdgroup_load(ms0, ss + C,         TF, 0, false);
+                simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false);
+
+                for (short i = 0; i < D8; ++i) {
+                    simdgroup_load    (t, sq + i*8, T, 0, false);
+                    simdgroup_multiply(t, ms1, t);
+
+                    simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
+                }
+            }
+        }
+    }
+
+    // store result to shared memory (reuse sq)
+    if (sgitg == 0) {
+        for (short i = 0; i < D8; ++i) {
+            simdgroup_store(lo[i], sq + i*8, T, 0, false);
+        }
+    }
+
+    device float4 * dst4 = (device float4 *) dst;
+
+    // final rescale with 1/S and store to global memory
+    if (sgitg == 0) {
+        for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
+            const float S = ss[j*TF + 0];
+
+            for (short i = tiisg; i < D4; i += NW) {
+                dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
+            }
+        }
+    }
+}
+
+template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>;
+template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>;
+template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
+template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
+template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
+template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
+
+template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
+kernel void kernel_flash_attn_ext_vec_f16(
+        device const  char * q,
+        device const  char * k,
+        device const  char * v,
+        device const  char * mask,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant   int64_t & ne13,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
+        constant   int64_t & ne31,
+        constant  uint64_t & nb31,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant     float & scale,
+        threadgroup   half * shared [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        uint3  tpitg[[thread_position_in_threadgroup]],
+        uint3    ntg[[threads_per_threadgroup]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+    const short nsg = ntg.y; // number of simdgroups
+
+    const short iq3 = tgpig[2];
+    const short iq2 = tgpig[1];
+    const short iq1 = tgpig[0];
+
+    const short D4 = D/4;
+    const short NW = N_SIMDWIDTH;
+    const short SH = (C + Q); // shared memory per simdgroup in (half)
+
+    const short T  = D + 2*nsg*SH; // shared memory size per query in (half)
+
+  //threadgroup half   * sq  = (threadgroup half   *) (shared +              0*D); // holds the query data
+    threadgroup half4  * sq4 = (threadgroup half4  *) (shared +              0*D); // same as above but in half4
+    threadgroup float  * ss  = (threadgroup float  *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
+    threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
+    threadgroup half4  * sr4 = (threadgroup half4  *) (shared +   sgitg*D  + 1*T); // scratch buffer for the results
+
+    // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
+    half4 lo[D4/NW];
+
+    // load heads from Q to shared memory
+    device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
+
+    for (short i = tiisg; i < D4; i += NW) {
+        if (iq1 < ne01) {
+            sq4[i] = (half4) q4[i];
+        } else {
+            sq4[i] = 0.0h;
+        }
+    }
+
+    // zero out lo
+    for (short i = tiisg; i < D4; i += NW) {
+        lo[i/NW] = 0.0h;
+    }
+
+    // zero out shared memory SH
+    for (short i = tiisg; i < SH/4; i += NW) {
+        ss4[i] = 0.0h;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    {
+        float S = { 0.0h };
+        float M = { -FLT_MAX/2 };
+
+        // assume K and V are same shape
+        const short ne22 = ne12;
+        const short ne23 = ne13;
+
+        const uint nb21 = nb11;
+        const uint nb22 = nb12;
+        const uint nb23 = nb13;
+
+        // broadcast
+        const short rk2 = ne02/ne12;
+        const short rk3 = ne03/ne13;
+
+        const short rv2 = ne02/ne22;
+        const short rv3 = ne03/ne23;
+
+        // k indices
+        const short ik2 = iq2 / rk2;
+        const short ik3 = iq3 / rk3;
+
+        // v indices
+        const short iv2 = iq2 / rv2;
+        const short iv3 = iq3 / rv3;
+
+        // load the queries from shared memory into local memory
+        half4 mq[D4];
+
+        for (short ii = 0; ii < D4; ii += NW) {
+            short i = ii + tiisg;
+            mq[i] = sq4[i];
+        }
+
+        // pointer to the mask
+        device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
+
+        // loop over the KV cache
+        // each simdgroup handles blocks of Q rows and C columns
+        for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
+            const int ic = ic0 + C*sgitg;
+            if (ic >= ne11) {
+                break;
+            }
+
+            // Q*K^T
+            {
+#pragma unroll
+                for (short cc = 0; cc < C/4; ++cc) {
+                    float4 mqk = { 0.0h };
+
+                    device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
+
+#pragma unroll
+                    for (short ii = 0; ii < D4; ii += NW) {
+                        const short i = ii + tiisg;
+
+                        half4x4 mk;
+                        mk[0] = pk4[i + 0*(nb11/8)];
+                        mk[1] = pk4[i + 1*(nb11/8)];
+                        mk[2] = pk4[i + 2*(nb11/8)];
+                        mk[3] = pk4[i + 3*(nb11/8)];
+
+                        mqk += (float4) (mq[i] * mk);
+                    }
+
+                    // reduce the results from the threads in the simdgroup
+                    mqk += simd_shuffle_down(mqk, 16);
+                    mqk += simd_shuffle_down(mqk,  8);
+                    mqk += simd_shuffle_down(mqk,  4);
+                    mqk += simd_shuffle_down(mqk,  2);
+                    mqk += simd_shuffle_down(mqk,  1);
+
+                    // mqk = mqk*scale + mask
+                    if (tiisg == 0) {
+                        float4 mm = (float4) mp4[ic/4 + cc];
+                        mqk = mqk*scale + mm;
+
+                        ss4[cc] = mqk;
+                    }
+                }
+            }
+
+            // online softmax
+            {
+                const short p = tiisg;
+
+                const float m = M;
+                const float s = ss[p];
+
+                M = simd_max(max(M, s));
+
+                const float ms = exp(m - M);
+                const float vs = exp(s - M);
+
+                S = S*ms + simd_sum(vs);
+
+                // the P matrix from the paper (Q rows, C columns)
+                ss[p] = vs;
+
+                // O = diag(ms)*O
+#pragma unroll
+                for (short ii = 0; ii < D4; ii += NW) {
+                    const short i = ii + tiisg;
+                    lo[i/NW] *= ms;
+                }
+            }
+
+            // O = O + (Q*K^T)*V
+            {
+#pragma unroll
+                for (short cc = 0; cc < C/4; ++cc) {
+                    device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23));
+
+#pragma unroll
+                    for (short ii = 0; ii < D4; ii += NW) {
+                        const short i = ii + tiisg;
+
+                        lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
+                        lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
+                        lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
+                        lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
+                    }
+                }
+            }
+
+        }
+
+        // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
+        if (tiisg == 0) {
+            ss[0] = S;
+            ss[1] = M;
+        }
+    }
+
+    // store results to shared memory
+    for (short ii = 0; ii < D4; ii += NW) {
+        short i = ii + tiisg;
+        sr4[i] = lo[ii/NW];
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    // parallel reduce
+    for (short r = nsg/2; r > 0; r >>= 1) {
+        if (sgitg < r) {
+            const float S0 = ss[       0];
+            const float S1 = ss[r*SH + 0];
+
+            const float M0 = ss[       1];
+            const float M1 = ss[r*SH + 1];
+
+            const float M = max(M0, M1);
+
+            const float ms0 = exp(M0 - M);
+            const float ms1 = exp(M1 - M);
+
+            const float S = S0*ms0 + S1*ms1;
+
+            if (tiisg == 0) {
+                ss[0] = S;
+                ss[1] = M;
+            }
+
+            // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
+            for (short ii = 0; ii < D4; ii += NW) {
+                short i = ii + tiisg;
+                sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
+            }
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    device float4 * dst4 = (device float4 *) dst;
+
+    // final rescale with 1/S and store to global memory
+    if (sgitg == 0) {
+        const float S = ss[0];
+
+        for (short ii = 0; ii < D4; ii += NW) {
+            short i = ii + tiisg;
+            dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
+        }
+    }
+}
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
+template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
+
 kernel void kernel_cpy_f16_f16(
         device  const half * src0,
         device        half * dst,
index 2b76b3ebd64f749726de80f6d76a81bb0ec520a8..57fe4ea3d4ac25d2cf65511bffcc03b6c6f58901 100644 (file)
@@ -14744,7 +14744,12 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
+    const ggml_tensor * src2 = dst->src[2];
+
+#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support")
+#pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/5021")
     GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
+    GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
 
     const int64_t ne00 = src0->ne[0];
     const int64_t nrows_x = ggml_nrows(src0);
@@ -14760,7 +14765,6 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
     float * src2_dd = nullptr;
     sycl_pool_alloc<float> src2_f;
 
-    ggml_tensor * src2 = dst->src[2];
     const bool use_src2 = src2 != nullptr;
 
     if (use_src2) {
index 1736ab7361c273443d1c6af4435b23cd2d2a678a..f712cdd5a900eb454625d4c1a3a6a1c27cdb89e1 100644 (file)
@@ -3178,6 +3178,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
         }
         return nullptr;
     case GGML_OP_SOFT_MAX:
+#pragma message("TODO: add ggml_vk_soft_max() F16 src1 and src2 support")
+#pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/5021")
+        GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32);
+        GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
+
         if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
             return ctx->device->pipeline_soft_max_f32;
         }
diff --git a/ggml.c b/ggml.c
index cb273061c5c535a4756c065dbdf75e9a8573dc36..74ecd592791679c68939b0439c6b3be53d966d9a 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -951,7 +951,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
     #define GGML_F16_VEC_ZERO           GGML_F16x8_ZERO
     #define GGML_F16_VEC_SET1           GGML_F16x8_SET1
     #define GGML_F16_VEC_LOAD(p, i)     GGML_F16x8_LOAD(p)
-    #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i])
+    #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), r[i])
     #define GGML_F16_VEC_FMA            GGML_F16x8_FMA
     #define GGML_F16_VEC_ADD            GGML_F16x8_ADD
     #define GGML_F16_VEC_MUL            GGML_F16x8_MUL
@@ -977,7 +977,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
     #define GGML_F16_VEC_ZERO           GGML_F32Cx4_ZERO
     #define GGML_F16_VEC_SET1           GGML_F32Cx4_SET1
     #define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx4_LOAD(p)
-    #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
+    #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i])
     #define GGML_F16_VEC_FMA            GGML_F32Cx4_FMA
     #define GGML_F16_VEC_ADD            GGML_F32Cx4_ADD
     #define GGML_F16_VEC_MUL            GGML_F32Cx4_MUL
@@ -1046,7 +1046,7 @@ do {                                                                  \
 
 // unlike  _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F
 // so F16C guard isn't required
-#define GGML_F32Cx16_LOAD(x)     _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(x)))
+#define GGML_F32Cx16_LOAD(x)     _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
 #define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0))
 
 #define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
@@ -1144,7 +1144,7 @@ do {                                                              \
 
 #if defined(__F16C__)
 // the  _mm256_cvt intrinsics require F16C
-#define GGML_F32Cx8_LOAD(x)     _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
+#define GGML_F32Cx8_LOAD(x)     _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
 #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
 #else
 static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
@@ -1662,6 +1662,37 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
 #endif
 }
 
+inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) {
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F16_STEP - 1));
+
+    GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
+
+    GGML_F16_VEC ax[GGML_F16_ARR];
+    GGML_F16_VEC ay[GGML_F16_ARR];
+
+    for (int i = 0; i < np; i += GGML_F16_STEP) {
+        for (int j = 0; j < GGML_F16_ARR; j++) {
+            ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
+            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+            ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
+
+            GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
+        }
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
+    }
+#else
+    // scalar
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
+    }
+#endif
+}
+
 // xs and vs are byte strides of x and v
 inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
 
@@ -1746,6 +1777,35 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) {
 #endif
 }
 
+inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F16_STEP - 1));
+
+    GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
+
+    GGML_F16_VEC ay[GGML_F16_ARR];
+
+    for (int i = 0; i < np; i += GGML_F16_STEP) {
+        for (int j = 0; j < GGML_F16_ARR; j++) {
+            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+            ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
+
+            GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
+        }
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
+    }
+#else
+    // scalar
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
+    }
+#endif
+}
+
 inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s);   }
 inline static void ggml_vec_sqr_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i];   }
 inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
@@ -2000,6 +2060,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "LEAKY_RELU",
 
     "FLASH_ATTN",
+    "FLASH_ATTN_EXT",
     "FLASH_FF",
     "FLASH_ATTN_BACK",
     "SSM_CONV",
@@ -2026,7 +2087,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CROSS_ENTROPY_LOSS_BACK",
 };
 
-static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
+static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -2090,6 +2151,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "leaky_relu(x)",
 
     "flash_attn(x)",
+    "flash_attn_ext(x)",
     "flash_ff(x)",
     "flash_attn_back(x)",
     "ssm_conv(x)",
@@ -2116,7 +2178,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "cross_entropy_loss_back(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
+static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -4559,6 +4621,8 @@ struct ggml_tensor * ggml_mul_mat(
 void ggml_mul_mat_set_prec(
         struct ggml_tensor * a,
         enum ggml_prec       prec) {
+    GGML_ASSERT(a->op == GGML_OP_MUL_MAT);
+
     const int32_t prec_i32 = (int32_t) prec;
 
     ggml_set_op_params_i32(a, 0, prec_i32);
@@ -5397,17 +5461,23 @@ static struct ggml_tensor * ggml_soft_max_impl(
     GGML_ASSERT(ggml_is_contiguous(a));
 
     if (mask) {
+        GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
         GGML_ASSERT(ggml_is_contiguous(mask));
         GGML_ASSERT(ggml_is_matrix(mask));
-        GGML_ASSERT(ggml_can_repeat_rows(mask, a));
+        GGML_ASSERT(mask->ne[0] == a->ne[0]);
+        GGML_ASSERT(mask->ne[1] >= a->ne[1]);
     }
 
     if (pos) {
         GGML_ASSERT(ggml_is_vector(pos));
-        GGML_ASSERT(pos->type == GGML_TYPE_F32);
+        GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32);
         GGML_ASSERT(pos->ne[0] == a->ne[0]);
     }
 
+    if (pos && mask) {
+        GGML_ASSERT(pos->type == mask->type);
+    }
+
     if (max_bias > 0.0f) {
         GGML_ASSERT(pos);
     }
@@ -6216,6 +6286,59 @@ struct ggml_tensor * ggml_flash_attn(
     return result;
 }
 
+// ggml_flash_attn_ext
+
+struct ggml_tensor * ggml_flash_attn_ext(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * q,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        struct ggml_tensor  * mask,
+        float                 scale) {
+    GGML_ASSERT(ggml_can_mul_mat(k, q));
+    // TODO: check if vT can be multiplied by (k*qT)
+    if (mask) {
+        GGML_ASSERT(ggml_is_contiguous(mask));
+        GGML_ASSERT(mask->ne[2] == 1);
+        GGML_ASSERT(mask->ne[3] == 1);
+        GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
+                "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
+        //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
+    }
+
+    bool is_node = false;
+
+    if (q->grad || k->grad || v->grad) {
+        is_node = true;
+    }
+
+    // permute(0, 2, 1, 3)
+    int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    float params[] = { scale };
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op   = GGML_OP_FLASH_ATTN_EXT;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = q;
+    result->src[1] = k;
+    result->src[2] = v;
+    result->src[3] = mask;
+
+    return result;
+}
+
+void ggml_flash_attn_ext_set_prec(
+        struct ggml_tensor * a,
+        enum ggml_prec       prec) {
+    GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
+
+    const int32_t prec_i32 = (int32_t) prec;
+
+    ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos
+}
+
 // ggml_flash_ff
 
 struct ggml_tensor * ggml_flash_ff(
@@ -12255,7 +12378,7 @@ static void ggml_compute_forward_soft_max_f32(
 
     GGML_TENSOR_UNARY_OP_LOCALS
 
-    const int64_t ne11 = src1 ? src1->ne[1] : 1;
+    //const int64_t ne11 = src1 ? src1->ne[1] : 1;
 
     // TODO: is this supposed to be ceil instead of floor?
     //       https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
@@ -12278,19 +12401,31 @@ static void ggml_compute_forward_soft_max_f32(
     float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
 
     // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
-    float * pos = src2 ? (float *) src2->data : src0->data;
+    ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data;
+    float       * pos_f32 = src2 ? (float       *) src2->data : src0->data;
+
+    const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
 
     for (int i1 = ir0; i1 < ir1; i1++) {
         float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
         float * dp = (float *)((char *)  dst->data +  i1*dst->nb[1]);
 
         // broadcast the mask across rows
-        float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
+        ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
+        float       * mp_f32 = src1 ? (float       *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
 
         ggml_vec_cpy_f32  (nc, wp, sp);
         ggml_vec_scale_f32(nc, wp, scale);
-        if (mp) {
-            ggml_vec_acc_f32(nc, wp, mp);
+        if (mp_f32) {
+            if (use_f16) {
+                for (int i = 0; i < nc; ++i) {
+                    wp[i] += GGML_FP16_TO_FP32(mp_f16[i]);
+                }
+            } else {
+                for (int i = 0; i < nc; ++i) {
+                    wp[i] += mp_f32[i];
+                }
+            }
         }
 
         // ALiBi bias
@@ -12298,8 +12433,14 @@ static void ggml_compute_forward_soft_max_f32(
             const uint32_t h  = (i1/ne01)%ne02; // head
             const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
 
-            for (int i = 0; i < nc; i++) {
-                wp[i] = wp[i] + slope*pos[i];
+            if (use_f16) {
+                for (int i = 0; i < nc; ++i) {
+                    wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]);
+                }
+            } else {
+                for (int i = 0; i < nc; ++i) {
+                    wp[i] += slope*pos_f32[i];
+                }
             }
         }
 
@@ -14569,6 +14710,198 @@ static void ggml_compute_forward_flash_attn(
     }
 }
 
+// ggml_compute_forward_flash_attn_ext
+
+static void ggml_compute_forward_flash_attn_ext_f16(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * q,
+        const struct ggml_tensor * k,
+        const struct ggml_tensor * v,
+        const struct ggml_tensor * mask,
+        struct ggml_tensor * dst) {
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t D = neq0;
+    const int64_t N = neq1;
+
+    GGML_ASSERT(ne0 == D);
+    GGML_ASSERT(ne2 == N);
+
+    GGML_ASSERT(nbq0 == sizeof(float));
+    GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
+
+    GGML_ASSERT(neq0 == D);
+    GGML_ASSERT(nek0 == D);
+    GGML_ASSERT(nev0 == D);
+
+    GGML_ASSERT(neq1 == N);
+    GGML_ASSERT(nev0 == D);
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    // broadcast factors
+    const int64_t rk2 = neq2/nek2;
+    const int64_t rk3 = neq3/nek3;
+
+    const int64_t rv2 = neq2/nev2;
+    const int64_t rv3 = neq3/nev3;
+
+    if (params->type == GGML_TASK_TYPE_INIT) {
+        return;
+    }
+
+    if (params->type == GGML_TASK_TYPE_FINALIZE) {
+        return;
+    }
+
+    // parallelize by q rows using ggml_vec_dot_f32
+
+    // total rows in q
+    const int nr = neq1*neq2*neq3;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    float scale = 1.0f;
+    memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
+
+    // loop over n_batch and n_head
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // q indices
+        const int iq3 = ir/(neq2*neq1);
+        const int iq2 = (ir - iq3*neq2*neq1)/neq1;
+        const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
+
+        float S = 0.0f;
+        float M = -INFINITY;
+
+        float       * V32 = (float       *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
+        ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory
+        ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
+
+        memset(V16, 0, D*sizeof(ggml_fp16_t));
+
+        const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
+
+        // k indices
+        const int ik3 = iq3 / rk3;
+        const int ik2 = iq2 / rk2;
+
+        // v indices
+        const int iv3 = iq3 / rv3;
+        const int iv2 = iq2 / rv2;
+
+        // online softmax / attention
+        // loop over n_kv and n_head_kv
+        // ref: https://arxiv.org/pdf/2112.05682.pdf
+        for (int64_t ic = 0; ic < nek1; ++ic) {
+            const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
+            if (mv == -INFINITY) {
+                continue;
+            }
+
+            float s;
+
+            // convert Q to F16 in V32
+            {
+                const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
+
+                for (int64_t d = 0; d < D; ++d) {
+                    Q16[d] = GGML_FP32_TO_FP16(pq[d]);
+                }
+            }
+
+            ggml_vec_dot_f16(D,
+                    &s, 0,
+                    (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
+                    Q16, 0, 1);
+
+            s = s*scale + mv;
+
+            const float Mold = M;
+
+            float ms = 1.0f;
+            float vs = 1.0f;
+
+            if (s > M) {
+                M = s;
+                ms = expf(Mold - M);
+
+                // V = V*expf(Mold - M)
+                ggml_vec_scale_f16(D, V16, ms);
+            } else {
+                vs = expf(s - M);
+            }
+
+            const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
+
+            // V += v*expf(s - M)
+            ggml_vec_mad_f16(D, V16, v16, vs);
+
+            S = S*ms + vs;
+        }
+
+        // V /= S
+        for (int64_t d = 0; d < D; ++d) {
+            V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
+        }
+
+        // dst indices
+        const int i1 = iq1;
+        const int i2 = iq2;
+        const int i3 = iq3;
+
+        // original
+        //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
+
+        // permute(0, 2, 1, 3)
+        memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1);
+    }
+}
+
+static void ggml_compute_forward_flash_attn_ext(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * q,
+        const struct ggml_tensor * k,
+        const struct ggml_tensor * v,
+        const struct ggml_tensor * mask,
+        struct ggml_tensor * dst) {
+    switch (dst->op_params[1]) {
+        case GGML_PREC_DEFAULT:
+        case GGML_PREC_F32:
+            {
+                // uses F32 accumulators
+                ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_flash_ff
 
 static void ggml_compute_forward_flash_ff_f16(
@@ -16376,6 +16709,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
                 const bool masked = t != 0;
                 ggml_compute_forward_flash_attn(params, masked, tensor);
             } break;
+        case GGML_OP_FLASH_ATTN_EXT:
+            {
+                ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
+            } break;
         case GGML_OP_FLASH_FF:
             {
                 ggml_compute_forward_flash_ff(params, tensor);
@@ -17388,6 +17725,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
         case GGML_OP_FLASH_ATTN:
+        case GGML_OP_FLASH_ATTN_EXT:
             {
                 struct ggml_tensor * flash_grad = NULL;
                 if (src0->grad || src1->grad || tensor->src[2]->grad) {
@@ -18160,6 +18498,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
                 n_tasks = n_threads;
             } break;
         case GGML_OP_FLASH_ATTN:
+        case GGML_OP_FLASH_ATTN_EXT:
             {
                 n_tasks = n_threads;
             } break;
@@ -18563,6 +18902,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
                         cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
                     }
                 } break;
+            case GGML_OP_FLASH_ATTN_EXT:
+                {
+                    const int64_t ne00 = node->src[0]->ne[0]; // D
+
+                    cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size
+                } break;
             case GGML_OP_FLASH_FF:
                 {
                     if (node->src[1]->type == GGML_TYPE_F32) {
diff --git a/ggml.h b/ggml.h
index 86e5a8dc57ad198237b275efb2213b7b1c085f82..a11795973ca3f7e54a2c407fe9d6d8fea450f645 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -475,6 +475,7 @@ extern "C" {
         GGML_OP_LEAKY_RELU,
 
         GGML_OP_FLASH_ATTN,
+        GGML_OP_FLASH_ATTN_EXT,
         GGML_OP_FLASH_FF,
         GGML_OP_FLASH_ATTN_BACK,
         GGML_OP_SSM_CONV,
@@ -1722,6 +1723,25 @@ extern "C" {
             struct ggml_tensor  * v,
             bool                  masked);
 
+#define GGML_KQ_MASK_PAD 32
+
+    // q:    [n_embd, n_batch,     n_head,    1]
+    // k:    [n_embd, n_kv,        n_head_kv, 1]
+    // v:    [n_embd, n_kv,        n_head_kv, 1] !! not transposed !!
+    // mask: [n_kv,   n_batch_pad, 1,         1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
+    // res:  [n_embd, n_head,      n_batch,   1] !! permuted !!
+    GGML_API struct ggml_tensor * ggml_flash_attn_ext(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * q,
+            struct ggml_tensor  * k,
+            struct ggml_tensor  * v,
+            struct ggml_tensor  * mask,
+            float                 scale);
+
+    GGML_API void ggml_flash_attn_ext_set_prec(
+            struct ggml_tensor * a,
+            enum ggml_prec       prec);
+
     GGML_API struct ggml_tensor * ggml_flash_attn_back(
            struct ggml_context * ctx,
            struct ggml_tensor  * q,
index 72c10ffc202fcc97a1bb2be0bc4d43ecd86db302..18d6297ce1dfd33a998d4e8668165665b6aff51f 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
 #define LLAMA_MAX_NODES   8192
 #define LLAMA_MAX_EXPERTS 60
 
-
 //
 // logging
 //
@@ -1846,7 +1845,7 @@ struct llama_hparams {
     float f_logit_scale    = 0.0f;
 
     bool causal_attn = true;
-    bool need_kq_pos = false;
+    bool use_alibi   = false; // currently, we need KQ_pos data for ALiBi-based models
 
     enum llama_pooling_type      pooling_type            = LLAMA_POOLING_TYPE_NONE;
     enum llama_rope_type         rope_type               = LLAMA_ROPE_TYPE_NONE;
@@ -1936,6 +1935,7 @@ struct llama_cparams {
     bool embeddings;
     bool causal_attn;
     bool offload_kqv;
+    bool flash_attn;
 
     enum llama_pooling_type pooling_type;
 
@@ -2039,8 +2039,8 @@ struct llama_kv_cache {
     bool has_shift = false;
     bool do_defrag = false;
     bool do_copy   = false;
-    // with recurrent state models, a cell can hold the state for more than one past token
-    bool recurrent = false;
+    bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
+    bool v_trans   = true;  // the value tensor is transposed
 
     // Note: The value of head isn't only used to optimize searching
     // for a free KV slot. llama_decode_internal also uses it, so it
@@ -2339,11 +2339,14 @@ struct llama_context {
 
 static bool llama_kv_cache_init(
              struct llama_kv_cache & cache,
-                 const llama_model & model,
+               const llama_context * ctx,
                          ggml_type   type_k,
                          ggml_type   type_v,
                           uint32_t   kv_size,
                               bool   offload) {
+    const llama_model & model = ctx->model;
+    const llama_cparams & cparams = ctx->cparams;
+
     const struct llama_hparams & hparams = model.hparams;
 
     const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
@@ -2354,6 +2357,7 @@ static bool llama_kv_cache_init(
 
     // TODO: find a nicer way to add other recurrent model architectures
     cache.recurrent = model.arch == LLM_ARCH_MAMBA;
+    cache.v_trans   = !cparams.flash_attn;
 
     // TODO: support mixed reccurent Transformer architectues
     // NOTE: (!a || b) is a logical implication (a -> b)
@@ -2566,6 +2570,10 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
     }
     cache.head = 0;
     cache.used = 0;
+
+    for (auto & buf : cache.bufs) {
+        ggml_backend_buffer_clear(buf, 0);
+    }
 }
 
 static bool llama_kv_cache_seq_rm(
@@ -4194,7 +4202,7 @@ static void llm_load_hparams(
     model.ftype = ml.ftype;
 
     if (hparams.f_max_alibi_bias > 0.0f) {
-        hparams.need_kq_pos = true;
+        hparams.use_alibi = true;
     }
 
     hparams.rope_type = llama_rope_type(&model);
@@ -6203,37 +6211,47 @@ static struct ggml_tensor * llm_build_inp_embd(
 static void llm_build_kv_store(
         struct ggml_context * ctx,
         const llama_hparams & hparams,
+        const llama_cparams & cparams,
        const llama_kv_cache & kv,
          struct ggml_cgraph * graph,
          struct ggml_tensor * k_cur,
          struct ggml_tensor * v_cur,
-                    int64_t   n_ctx,
                     int32_t   n_tokens,
                     int32_t   kv_head,
          const llm_build_cb & cb,
                     int64_t   il) {
+    const int64_t n_ctx = cparams.n_ctx;
+
     const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
     const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
 
     GGML_ASSERT(kv.size == n_ctx);
 
-    // compute the transposed [n_tokens, n_embd] V matrix
-    assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
-    struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur);
-    cb(v_cur_t, "v_cur_t", il);
-
     struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa,
             (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head);
     cb(k_cache_view, "k_cache_view", il);
 
-    struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa,
-            (  n_ctx)*ggml_element_size(kv.v_l[il]),
-            (kv_head)*ggml_element_size(kv.v_l[il]));
+    // note: storing RoPE-ed version of K in the KV cache
+    ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
+
+    assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
+
+    struct ggml_tensor * v_cache_view = nullptr;
+
+    if (cparams.flash_attn) {
+        v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa,
+                (kv_head)*ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa));
+    } else {
+        // note: the V cache is transposed when not using flash attention
+        v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa,
+                (  n_ctx)*ggml_element_size(kv.v_l[il]),
+                (kv_head)*ggml_element_size(kv.v_l[il]));
+
+        v_cur = ggml_transpose(ctx, v_cur);
+    }
     cb(v_cache_view, "v_cache_view", il);
 
-    // important: storing RoPE-ed version of K in the KV cache!
-    ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur,   k_cache_view));
-    ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view));
+    ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
 }
 
 static struct ggml_tensor * llm_build_norm(
@@ -6453,11 +6471,11 @@ static struct ggml_tensor * llm_build_moe_ffn(
     return moe_out;
 }
 
-// if max_alibi_bias > 0 then apply ALiBi
 static struct ggml_tensor * llm_build_kqv(
         struct ggml_context * ctx,
           const llama_model & model,
         const llama_hparams & hparams,
+        const llama_cparams & cparams,
        const llama_kv_cache & kv,
          struct ggml_cgraph * graph,
          struct ggml_tensor * wo,
@@ -6465,12 +6483,12 @@ static struct ggml_tensor * llm_build_kqv(
          struct ggml_tensor * q_cur,
          struct ggml_tensor * kq_mask,
          struct ggml_tensor * kq_pos,
-                    int64_t   n_ctx,
                     int32_t   n_tokens,
                     int32_t   n_kv,
                     float     kq_scale,
          const llm_build_cb & cb,
                     int       il) {
+    const int64_t n_ctx         = cparams.n_ctx;
     const int64_t n_head        = hparams.n_head;
     const int64_t n_head_kv     = hparams.n_head_kv;
     const int64_t n_embd_head_k = hparams.n_embd_head_k;
@@ -6488,71 +6506,99 @@ static struct ggml_tensor * llm_build_kqv(
                 0);
     cb(k, "k", il);
 
-    struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
-    cb(kq, "kq", il);
+    struct ggml_tensor * cur;
 
-    if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) {
-        // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
-        // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
-        ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
-    }
+    if (cparams.flash_attn) {
+        GGML_UNUSED(model);
+        GGML_UNUSED(n_ctx);
 
-    if (model.arch == LLM_ARCH_GROK) {
-        // need to do the following:
-        // multiply by attn_output_multiplyer of 0.08838834764831845
-        // and then :
-        // kq = 30 * tanh(kq / 30)
-        // before the softmax below
+        // note: if this assert triggers, then some check has failed earlier
+        //       the idea is to detect during context creation that ALiBi would be used and disable Flash Attention
+        GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention");
 
-        //try from phi2
-        //ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
+        // split cached v into n_head heads (not transposed)
+        struct ggml_tensor * v =
+            ggml_view_3d(ctx, kv.v_l[il],
+                    n_embd_head_v, n_kv, n_head_kv,
+                    ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa),
+                    ggml_row_size(kv.v_l[il]->type, n_embd_head_k),
+                    0);
+        cb(v, "v", il);
 
-        kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
-        kq = ggml_scale(ctx, kq, 30);
-    }
+        cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale);
+
+        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) {
+            ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
+        }
+
+        cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens);
+    } else {
+        struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
+        cb(kq, "kq", il);
+
+        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) {
+            // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
+            // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
+            ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
+        }
+
+        if (model.arch == LLM_ARCH_GROK) {
+            // need to do the following:
+            // multiply by attn_output_multiplyer of 0.08838834764831845
+            // and then :
+            // kq = 30 * tanh(kq / 30)
+            // before the softmax below
+
+            //try from phi2
+            //ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
+
+            kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
+            kq = ggml_scale(ctx, kq, 30);
+        }
 
 #if defined(GGML_USE_KOMPUTE)
 #pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Kompute")
 #pragma message("      Falling back to ggml_alibi(). Will become an error in Mar 2024")
 #pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/5488")
-    if (hparams.f_max_alibi_bias > 0.0f) {
-        kq = ggml_scale(ctx, kq, kq_scale);
-        cb(kq, "kq_scaled", il);
+        if (hparams.use_alibi) {
+            kq = ggml_scale(ctx, kq, kq_scale);
+            cb(kq, "kq_scaled", il);
 
-        kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, hparams.f_max_alibi_bias);
-        cb(kq, "kq_scaled_alibi", il);
+            kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, hparams.f_max_alibi_bias);
+            cb(kq, "kq_scaled_alibi", il);
 
-        kq = ggml_add(ctx, kq, kq_mask);
-        cb(kq, "kq_masked", il);
+            kq = ggml_add(ctx, kq, kq_mask);
+            cb(kq, "kq_masked", il);
 
-        kq = ggml_soft_max(ctx, kq);
-        cb(kq, "kq_soft_max", il);
-    } else
+            kq = ggml_soft_max(ctx, kq);
+            cb(kq, "kq_soft_max", il);
+        } else
 #endif
-    {
-        kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_pos, kq_scale, hparams.f_max_alibi_bias);
-        cb(kq, "kq_soft_max_ext", il);
-    }
+        {
+            kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_pos, kq_scale, hparams.f_max_alibi_bias);
+            cb(kq, "kq_soft_max_ext", il);
+        }
 
-    GGML_ASSERT(kv.size == n_ctx);
+        GGML_ASSERT(kv.size == n_ctx);
 
-    // split cached v into n_head heads
-    struct ggml_tensor * v =
-        ggml_view_3d(ctx, kv.v_l[il],
-                n_kv, n_embd_head_v, n_head_kv,
-                ggml_element_size(kv.v_l[il])*n_ctx,
-                ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v,
-                0);
-    cb(v, "v", il);
+        // split cached v into n_head heads
+        struct ggml_tensor * v =
+            ggml_view_3d(ctx, kv.v_l[il],
+                    n_kv, n_embd_head_v, n_head_kv,
+                    ggml_element_size(kv.v_l[il])*n_ctx,
+                    ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v,
+                    0);
+        cb(v, "v", il);
 
-    struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
-    cb(kqv, "kqv", il);
+        struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
+        cb(kqv, "kqv", il);
 
-    struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
-    cb(kqv_merged, "kqv_merged", il);
+        struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
+        cb(kqv_merged, "kqv_merged", il);
 
-    struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens);
-    cb(cur, "kqv_merged_cont", il);
+        cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens);
+        cb(cur, "kqv_merged_cont", il);
+    }
 
     ggml_build_forward_expand(graph, cur);
 
@@ -6572,6 +6618,7 @@ static struct ggml_tensor * llm_build_kv(
         struct ggml_context * ctx,
           const llama_model & model,
         const llama_hparams & hparams,
+        const llama_cparams & cparams,
        const llama_kv_cache & kv,
          struct ggml_cgraph * graph,
          struct ggml_tensor * wo,
@@ -6581,7 +6628,6 @@ static struct ggml_tensor * llm_build_kv(
          struct ggml_tensor * q_cur,
          struct ggml_tensor * kq_mask,
          struct ggml_tensor * kq_pos,
-                    int64_t   n_ctx,
                     int32_t   n_tokens,
                     int32_t   kv_head,
                     int32_t   n_kv,
@@ -6595,12 +6641,12 @@ static struct ggml_tensor * llm_build_kv(
     ggml_build_forward_expand(graph, k_cur);
     ggml_build_forward_expand(graph, v_cur);
 
-    llm_build_kv_store(ctx, hparams, kv, graph, k_cur, v_cur, n_ctx, n_tokens, kv_head, cb, il);
+    llm_build_kv_store(ctx, hparams, cparams, kv, graph, k_cur, v_cur, n_tokens, kv_head, cb, il);
 
     struct ggml_tensor * cur;
 
-    cur  = llm_build_kqv(ctx, model, hparams, kv, graph, wo, wo_b,
-            q_cur, kq_mask, kq_pos, n_ctx, n_tokens, n_kv, kq_scale, cb, il);
+    cur  = llm_build_kqv(ctx, model, hparams, cparams, kv, graph, wo, wo_b,
+            q_cur, kq_mask, kq_pos, n_tokens, n_kv, kq_scale, cb, il);
     cb(cur, "kqv_out", il);
 
     return cur;
@@ -6642,6 +6688,8 @@ struct llm_build_context {
     const int32_t kv_head;  // index of where we store new KV data in the cache
     const int32_t n_orig_ctx;
 
+    const bool flash_attn;
+
     const enum llama_pooling_type pooling_type;
     const enum llama_rope_type    rope_type;
 
@@ -6688,6 +6736,7 @@ struct llm_build_context {
         n_outputs        (worst_case ? n_tokens : lctx.n_outputs),
         kv_head          (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
         n_orig_ctx       (cparams.n_yarn_orig_ctx),
+        flash_attn       (cparams.flash_attn),
         pooling_type     (cparams.pooling_type),
         rope_type        (hparams.rope_type),
         cb               (cb),
@@ -6802,15 +6851,31 @@ struct llm_build_context {
                         ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
                         ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id));
 
-                ggml_tensor * view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
-                        nm, n_embd_v_gqa,
-                        ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
-                        ggml_row_size(kv_self.v_l[il]->type, i));
+                ggml_tensor * view_v_src;
+                ggml_tensor * view_v_dst;
 
-                ggml_tensor * view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
-                        nm, n_embd_v_gqa,
-                        ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
-                        ggml_row_size(kv_self.v_l[il]->type, id));
+                if (flash_attn) {
+                    // NOTE: the V cache is not transposed when using flash attention
+                    view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
+                            n_embd_v_gqa, nm,
+                            ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
+                            ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i));
+
+                    view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
+                            n_embd_v_gqa, nm,
+                            ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
+                            ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id));
+                } else {
+                    view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
+                            nm, n_embd_v_gqa,
+                            ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
+                            ggml_row_size(kv_self.v_l[il]->type, i));
+
+                    view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
+                            nm, n_embd_v_gqa,
+                            ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
+                            ggml_row_size(kv_self.v_l[il]->type, id));
+                }
 
                 ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
                 ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
@@ -6840,20 +6905,26 @@ struct llm_build_context {
 
     struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
         if (causal) {
-            lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, n_tokens);
+            lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv,     GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
         } else {
-            lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
+            lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
         }
         cb(lctx.inp_KQ_mask, "KQ_mask", -1);
         ggml_set_input(lctx.inp_KQ_mask);
-        return lctx.inp_KQ_mask;
+        return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
     }
 
-    struct ggml_tensor * build_inp_KQ_pos() {
-        lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv);
+    struct ggml_tensor * build_inp_KQ_pos(bool causal = true) {
+        if (causal) {
+            lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv);
+        } else {
+            // TODO: this will be needed for ALiBi-based BERT models
+            //       https://github.com/ggerganov/llama.cpp/pull/6826
+            lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_tokens);
+        }
         cb(lctx.inp_KQ_pos, "KQ_pos", -1);
         ggml_set_input(lctx.inp_KQ_pos);
-        return lctx.inp_KQ_pos;
+        return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16) : lctx.inp_KQ_pos;
     }
 
     struct ggml_tensor * build_inp_mean() {
@@ -6959,9 +7030,9 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -7099,9 +7170,9 @@ struct llm_build_context {
                 cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -7206,9 +7277,9 @@ struct llm_build_context {
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -7326,9 +7397,9 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -7451,9 +7522,9 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -7603,9 +7674,9 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
-                                   model.layers[il].wo, NULL,
-                                   Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                        model.layers[il].wo, NULL,
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -7715,9 +7786,9 @@ struct llm_build_context {
 
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -7919,9 +7990,9 @@ struct llm_build_context {
                         );
                 cb(Vcur, "Vcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Q, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Q, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -8015,9 +8086,9 @@ struct llm_build_context {
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                 cb(Qcur, "Qcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -8308,9 +8379,9 @@ struct llm_build_context {
 
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -8439,14 +8510,15 @@ struct llm_build_context {
                     Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                     Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
 
-                    cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                    cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                            model.layers[il].wo, model.layers[il].bo,
+                            Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 } else {
                     Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-                    cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+
+                    cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                             model.layers[il].wo, model.layers[il].bo,
-                            Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                            Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 }
             }
 
@@ -8588,9 +8660,9 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -8706,9 +8778,9 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -8819,9 +8891,9 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -8933,9 +9005,9 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -9088,9 +9160,9 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -9205,9 +9277,9 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
-                    model.layers[il].wo, NULL,
-                    Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -9318,9 +9390,9 @@ struct llm_build_context {
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
             struct ggml_tensor * sa_out = cur;
 
@@ -9421,9 +9493,9 @@ struct llm_build_context {
 
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -9528,9 +9600,9 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -9644,9 +9716,9 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -9761,9 +9833,9 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -9891,9 +9963,9 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -10012,9 +10084,9 @@ struct llm_build_context {
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -10131,9 +10203,9 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -10421,9 +10493,9 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -10552,9 +10624,9 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, nullptr,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -10981,7 +11053,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         }
     }
 
-    if (hparams.need_kq_pos) {
+    // ALiBi requires the KQ_pos tensor to provide the sequence position of each token in the batch
+    // this allows to process multiple sequences in parallel with ALiBi-based models
+    if (hparams.use_alibi) {
         const int64_t n_kv = kv_self.n;
 
         GGML_ASSERT(lctx.inp_KQ_pos);
@@ -11363,7 +11437,7 @@ static int llama_decode_internal(
                 // a heuristic, to avoid attending the full cache if it is not yet utilized
                 // after enough generations, the benefit from this heuristic disappears
                 // if we start defragmenting the cache, the benefit from this will be more important
-                kv_self.n = std::min(kv_self.size, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
+                kv_self.n = std::min(kv_self.size, std::max(256u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 256)));
                 //kv_self.n = llama_kv_cache_cell_max(kv_self);
             }
         }
@@ -11560,7 +11634,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
     // each move requires 6*n_layer tensors (see build_defrag)
     //   - source view, destination view, copy operation
     //   - x2 for keys and values
-    const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer);
+    //const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer);
+    // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
+    const uint32_t max_moves = (LLAMA_MAX_NODES - 2*n_layer)/(6*n_layer);
 
     // determine which KV cells to move where
     //
@@ -15167,6 +15243,7 @@ struct llama_context_params llama_context_default_params() {
         /*.logits_all                  =*/ false,
         /*.embeddings                  =*/ false,
         /*.offload_kqv                 =*/ true,
+        /*.flash_attn                  =*/ false,
         /*.abort_callback              =*/ nullptr,
         /*.abort_callback_data         =*/ nullptr,
     };
@@ -15333,6 +15410,7 @@ struct llama_context * llama_new_context_with_model(
     cparams.defrag_thold     = params.defrag_thold;
     cparams.embeddings       = params.embeddings;
     cparams.offload_kqv      = params.offload_kqv;
+    cparams.flash_attn       = params.flash_attn;
     cparams.pooling_type     = params.pooling_type;
 
     cparams.n_ctx            = params.n_ctx           == 0    ? hparams.n_ctx_train           : params.n_ctx;
@@ -15340,12 +15418,20 @@ struct llama_context * llama_new_context_with_model(
     cparams.rope_freq_scale  = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
 
     // this is necessary due to kv_self.n being padded later during inference
-    cparams.n_ctx = GGML_PAD(cparams.n_ctx, 32);
+    cparams.n_ctx            = GGML_PAD(cparams.n_ctx, 256);
 
     // with causal attention, the batch size is limited by the context size
     cparams.n_batch          = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
-    cparams.n_ubatch         = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
 
+    // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
+    // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
+    // ref: https://github.com/ggerganov/llama.cpp/pull/5021
+    if (cparams.n_batch < GGML_KQ_MASK_PAD) {
+        LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
+        cparams.n_batch = GGML_KQ_MASK_PAD;
+    }
+
+    cparams.n_ubatch         = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
 
     cparams.n_yarn_orig_ctx  = params.yarn_orig_ctx    != 0 ? params.yarn_orig_ctx    :
                                hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx :
@@ -15377,6 +15463,23 @@ struct llama_context * llama_new_context_with_model(
         }
     }
 
+    if (cparams.flash_attn && hparams.use_alibi) {
+        LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with ALiBi - forcing off\n", __func__);
+        cparams.flash_attn = false;
+    }
+
+    if (cparams.flash_attn && model->arch == LLM_ARCH_GROK) {
+        LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
+        cparams.flash_attn = false;
+    }
+
+#ifdef GGML_USE_HIPBLAS
+    if (cparams.flash_attn) {
+        LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with HIPBLAS builds - forcing off\n", __func__);
+        cparams.flash_attn = false;
+    }
+#endif
+
     if (params.seed == LLAMA_DEFAULT_SEED) {
         params.seed = time(NULL);
     }
@@ -15384,6 +15487,7 @@ struct llama_context * llama_new_context_with_model(
     LLAMA_LOG_INFO("%s: n_ctx      = %u\n",     __func__, cparams.n_ctx);
     LLAMA_LOG_INFO("%s: n_batch    = %u\n",     __func__, cparams.n_batch);
     LLAMA_LOG_INFO("%s: n_ubatch   = %u\n",     __func__, cparams.n_ubatch);
+    LLAMA_LOG_INFO("%s: flash_attn = %d\n",     __func__, cparams.flash_attn);
     LLAMA_LOG_INFO("%s: freq_base  = %.1f\n",   __func__, cparams.rope_freq_base);
     LLAMA_LOG_INFO("%s: freq_scale = %g\n",     __func__, cparams.rope_freq_scale);
 
@@ -15512,7 +15616,7 @@ struct llama_context * llama_new_context_with_model(
         }
         ctx->backends.push_back(ctx->backend_cpu);
 
-        if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, kv_size, cparams.offload_kqv)) {
+        if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) {
             LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
             llama_free(ctx);
             return nullptr;
@@ -16111,6 +16215,7 @@ size_t llama_state_get_size(const struct llama_context * ctx) {
     const size_t s_kv_head         = sizeof(uint32_t);
     const size_t s_kv_size         = sizeof(uint32_t);
     const size_t s_kv_used         = sizeof(uint32_t);
+    const size_t s_v_trans         = sizeof(uint32_t);
     const size_t s_kv              = ctx->kv_self.total_size();
     const size_t s_kv_cell         = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id);
     const size_t s_kv_cells        = ctx->kv_self.size * s_kv_cell;
@@ -16128,10 +16233,14 @@ size_t llama_state_get_size(const struct llama_context * ctx) {
         + s_kv_head
         + s_kv_size
         + s_kv_used
+        + s_v_trans
         + s_kv
         + s_kv_cells
     );
 
+    // on session change it is very likely that the state size has changed - so we need to update this function
+    static_assert(LLAMA_SESSION_VERSION == 6, "So you just bumped the session version - good. But did you remember to update llama_state_get_size?");
+
     return s_total;
 }
 
@@ -16277,11 +16386,13 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data
         const uint32_t kv_size     = kv_self.size;
         const size_t   kv_buf_size = kv_self.total_size() / (kv_size ? kv_size : 1) * kv_head;
         const uint32_t kv_used     = kv_self.used;
+        const uint32_t v_trans     = kv_self.v_trans ? 1 : 0;
 
         data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
         data_ctx->write(&kv_head,     sizeof(kv_head));
         data_ctx->write(&kv_size,     sizeof(kv_size));
         data_ctx->write(&kv_used,     sizeof(kv_used));
+        data_ctx->write(&v_trans,     sizeof(v_trans));
 
         if (kv_buf_size) {
             const size_t pre_kv_buf_size = data_ctx->get_size_written();
@@ -16294,7 +16405,7 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data
                 ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size());
                 data_ctx->write(tmp_buf.data(), tmp_buf.size());
 
-                if (kv_self.recurrent) {
+                if (kv_self.recurrent || !kv_self.v_trans) {
                     // v is contiguous for recurrent models
                     // TODO: use other tensors for state models than k and v
                     const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
@@ -16427,11 +16538,15 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
         uint32_t kv_head;
         uint32_t kv_size;
         uint32_t kv_used;
+        uint32_t v_trans;
 
         memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size);
         memcpy(&kv_head,     inp, sizeof(kv_head));     inp += sizeof(kv_head);
         memcpy(&kv_size,     inp, sizeof(kv_size));     inp += sizeof(kv_size);
         memcpy(&kv_used,     inp, sizeof(kv_used));     inp += sizeof(kv_used);
+        memcpy(&v_trans,     inp, sizeof(v_trans));     inp += sizeof(v_trans);
+
+        GGML_ASSERT(kv_self.v_trans == (bool) v_trans); // incompatible V transposition
 
         if (kv_self.size != kv_size) {
             // the KV cache needs to be big enough to load all the KV cells from the saved state
@@ -16441,6 +16556,8 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
                 __func__, kv_head, kv_size, kv_self.size);
         }
 
+        llama_kv_cache_clear(ctx);
+
         if (kv_buf_size) {
             const size_t pre_kv_buf_size = inp - src;
 
@@ -16452,7 +16569,7 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
                 ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size);
                 inp += k_size;
 
-                if (kv_self.recurrent) {
+                if (kv_self.recurrent || !kv_self.v_trans) {
                     // v is contiguous for recurrent models
                     // TODO: use other tensors for state models than k and v
                     const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
@@ -16474,8 +16591,6 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
             GGML_ASSERT(kv_buf_size == inp - src - pre_kv_buf_size);
         }
 
-        llama_kv_cache_clear(ctx);
-
         ctx->kv_self.head = kv_head;
         ctx->kv_self.used = kv_used;
 
@@ -16735,28 +16850,49 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
         }
     }
 
-    // For the values, they are transposed, so we also need the element size and get the element ranges from each row
-    const uint32_t kv_size = kv_self.size;
-    for (int il = 0; il < (int)n_layer; ++il) {
-        // Write value type
-        const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
-        data_ctx.write(&v_type_i, sizeof(v_type_i));
+    // TODO: simplify, reduce copy-paste
+    if (!kv_self.v_trans) {
+        for (int il = 0; il < (int)n_layer; ++il) {
+            // Write value type
+            const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
+            data_ctx.write(&v_type_i, sizeof(v_type_i));
 
-        // Write element size
-        const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
-        data_ctx.write(&v_size_el, sizeof(v_size_el));
+            // Write row size of value
+            const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
+            data_ctx.write(&v_size_row, sizeof(v_size_row));
 
-        // For each row, we get the element values of each cell
-        for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-            // Read each range of cells of v_size_el length each into tmp_buf and write out
+            // Read each range of cells of v_size length each into tmp_buf and write out
             for (const auto & range : cell_ranges) {
                 const size_t range_size = range.second - range.first;
-                const size_t src_offset = (range.first + j * kv_size) * v_size_el;
-                tmp_buf.resize(range_size * v_size_el);
-                ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size());
+                tmp_buf.resize(range_size * v_size_row);
+                ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), range.first * v_size_row, range_size * v_size_row);
                 data_ctx.write(tmp_buf.data(), tmp_buf.size());
             }
         }
+    } else {
+        // For the values, they are transposed, so we also need the element size and get the element ranges from each row
+        const uint32_t kv_size = kv_self.size;
+        for (int il = 0; il < (int)n_layer; ++il) {
+            // Write value type
+            const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
+            data_ctx.write(&v_type_i, sizeof(v_type_i));
+
+            // Write element size
+            const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
+            data_ctx.write(&v_size_el, sizeof(v_size_el));
+
+            // For each row, we get the element values of each cell
+            for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                // Read each range of cells of v_size_el length each into tmp_buf and write out
+                for (const auto & range : cell_ranges) {
+                    const size_t range_size = range.second - range.first;
+                    const size_t src_offset = (range.first + j * kv_size) * v_size_el;
+                    tmp_buf.resize(range_size * v_size_el);
+                    ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size());
+                    data_ctx.write(tmp_buf.data(), tmp_buf.size());
+                }
+            }
+        }
     }
 
     return data_ctx.get_size_written();
@@ -16881,41 +17017,75 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
         }
     }
 
-    // For each layer, read the values for each cell (transposed)
-    for (int il = 0; il < (int)n_layer; ++il) {
-        // Read type of value
-        int32_t v_type_i_ref;
-        memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
-        inp += sizeof(v_type_i_ref);
-        const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
-        if (v_type_i != v_type_i_ref) {
-            llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
-            LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
-            return 0;
-        }
+    // TODO: simplify, reduce copy-paste
+    if (!kv_self.v_trans) {
+        for (int il = 0; il < (int)n_layer; ++il) {
+            // Read type of value
+            int32_t v_type_i_ref;
+            memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
+            inp += sizeof(v_type_i_ref);
+            const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
+            if (v_type_i != v_type_i_ref) {
+                llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
+                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
+                return 0;
+            }
 
-        // Read element size of value
-        size_t v_size_el_ref;
-        memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref));
-        inp += sizeof(v_size_el_ref);
-        const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
-        if (v_size_el != v_size_el_ref) {
-            llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
-            LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il);
-            return 0;
-        }
+            // Read row size of value
+            size_t v_size_row_ref;
+            memcpy(&v_size_row_ref, inp, sizeof(v_size_row_ref));
+            inp += sizeof(v_size_row_ref);
+            const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
+            if (v_size_row != v_size_row_ref) {
+                llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
+                LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, v_size_row_ref, il);
+                return 0;
+            }
 
-        if (cell_count) {
-            // For each row in the transposed matrix, read the values for the whole cell range
-            for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                const size_t dst_offset = (kv_head + j * kv_size) * v_size_el;
-                ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el);
-                inp += cell_count * v_size_el;
+            if (cell_count) {
+                // Read and set the values for the whole cell range
+                ggml_backend_tensor_set(kv_self.v_l[il], inp, kv_head * v_size_row, cell_count * v_size_row);
+                inp += cell_count * v_size_row;
+            }
+        }
+    } else {
+        // For each layer, read the values for each cell (transposed)
+        for (int il = 0; il < (int)n_layer; ++il) {
+            // Read type of value
+            int32_t v_type_i_ref;
+            memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
+            inp += sizeof(v_type_i_ref);
+            const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
+            if (v_type_i != v_type_i_ref) {
+                llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
+                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
+                return 0;
+            }
+
+            // Read element size of value
+            size_t v_size_el_ref;
+            memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref));
+            inp += sizeof(v_size_el_ref);
+            const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
+            if (v_size_el != v_size_el_ref) {
+                llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
+                LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il);
+                return 0;
+            }
+
+            if (cell_count) {
+                // For each row in the transposed matrix, read the values for the whole cell range
+                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                    const size_t dst_offset = (kv_head + j * kv_size) * v_size_el;
+                    ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el);
+                    inp += cell_count * v_size_el;
+                }
             }
         }
     }
 
     const size_t nread = inp - src;
+
     return nread;
 }
 
diff --git a/llama.h b/llama.h
index 30835de5f9433e1140f3dae5781406d96b276de3..059d78f115c6dd1731a6b52de57951f35dee102b 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -40,7 +40,7 @@
 #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
 
 #define LLAMA_SESSION_MAGIC   LLAMA_FILE_MAGIC_GGSN
-#define LLAMA_SESSION_VERSION 5
+#define LLAMA_SESSION_VERSION 6
 
 #define LLAMA_STATE_SEQ_MAGIC   LLAMA_FILE_MAGIC_GGSQ
 #define LLAMA_STATE_SEQ_VERSION 1
@@ -287,6 +287,7 @@ extern "C" {
         bool logits_all;  // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
         bool embeddings;  // if true, extract embeddings (together with logits)
         bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
+        bool flash_attn;  // whether to use flash attention
 
         // Abort callback
         // if it returns true, execution of llama_decode() will be aborted
@@ -542,7 +543,7 @@ extern "C" {
     // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
     LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
 
-    // Clear the KV cache
+    // Clear the KV cache - both cell info is erased and KV data is zeroed
     LLAMA_API void llama_kv_cache_clear(
             struct llama_context * ctx);
 
index 02daad24b030aba5f5f08c7e02329f1dd84c8c4c..b27c1291e4088be86d97f742d3a5aa1ade2d24fa 100644 (file)
@@ -1090,6 +1090,12 @@ struct test_soft_max : public test_case {
         return VARS_TO_STR5(type, ne, mask, scale, max_bias);
     }
 
+    // the 1024 test with bias occasionally fails:
+    // SOFT_MAX(type=f32,ne=[1024,16,1,1],mask=1,scale=1.000000,max_bias=8.000000): [SOFT_MAX] NMSE = 0.000000103 > 0.000000100 FAIL
+    virtual double max_nmse_err() override {
+        return 1e-6;
+    }
+
     test_soft_max(ggml_type type = GGML_TYPE_F32,
             std::array<int64_t, 4> ne = {10, 10, 10, 10},
             bool mask = false,
@@ -1101,7 +1107,7 @@ struct test_soft_max : public test_case {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
         ggml_tensor * mask = nullptr;
         if (this->mask) {
-            mask = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]);
+            mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]);
         }
         ggml_tensor * pos = nullptr;
         if (max_bias > 0.0f) {
@@ -1475,6 +1481,34 @@ struct test_leaky_relu : public test_case {
     }
 };
 
+// GGML_OP_FLASH_ATTN_EXT
+struct test_flash_attn_ext : public test_case {
+    const int64_t hs; // head size
+    const int64_t nh; // num heads
+    const int64_t kv; // kv size
+    const int64_t nb; // batch size
+
+    std::string vars() override {
+        return VARS_TO_STR4(hs, nh, kv, nb);
+    }
+
+    double max_nmse_err() override {
+        return 5e-4;
+    }
+
+    test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8)
+        : hs(hs), nh(nh), kv(kv), nb(nb) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
+        ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
+        ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
+        ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1);
+        ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs));
+        return out;
+    }
+};
+
 enum llm_norm_type {
     LLM_NORM,
     LLM_NORM_RMS,
@@ -1661,7 +1695,7 @@ struct test_llama : public test_llm {
         struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1);
+        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1);
 
         ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
         ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
@@ -1783,7 +1817,7 @@ struct test_falcon : public test_llm {
         struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1);
+        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1);
 
         ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
         ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
@@ -2095,7 +2129,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
             for (float scale : {1.0f, 0.1f}) {
                 for (int64_t ne0 : {16, 1024}) {
                     for (int64_t ne1 : {16, 1024}) {
-                        test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, scale, max_bias));
+                        test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, mask, scale, max_bias));
                         test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, scale, max_bias));
                     }
                 }
@@ -2139,6 +2173,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     test_cases.emplace_back(new test_timestep_embedding());
     test_cases.emplace_back(new test_leaky_relu());
 
+    for (int hs : { 64, 80, 128, 256, }) {
+        for (int nh : { 32, }) {
+            for (int kv : { 512, 1024, }) {
+                for (int nb : { 1, 2, 4, 8, }) {
+                    test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb));
+                }
+            }
+        }
+    }
+
     // these tests are disabled to save execution time, but they can be handy for debugging
 #if 0
     test_cases.emplace_back(new test_llama(1));