]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : add Flash Attention (llama/5021)
authorGeorgi Gerganov <redacted>
Tue, 30 Apr 2024 09:16:08 +0000 (12:16 +0300)
committerGeorgi Gerganov <redacted>
Mon, 13 May 2024 08:02:26 +0000 (11:02 +0300)
* ggml : add ggml_flash_attn_ext API

* ggml : fix GQA support in ggml_flash_attn_ext

* ggml : online attention (CPU)

* metal : initial implementation

* metal : f16 precision

* metal : reduce branches

* metal : specialize for head size

* wip : 8 rows per simd group

* wip : 4 rows per simd group

* wip : template for rows per warp

* metal : parallelize across KV size

* metal : parallel reduce across heads

* metal : efficient flash_attn_f16 implementation

* metal : avoid redundant loads of the attention

* metal : scale and mask in matrix form

* metal : fix comment

* llama : avoid ggml_cast, use F32 query

* metal : add parallel reduce version (disabled)

* metal : move output into local memory + optimize

- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments

* metal : add tests, fix scaling, support C > 32

* metal : improve precision

* ggml : fix f16 mad

* metal : minor

* metal : support Q > 8

* tests : add ATTN tests

* metal : disable buffer allocation logs

* tests : more

* metal : faster inner loop for C == 32

* metal : fix array initialization

* tests : ifdef

* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext

* ggml : fix ggml_soft_max mask requirement

* cuda : fix soft_max to use correct mask size

* cuda : add flash_attn kernel (wip)

* metal : optimize softmax for C > 32

* metal : optimize softmax

* tests : minor fix

* cuda : avoid zeroing fragments

* tests : update dims

* cuda : fix __hisinf() result check

* cuda : avoid warp_reduce for smax

* cuda : use int instead of int64_t

Noticeably improves performance (thanks to Johannes)

* cuda : make loops use the same loop values

Thanks Johannes again for the tip

* cuda : unroll some of the loops

* cuda : avoid __hisinf branches

* cuda : use half2 in softmax

* cuda : switch to 1 warp for bs > 16

* cuda : speed-up reduce part of the kernel

* cuda : unroll Q*K^T loop

* cuda : fix -INF block check

* cuda : simplify softmax

* cuda : fix matrix names

* cuda : minor

* llama : adapt to F16 KQ_pos

* llama : adapt new models to F16 KQ_mask

* ggml : fix F16 store (ARM NEON)

* llama : fix type of KQ_mask and KQ_pos

* ggml : fix CPU soft_max

* tests : add hs=256

* cuda : fix build

* metal : improve perf via smaller int registers

* cuda : adapt soft_max to F16 mask and pos

* CUDA: faster FlashAttention, kernel for bs == 1

* 16 cols for Phi-2

* no vec for hs, no hs==256 ncols==32 for Volta

* adjust kernel selection logic

* 4 warps, 256 stride for all D

* no ncols == 64

* Multiple parallel blocks for batch size 1

* fix compile warnings

* fix excessive KQ_b loads

* fix cmake build

* fix KV cache padding, NaN from INFINITY (llama/6438)

* llama : flash_attn cparam + fix defrag

* server: support flash_attn param

* server: bench: enable flash_attn param

* CUDA: refactor host code, dyn. par. blocks

* fix flash_attn_vec_f16 race condition

* flush softmax exp below threshold to 0

* store temp KQ in registers

* Calculate KQ as FP32 if KQV has GGML_PREC_F32

* Add __hgt2_mask implementation for CUDA 11

* fix KQ FP32 precision fpr parallel_blocks > 1

* llama-bench : add -fa,--flash-attn arg

* metal : add BS=1 kernel for flash attention (llama/6508)

* metal : add BS=1 kernel for flash attention (wip)

* metal : support more than 1 warps

* metal : opts

* metal : opt

* metal : switch to parallel reduce

* metal : reduce registers

* metal : simplify

* metal : initial FA vec kernel

* metal : use F32 attention accumulators

* batched-bench : add fattn arg

* llama : simplify llama_build_kv_store

ggml-ci

* llama : adapt build_olmo to changes

* ggml : fix arm fp16 store on windows

* metal : clean-up

* metal : clean-up kernel code

* metal : minor

* tests : remove benchmarks

ggml-ci

* ggml : fix avx512 const correctness

ggml-ci

* ggml : fix soft_max with bias on CPU

ggml-ci

* common : print --flash-attn in help

* ggml : fix num dimensions in ggml_flash_attn_ext

* llama : force disable flash attention for incompatible models

* ggml : ggml_soft_max support F16/F32 mask/pos

ggml-ci

* cuda : uint -> uint32_t

* cuda : "constexpr dim3" -> "const dim3"

ggml-ci

* cuda : try to fix __hgt2_mask

ggml-ci

* ggml : add TODO's for F16/F32 mask/pos support in other backends

* llama : replace bool need_kq_pos with use_alibi

* llama : prep ALiBi support for BERT models

ggml-ci

* llama : fix n_batch requirements

ggml-ci

* cont

* server : add help for --flash-attn arg

* llama : disable FA for AMD

* tests : remove TMP_ATTN_BENCH

ggml-ci

* llama : support save/load state with FA enabled

ggml-ci

* ci : add CUDA save-load-state tests

ggml-ci

* llama : llama_kv_cache_clear zeroes data + fix save-load seq

ggml-ci

* llama : fix copy-paste errors, add TODO

* llama : disallow incompatible states

* llama : update llama_state_get_size after v_trans field

* metal : remove tmp log

* llama : add static reminder for llama_state_get_size

* metal : fix max nsg

ggml-ci

* ci : fix arg order

ggml-ci

---------

Co-authored-by: Johannes Gäßler <redacted>
Co-authored-by: Pierrick HYMBERT <redacted>
12 files changed:
ggml-cuda.cu
ggml-cuda/common.cuh
ggml-cuda/fattn.cu [new file with mode: 0644]
ggml-cuda/fattn.cuh [new file with mode: 0644]
ggml-cuda/softmax.cu
ggml-kompute.cpp
ggml-metal.m
ggml-metal.metal
ggml-sycl.cpp
ggml-vulkan.cpp
ggml.c
ggml.h

index 07534370c34ff30f45fde5d58837ef9d5f3385be..fa56f9521e494134c2e3b680da741f2baa275eb5 100644 (file)
@@ -14,6 +14,7 @@
 #include "ggml-cuda/cpy.cuh"
 #include "ggml-cuda/diagmask.cuh"
 #include "ggml-cuda/dmmv.cuh"
+#include "ggml-cuda/fattn.cuh"
 #include "ggml-cuda/getrows.cuh"
 #include "ggml-cuda/im2col.cuh"
 #include "ggml-cuda/mmq.cuh"
@@ -140,6 +141,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
         info.devices[id].cc = 100*prop.major + 10*prop.minor;
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
         info.devices[id].smpb = prop.sharedMemPerBlock;
+        info.devices[id].nsm  = prop.multiProcessorCount;
     }
 
     for (int id = 0; id < info.device_count; ++id) {
@@ -2293,6 +2295,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_ARGSORT:
             ggml_cuda_op_argsort(ctx, dst);
             break;
+        case GGML_OP_FLASH_ATTN_EXT:
+            ggml_cuda_flash_attn_ext(ctx, dst);
+            break;
         default:
             return false;
     }
@@ -2568,6 +2573,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_ARANGE:
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_LEAKY_RELU:
+        case GGML_OP_FLASH_ATTN_EXT:
             return true;
         default:
             return false;
index 481065b2a3484b350f258b981a5b4efae291477a..156eba6d1ef74d780ccaa74e72eea9428ccac7e8 100644 (file)
 #define CC_PASCAL     600
 #define MIN_CC_DP4A   610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
 #define CC_VOLTA      700
+#define CC_AMPERE     800
 #define CC_OFFSET_AMD 1000000
 #define CC_RDNA1      (CC_OFFSET_AMD + 1010)
 #define CC_RDNA2      (CC_OFFSET_AMD + 1030)
@@ -271,7 +272,6 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
     return a;
 }
 
-#ifdef GGML_CUDA_F16
 static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
 #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
 #pragma unroll
@@ -284,7 +284,6 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
    NO_DEVICE_CODE;
 #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
 }
-#endif // GGML_CUDA_F16
 
 static __device__ __forceinline__ float warp_reduce_max(float x) {
 #pragma unroll
@@ -294,19 +293,26 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
     return x;
 }
 
-//static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
-//#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
-//#pragma unroll
-//    for (int mask = 16; mask > 0; mask >>= 1) {
-//        x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
-//    }
-//    return x;
-//#else
-//    GGML_UNUSED(x);
-//    NO_DEVICE_CODE;
-//#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
-//}
+static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
+#pragma unroll
+   for (int mask = 16; mask > 0; mask >>= 1) {
+       x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
+   }
+   return x;
+#else
+   GGML_UNUSED(x);
+   NO_DEVICE_CODE;
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
+}
 
+#if CUDART_VERSION < 12000
+static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
+    const uint32_t mask_low  = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
+    const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
+    return mask_low | mask_high;
+}
+#endif // CUDART_VERSION < 12000
 
 #if defined(GGML_USE_HIPBLAS)
 #define __CUDA_ARCH__ 1300
@@ -391,6 +397,11 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
 }
 #endif // defined(GGML_USE_HIPBLAS)
 
+#define FP16_AVAILABLE     defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \
+    defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL
+
+#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
+
 // TODO: move to ggml-common.h
 static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
 
@@ -404,6 +415,7 @@ struct ggml_cuda_device_info {
 
     struct cuda_device_info {
         int     cc;                 // compute capability
+        int     nsm;                // number of streaming multiprocessors
         size_t  smpb;               // max. shared memory per block
         bool    vmm;                // virtual memory support
         size_t  vmm_granularity;    // granularity of virtual memory
diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu
new file mode 100644 (file)
index 0000000..df1e800
--- /dev/null
@@ -0,0 +1,944 @@
+#include "common.cuh"
+#include "fattn.cuh"
+
+#include <cstdint>
+
+#if FP16_MMA_AVAILABLE
+#include <mma.h>
+#endif
+
+#define FATTN_KQ_STRIDE       256
+#define HALF_MAX_HALF         __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
+#define SOFTMAX_FTZ_THRESHOLD -20.0f                   // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
+
+template<int D, int parallel_blocks> // D == head size
+__launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1)
+static __global__ void flash_attn_vec_ext_f16(
+        const char * __restrict__ Q,
+        const char * __restrict__ K,
+        const char * __restrict__ V,
+        const char * __restrict__ mask,
+        float      * __restrict__ dst,
+        float2     * __restrict__ dst_meta,
+        const float scale,
+        const int ne00,
+        const int ne01,
+        const int ne02,
+        const int ne03,
+        const int ne10,
+        const int ne11,
+        const int ne12,
+        const int ne13,
+        const int ne31,
+        const int nb31,
+        const int nb01,
+        const int nb02,
+        const int nb03,
+        const int nb11,
+        const int nb12,
+        const int nb13,
+        const int ne0,
+        const int ne1,
+        const int ne2,
+        const int ne3) {
+#if FP16_AVAILABLE
+    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    const int ic = blockIdx.x / parallel_blocks; // Index of the Q/QKV column to work on.
+    const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+    const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.y              + nb01*ic);
+    const half2  * K_h2  = (const half2  *) (K    + nb12*(blockIdx.y / gqa_ratio));
+    const half   * V_h   = (const half   *) (V    + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
+    const half   * maskh = (const half   *)  mask + ne11*ic;
+
+    const int stride_KV  = nb11 / sizeof(half);
+    const int stride_KV2 = nb11 / sizeof(half2);
+
+    constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
+    const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
+    __builtin_assume(tid < nwarps*WARP_SIZE);
+
+    __shared__ half KQ[nwarps*WARP_SIZE];
+    KQ[tid] = -INFINITY;
+    half2 * KQ2 = (half2 *) KQ;
+
+    half kqmax = -HALF_MAX_HALF;
+    half kqsum = 0.0f;
+
+    __shared__ half kqmax_shared[WARP_SIZE];
+    __shared__ half kqsum_shared[WARP_SIZE];
+    if (threadIdx.y == 0) {
+        kqmax_shared[threadIdx.x] = -HALF_MAX_HALF;
+        kqsum_shared[threadIdx.x] = 0.0f;
+    }
+    __syncthreads();
+
+    // Convert Q to half2 and store in registers:
+    half2 Q_h2[(D/2 + WARP_SIZE - 1) / WARP_SIZE];
+#pragma unroll
+    for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+        const int i = i0 + threadIdx.x;
+        if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+            break;
+        }
+
+        Q_h2[i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(Q_f2[i].x, Q_f2[i].y);
+    }
+
+    half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value.
+
+    const int k_start  = parallel_blocks == 1 ? 0 : ip*D;
+    for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
+        // Calculate KQ tile and keep track of new maximum KQ values:
+        half kqmax_new = kqmax;
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
+            const int i_KQ = i_KQ_0 + threadIdx.y;
+
+            if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
+                break;
+            }
+
+            half2 sum2 = make_half2(0.0f, 0.0f);
+#pragma unroll
+            for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
+                const int k_KQ = k_KQ_0 + threadIdx.x;
+                if (k_KQ_0 + WARP_SIZE > D/2 && k_KQ >= D/2) {
+                    break;
+                }
+
+                const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
+                sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE];
+            }
+
+            sum2 = warp_reduce_sum(sum2);
+            half sum = __low2half(sum2) + __high2half(sum2);
+            sum += mask ? maskh[k_VKQ_0 + i_KQ] : __float2half(0.0f);
+            kqmax_new = __hmax(kqmax_new, sum);
+            if (threadIdx.x == 0) {
+                KQ[i_KQ] = sum;
+            }
+        }
+
+        kqmax_new = warp_reduce_max(kqmax_new);
+        if (threadIdx.x == 0) {
+            kqmax_shared[threadIdx.y] = kqmax_new;
+        }
+        __syncthreads();
+        kqmax_new = kqmax_shared[threadIdx.x];
+        kqmax_new = warp_reduce_max(kqmax_new);
+
+        const half KQ_max_scale = hexp(kqmax - kqmax_new);
+        kqmax = kqmax_new;
+
+        const half val = hexp(KQ[tid] - kqmax);
+        kqsum = kqsum*KQ_max_scale + val;
+        KQ[tid] = val;
+
+        VKQ *= __half2half2(KQ_max_scale);
+
+        __syncthreads();
+
+        if (tid < D) {
+#pragma unroll
+            for (int k0 = 0; k0 < D; k0 += 2) {
+                if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
+                    break;
+                }
+
+                half2 V_k;
+                reinterpret_cast<half&>(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid];
+                reinterpret_cast<half&>(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid];
+                VKQ += V_k*KQ2[k0/2];
+            }
+        }
+
+        __syncthreads();
+    }
+
+    if (tid >= D) {
+        kqsum = 0.0f;
+    }
+
+    kqsum = warp_reduce_sum(kqsum);
+    if (threadIdx.x == 0) {
+        kqsum_shared[threadIdx.y] = kqsum;
+    }
+    __syncthreads();
+    kqsum = kqsum_shared[threadIdx.x];
+    kqsum = warp_reduce_sum(kqsum);
+
+    if (tid >= D) {
+        return;
+    }
+
+    half dst_val = (__low2half(VKQ) + __high2half(VKQ));
+    if (parallel_blocks == 1) {
+        dst_val /= kqsum;
+    }
+    dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = dst_val;
+
+    if (parallel_blocks == 1 || tid != 0) {
+        return;
+    }
+    dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax, kqsum);
+#else
+   NO_DEVICE_CODE;
+#endif // FP16_AVAILABLE
+}
+
+// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
+template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
+__launch_bounds__(nwarps*WARP_SIZE, 1)
+static __global__ void flash_attn_ext_f16(
+        const char * __restrict__ Q,
+        const char * __restrict__ K,
+        const char * __restrict__ V,
+        const char * __restrict__ mask,
+        float      * __restrict__ dst,
+        float2     * __restrict__ dst_meta,
+        const float scale,
+        const int ne00,
+        const int ne01,
+        const int ne02,
+        const int ne03,
+        const int ne10,
+        const int ne11,
+        const int ne12,
+        const int ne13,
+        const int ne31,
+        const int nb31,
+        const int nb01,
+        const int nb02,
+        const int nb03,
+        const int nb11,
+        const int nb12,
+        const int nb13,
+        const int ne0,
+        const int ne1,
+        const int ne2,
+        const int ne3) {
+#if FP16_MMA_AVAILABLE
+    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
+    const int ip  =        blockIdx.x % parallel_blocks;  // Index in group of blocks running for the same column in parallel.
+
+    static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
+    static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
+    constexpr int frag_m = ncols == 8 ? 32 : 16;
+    constexpr int frag_n = ncols == 8 ?  8 : 16;
+    static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
+    typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a,    frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
+    typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a,    frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
+    typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b,    frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
+    typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t>                      frag_c_KQ;
+    typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half>                          frag_c_VKQ;
+
+    constexpr int KQ_stride_tc  = nwarps*frag_m; // Number of KQ rows calculated in parallel.
+    constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
+    static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
+
+    // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
+    constexpr int D_padded = D + 8;
+    constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
+    constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
+
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+    const float * Q_f   = (const float *) (Q + nb02* blockIdx.y              + nb01*ic0);
+    const half  * K_h   = (const half  *) (K + nb12*(blockIdx.y / gqa_ratio));
+    const half  * V_h   = (const half  *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
+    const half  * maskh = (const half  *)  mask + (nb31/sizeof(half))* ic0;
+    const half2 * mask2 = (const half2 *)  mask + (nb31/sizeof(half))*(ic0/2);
+
+    const int stride_Q  = nb01 / sizeof(float);
+    const int stride_KV = nb11 / sizeof(half);
+
+    frag_b Q_b[D/16][ncols/frag_n];
+
+    // A single buffer for temporarily holding tiles of KQ and VKQ parts:
+    constexpr int mem_KQ = ncols*kqs_padded*kqar;
+    constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
+    __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
+    float * KQ_f = (float *) KQ;
+    half2 * KQ2 = (half2 *) KQ;
+
+    float    KQ_rowsum_f[ncols/nwarps] = {0.0f};
+    float       KQ_max_f[ncols/nwarps];
+    float KQ_max_scale_f[ncols/nwarps] = {0.0f};
+
+#pragma unroll
+    for (int j = 0; j < ncols/nwarps; ++j) {
+        KQ_max_f[j] = -FLT_MAX/2.0f;
+    }
+
+    half2    KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
+    half2       KQ_max_h2[ncols/nwarps];
+    half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
+
+#pragma unroll
+    for (int j = 0; j < ncols/nwarps; ++j) {
+        KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
+    }
+
+    __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
+    half2 * VKQ2 = (half2 *) VKQ;
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+#pragma unroll
+        for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+            if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+                break;
+            }
+            VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
+        }
+    }
+
+    // Convert Q to half and apply scale, temporarily store in KQ:
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+#pragma unroll
+        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+            if (i0 + WARP_SIZE > D && i >= D) {
+                break;
+            }
+            KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
+        }
+    }
+
+    __syncthreads();
+
+    // Load Q into tensor core fragments/registers since it will be used frequently:
+#pragma unroll
+    for (int i0 = 0; i0 < D; i0 += 16) {
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+            nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
+        }
+    }
+
+    __syncthreads();
+
+    // Iterate over ne11 == previous tokens:
+    for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
+        // Calculate tile of KQ:
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
+            frag_c_KQ KQ_c[ncols/frag_n];
+#pragma unroll
+            for (int j = 0; j < ncols/frag_n; ++j) {
+                nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
+            }
+#pragma unroll
+            for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
+                frag_a_K K_a;
+                nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
+#pragma unroll
+                for (int j = 0; j < ncols/frag_n; ++j) {
+                    nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
+                }
+            }
+#pragma unroll
+            for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+                nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
+            }
+        }
+
+        __syncthreads();
+
+        // Calculate softmax for each KQ column using the current max. value.
+        // The divisor is stored in KQ_rowsum and will be applied at the end.
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            if (std::is_same<KQ_acc_t, float>::value) {
+                float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
+                }
+
+                float KQ_max_new = KQ_max_f[j0/nwarps];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
+                    KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
+                }
+                KQ_max_new = warp_reduce_max(KQ_max_new);
+
+                const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
+                KQ_max_scale_f[j0/nwarps] = expf(diff);
+                if (diff <= SOFTMAX_FTZ_THRESHOLD) {
+                    KQ_max_scale_f[j0/nwarps] = 0.0f;
+                }
+                KQ_max_f[j0/nwarps] = KQ_max_new;
+
+                float KQ_rowsum_add = 0.0f;
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
+                    KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
+                    if (diff <= SOFTMAX_FTZ_THRESHOLD) {
+                        KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
+                    }
+                    KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
+                    KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
+                }
+                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
+
+                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+                KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
+            } else {
+                half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
+                }
+
+                half2 KQ_max_new = KQ_max_h2[j0/nwarps];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
+                    KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
+                }
+                KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
+                const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
+                KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
+                const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
+                *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
+                KQ_max_h2[j0/nwarps] = KQ_max_new;
+
+                half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
+                    KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
+                    const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
+                    *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
+                    KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
+                    KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
+                }
+                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
+
+                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+                KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
+            }
+        }
+
+        __syncthreads();
+
+        frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+#pragma unroll
+            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
+                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
+                nvcuda::wmma::load_matrix_sync(
+                    KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
+                    KQ + j0*(kqar*kqs_padded) + k,
+                    kqar*kqs_padded);
+            }
+        }
+
+        frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
+#pragma unroll
+        for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
+#pragma unroll
+            for (int j = 0; j < ncols/frag_n; ++j) {
+                nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
+            }
+
+#pragma unroll
+            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
+                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
+
+                frag_a_V v_a;
+                nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
+#pragma unroll
+                for (int j = 0; j < ncols/frag_n; ++j) {
+                    nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
+                }
+            }
+        }
+
+        __syncthreads();
+
+        const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
+#pragma unroll
+            for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+                nvcuda::wmma::store_matrix_sync(
+                    KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
+                    VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
+                    D_padded, nvcuda::wmma::mem_col_major);
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            half2 VKQ_scale;
+            if (std::is_same<KQ_acc_t, float>::value) {
+                VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
+            } else {
+                VKQ_scale = KQ_max_scale_h2[j0/nwarps];
+            }
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+                if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+                    break;
+                }
+
+                half2 VKQ_add = make_half2(0.0f, 0.0f);
+#pragma unroll
+                for (int l = 0; l < VKQ_ratio; ++l) {
+                    VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
+                }
+                VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
+            }
+        }
+
+        __syncthreads();
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j_VKQ = j0 + threadIdx.y;
+        if (ic0 + j_VKQ >= ne01) {
+            return;
+        }
+        const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
+
+        float KQ_rowsum_j;
+        if (std::is_same<KQ_acc_t, float>::value) {
+            KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
+        } else {
+            KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+            if (i0 + WARP_SIZE > D && i >= D) {
+                break;
+            }
+            float dst_val = VKQ[j_VKQ*D_padded + i];
+            if (parallel_blocks == 1) {
+                dst_val /= KQ_rowsum_j;
+            }
+            dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
+        }
+
+        if (parallel_blocks == 1 || threadIdx.x != 0) {
+            continue;
+        }
+
+        float2 dst_meta_val;
+        if (std::is_same<KQ_acc_t, float>::value) {
+            dst_meta_val.x = KQ_max_f[j0/nwarps];
+        } else {
+            dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
+        }
+        dst_meta_val.y = KQ_rowsum_j;
+        dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
+    }
+#else
+   NO_DEVICE_CODE;
+#endif // FP16_MMA_AVAILABLE
+}
+
+template<int D, int parallel_blocks> // D == head size
+__launch_bounds__(D, 1)
+static __global__ void flash_attn_combine_results(
+        const float  * __restrict__ VKQ_parts,
+        const float2 * __restrict__ VKQ_meta,
+        float * __restrict__ dst) {
+#if FP16_AVAILABLE
+    VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
+    VKQ_meta  += parallel_blocks   * gridDim.y*blockIdx.x;
+    dst       +=                 D * gridDim.y*blockIdx.x;
+
+    const int tid = threadIdx.x;
+    __builtin_assume(tid < D);
+
+    __shared__ float2 meta[parallel_blocks];
+    if (tid < 2*parallel_blocks) {
+        ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
+    }
+
+    __syncthreads();
+
+    float kqmax = meta[0].x;
+#pragma unroll
+    for (int l = 1; l < parallel_blocks; ++l) {
+        kqmax = max(kqmax, meta[l].x);
+    }
+
+    float VKQ_numerator   = 0.0f;
+    float VKQ_denominator = 0.0f;
+#pragma unroll
+    for (int l = 0; l < parallel_blocks; ++l) {
+        const float diff = meta[l].x - kqmax;
+        const float KQ_max_scale = expf(diff);
+        const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
+        *((uint32_t *) &KQ_max_scale) &= ftz_mask;
+
+        VKQ_numerator   += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
+        VKQ_denominator += KQ_max_scale * meta[l].y;
+    }
+
+    dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
+#else
+   NO_DEVICE_CODE;
+#endif // FP16_AVAILABLE
+}
+
+constexpr int get_max_power_of_2(int x) {
+    return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
+}
+
+static_assert(get_max_power_of_2(1) == 1, "Test failed.");
+static_assert(get_max_power_of_2(2) == 2, "Test failed.");
+static_assert(get_max_power_of_2(4) == 4, "Test failed.");
+static_assert(get_max_power_of_2(6) == 2, "Test failed.");
+
+// Number of VKQ rows calculated in parallel:
+constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
+    return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
+}
+
+static_assert(get_VKQ_stride(128, 1, 32) ==  32, "Test failed.");
+static_assert(get_VKQ_stride(128, 2, 32) ==  64, "Test failed.");
+static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
+static_assert(get_VKQ_stride( 64, 1, 32) ==  32, "Test failed.");
+static_assert(get_VKQ_stride( 64, 2, 32) ==  64, "Test failed.");
+static_assert(get_VKQ_stride( 64, 4, 32) ==  64, "Test failed.");
+static_assert(get_VKQ_stride( 80, 1, 16) ==  16, "Test failed.");
+static_assert(get_VKQ_stride( 80, 2, 16) ==  16, "Test failed.");
+static_assert(get_VKQ_stride( 80, 4, 16) ==  16, "Test failed.");
+
+template <int D, int parallel_blocks> void launch_fattn_vec_f16(
+        const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
+        ggml_cuda_pool & pool, cudaStream_t main_stream
+) {
+    ggml_cuda_pool_alloc<float>  dst_tmp(pool);
+    ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
+
+    if (parallel_blocks > 1) {
+        dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
+        dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
+    }
+
+    constexpr int  nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
+    const     dim3 block_dim(WARP_SIZE, nwarps, 1);
+    const     dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]);
+    const     int  shmem = 0;
+
+    float scale;
+    memcpy(&scale, KQV->op_params, sizeof(float));
+
+    flash_attn_vec_ext_f16<D, parallel_blocks>
+        <<<blocks_num, block_dim, shmem, main_stream>>> (
+                (const char *) Q->data,
+                (const char *) K->data,
+                (const char *) V->data,
+                mask ? ((const char *) mask->data) : nullptr,
+                parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
+                scale,
+                Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
+                K->ne[0], K->ne[1], K->ne[2], K->ne[3],
+                mask ? mask->ne[1] : 0, mask ?  mask->nb[1] : 0,
+                Q->nb[1], Q->nb[2], Q->nb[3],
+                K->nb[1], K->nb[2], K->nb[3],
+                KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
+                );
+    CUDA_CHECK(cudaGetLastError());
+
+    if (parallel_blocks == 1) {
+        return;
+    }
+
+    const dim3 block_dim_combine(D, 1, 1);
+    const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
+    const int  shmem_combine = 0;
+
+    flash_attn_combine_results<D, parallel_blocks>
+        <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
+        (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
+    CUDA_CHECK(cudaGetLastError());
+}
+
+template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename KQ_acc_t> void launch_fattn_f16_impl(
+        const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
+        ggml_cuda_pool & pool, cudaStream_t main_stream
+) {
+    ggml_cuda_pool_alloc<float>  dst_tmp(pool);
+    ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
+
+    if (parallel_blocks > 1) {
+        dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
+        dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
+    }
+
+    constexpr int  frag_m = (cols_per_block) == 8 && (D) % 32 == 0 ? 32 : 16;
+    const     dim3 block_dim(WARP_SIZE, nwarps, 1);
+    const     dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
+    const     int  shmem = 0;
+
+    float scale;
+    memcpy(&scale, KQV->op_params, sizeof(float));
+
+    flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>
+        <<<blocks_num, block_dim, shmem, main_stream>>> (
+                (const char *) Q->data,
+                (const char *) K->data,
+                (const char *) V->data,
+                mask ? ((const char *) mask->data) : nullptr,
+                (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
+                scale,
+                Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
+                K->ne[0], K->ne[1], K->ne[2], K->ne[3],
+                mask ? mask->ne[1] : 0, mask ?  mask->nb[1] : 0,
+                Q->nb[1], Q->nb[2], Q->nb[3],
+                K->nb[1], K->nb[2], K->nb[3],
+                KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
+                );
+    CUDA_CHECK(cudaGetLastError());
+
+    if ((parallel_blocks) == 1) {
+        return;
+    }
+
+    const dim3 block_dim_combine(D, 1, 1);
+    const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
+    const int  shmem_combine = 0;
+
+    flash_attn_combine_results<D, parallel_blocks>
+        <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
+        (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
+    CUDA_CHECK(cudaGetLastError());
+}
+
+template <int D, int cols_per_block, int nwarps, typename KQ_acc_t> void launch_fattn_f16(
+        const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
+        const int nsm, ggml_cuda_pool & pool, cudaStream_t main_stream
+) {
+    const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
+
+    if (4*blocks_num_pb1 < 2*nsm) {
+        launch_fattn_f16_impl<D, cols_per_block, nwarps, 4, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
+        return;
+    }
+    if (2*blocks_num_pb1 < 2*nsm) {
+        launch_fattn_f16_impl<D, cols_per_block, nwarps, 2, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
+        return;
+    }
+    launch_fattn_f16_impl<D, cols_per_block, nwarps, 1, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
+}
+
+void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * Q = dst->src[0];
+    const ggml_tensor * K = dst->src[1];
+    const ggml_tensor * V = dst->src[2];
+
+    const ggml_tensor * mask = dst->src[3];
+
+    ggml_tensor * KQV = dst;
+
+    GGML_ASSERT(Q->type == GGML_TYPE_F32);
+    GGML_ASSERT(K->type == GGML_TYPE_F16);
+    GGML_ASSERT(V->type == GGML_TYPE_F16);
+    GGML_ASSERT(KQV->type == GGML_TYPE_F32);
+
+    GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
+    GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
+                                "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
+
+    GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
+
+    ggml_cuda_set_device(ctx.device);
+
+    const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
+
+    const int32_t precision = KQV->op_params[1];
+
+    if (precision != GGML_PREC_DEFAULT) {
+        if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
+            constexpr int cols_per_block = 16;
+            constexpr int nwarps         =  4;
+            switch (Q->ne[0]) {
+                case 64:
+                    launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    break;
+                case 80:
+                    launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    break;
+                case 96:
+                    launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    break;
+                case 112:
+                    launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    break;
+                case 128:
+                    launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    break;
+                case 256:
+                    launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    break;
+                default:
+                    GGML_ASSERT(false);
+                    break;
+            }
+        } else {
+            constexpr int cols_per_block = 32;
+            constexpr int nwarps         =  4;
+            switch (Q->ne[0]) {
+                case 64:
+                    launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    break;
+                case 80:
+                    launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    break;
+                case 96:
+                    launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    break;
+                case 112:
+                    launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    break;
+                case 128:
+                    launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    break;
+                // case 256:
+                //     launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                //     break;
+                default:
+                    GGML_ASSERT(false);
+                    break;
+            }
+        }
+        return;
+    }
+
+    if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
+        constexpr int parallel_blocks = 4;
+        switch (Q->ne[0]) {
+            case 64:
+                launch_fattn_vec_f16< 64, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
+                break;
+            case 128:
+                launch_fattn_vec_f16<128, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
+                break;
+            case 256:
+                launch_fattn_vec_f16<256, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
+                break;
+            default:
+                GGML_ASSERT(false);
+                break;
+        }
+        return;
+    }
+
+    if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
+        constexpr int cols_per_block = 8;
+        constexpr int nwarps         = 4;
+        switch (Q->ne[0]) {
+            case 64:
+                launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                break;
+            case 96:
+                launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                break;
+            case 128:
+                launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                break;
+            case 256:
+                launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                break;
+            default:
+                GGML_ASSERT(false);
+                break;
+        }
+        return;
+    }
+
+    if (Q->ne[1] <= 32) {
+        constexpr int cols_per_block = 16;
+        constexpr int nwarps         =  4;
+        switch (Q->ne[0]) {
+            case 64:
+                launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                break;
+            case 80:
+                launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                break;
+            case 96:
+                launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                break;
+            case 112:
+                launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                break;
+            case 128:
+                launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                break;
+            case 256:
+                launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                break;
+            default:
+                GGML_ASSERT(false);
+                break;
+        }
+        return;
+    }
+
+    constexpr int cols_per_block = 32;
+    constexpr int nwarps         =  4;
+    switch (Q->ne[0]) {
+        case 64:
+            launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+            break;
+        case 80:
+            launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+            break;
+        case 96:
+            launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+            break;
+        case 112:
+            launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+            break;
+        case 128:
+            launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+            break;
+        case 256:
+            launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+            break;
+        default:
+            GGML_ASSERT(false);
+            break;
+    }
+    return;
+}
diff --git a/ggml-cuda/fattn.cuh b/ggml-cuda/fattn.cuh
new file mode 100644 (file)
index 0000000..ad3ca7a
--- /dev/null
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index fa8f987cf7c1d4e3c26fe60bca485556f9d85e02..6ed225999bddfc914118a6fda3536222047bf6aa 100644 (file)
@@ -1,7 +1,17 @@
 #include "softmax.cuh"
 
-template <bool vals_smem, int ncols_template, int block_size_template>
-static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
+template <typename T>
+static __device__ __forceinline__ float t2f32(T val) {
+    return (float) val;
+}
+
+template <>
+__device__ float __forceinline__ t2f32<half>(half val) {
+    return __half2float(val);
+}
+
+template <bool vals_smem, int ncols_template, int block_size_template, typename T>
+static __global__ void soft_max_f32(const float * x, const T * mask, const T * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
     const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
 
     const int tid  = threadIdx.x;
@@ -43,7 +53,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
         const int64_t ix = (int64_t)rowx*ncols + col;
         const int64_t iy = (int64_t)rowy*ncols + col;
 
-        const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
+        const float val = x[ix]*scale + (mask ? t2f32(mask[iy]) : 0.0f) + (pos ? slope*t2f32(pos[col]) : 0.0f);
 
         vals[col] = val;
         max_val = max(max_val, val);
@@ -114,7 +124,8 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
     }
 }
 
-static void soft_max_f32_cuda(const float * x, const float * mask, const float * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
+template<typename T>
+static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
     int nth = WARP_SIZE;
     while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
     const dim3 block_dims(nth,     1, 1);
@@ -167,15 +178,19 @@ static void soft_max_f32_cuda(const float * x, const float * mask, const float *
 void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
+    const ggml_tensor * src2 = dst->src[2];
+
     const float * src0_d = (const float *)src0->data;
-    const float * src1_d = src1 ? (const float *)src1->data : nullptr;
+    const void  * src1_d = src1 ? (const void *)src1->data : nullptr;
+
     float * dst_d = (float *)dst->data;
     cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
+    GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
+    GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
 
     const int64_t ne00    = src0->ne[0];
     const int64_t nrows_x = ggml_nrows(src0);
@@ -188,14 +203,25 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
 
     // positions tensor
-    float * src2_dd = nullptr;
+    void * src2_d = nullptr;
 
-    ggml_tensor * src2 = dst->src[2];
     const bool use_src2 = src2 != nullptr;
 
     if (use_src2) {
-        src2_dd = (float *)src2->data;
+        src2_d = (void *)src2->data;
     }
 
-    soft_max_f32_cuda(src0_d, src1_d, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+    const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
+
+    if (use_f16) {
+        const half * src1_dd = (const half *)src1_d;
+        const half * src2_dd = (const half *)src2_d;
+
+        soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+    } else {
+        const float * src1_dd = (const float *)src1_d;
+        const float * src2_dd = (const float *)src2_d;
+
+        soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+    }
 }
index 407062e6fd47625d6cb2f78e17649f64387d2866..9a469821d804214afb3d99bb83a7cfa31e093349 100644 (file)
@@ -1427,6 +1427,7 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
         for (int i = node_start; i < node_end; ++i) {
             struct ggml_tensor * src0 = gf->nodes[i]->src[0];
             struct ggml_tensor * src1 = gf->nodes[i]->src[1];
+            struct ggml_tensor * src2 = gf->nodes[i]->src[2]; GGML_UNUSED(src2);
             struct ggml_tensor * dst = gf->nodes[i];
             GGML_ASSERT(dst->data != nullptr);
 
@@ -1559,6 +1560,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
                     {
                         float scale;
                         memcpy(&scale, dst->op_params, sizeof(float));
+
+#pragma message("TODO: add ggml_vk_soft_max() F16/F32 src1 and src2 support")
+#pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/5021")
+                        GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
+                        GGML_ASSERT(src2 == nullptr);
+
                         ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
                     } break;
                 case GGML_OP_DIAG_MASK_INF:
index fdba0de85bcdbb5d686ce97362901f55135725f0..71b8a099b7e14b60c44f89879c76e20b06f307d2 100644 (file)
@@ -47,8 +47,10 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
     GGML_METAL_KERNEL_TYPE_SILU,
     GGML_METAL_KERNEL_TYPE_SILU_4,
-    GGML_METAL_KERNEL_TYPE_SOFT_MAX,
-    GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
+    GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
+    GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
+    GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
+    GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,
     GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
     GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
@@ -178,6 +180,14 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
     GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
     GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
     GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
     GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
     GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
@@ -444,7 +454,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         }
 
         /*
-            GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
+            GGML_METAL_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
                     (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
                     (int) kernel->pipeline.threadExecutionWidth); \
         */
@@ -460,173 +470,183 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
                 return NULL; \
             } \
         } else { \
-            GGML_METAL_LOG_WARN("%s: skipping %-32s (not supported)\n", __func__, "kernel_"#name); \
+            GGML_METAL_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
         }
 
         // simd_sum and simd_max requires MTLGPUFamilyApple7
 
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD,                       add,                    true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW,                   add_row,                true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL,                       mul,                    true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW,                   mul_row,                true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV,                       div,                    true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW,                   div_row,                true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE,                     scale,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4,                   scale_4,                true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP,                     clamp,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH,                      tanh,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU,                      relu,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID,                   sigmoid,                true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU,                      gelu,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4,                    gelu_4,                 true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK,                gelu_quick,             true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,              gelu_quick_4,           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU,                      silu,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4,                    silu_4,                 true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX,                  soft_max,               ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,                soft_max_4,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,             diag_mask_inf,          true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,           diag_mask_inf_8,        true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,              get_rows_f32,           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,              get_rows_f16,           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,             get_rows_q4_0,          true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,             get_rows_q4_1,          true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,             get_rows_q5_0,          true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,             get_rows_q5_1,          true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,             get_rows_q8_0,          true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,             get_rows_q2_K,          true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,             get_rows_q3_K,          true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,             get_rows_q4_K,          true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,             get_rows_q5_K,          true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,             get_rows_q6_K,          true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,          get_rows_iq2_xxs,       true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,           get_rows_iq2_xs,        true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,          get_rows_iq3_xxs,       true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,            get_rows_iq3_s,         true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,            get_rows_iq2_s,         true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,            get_rows_iq1_s,         true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,            get_rows_iq1_m,         true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,           get_rows_iq4_nl,        true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,           get_rows_iq4_xs,        true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,              get_rows_i32,           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                  rms_norm,               ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                group_norm,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM,                      norm,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,            mul_mv_f32_f32,         ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,            mul_mv_f16_f16,         ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,            mul_mv_f16_f32,         ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,       mul_mv_f16_f32_1row,    ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,         mul_mv_f16_f32_l4,      ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,           mul_mv_q4_0_f32,        ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,           mul_mv_q4_1_f32,        ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,           mul_mv_q5_0_f32,        ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,           mul_mv_q5_1_f32,        ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,           mul_mv_q8_0_f32,        ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,           mul_mv_q2_K_f32,        ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,           mul_mv_q3_K_f32,        ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,           mul_mv_q4_K_f32,        ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,           mul_mv_q5_K_f32,        ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,           mul_mv_q6_K_f32,        ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,        mul_mv_iq2_xxs_f32,     ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,         mul_mv_iq2_xs_f32,      ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,        mul_mv_iq3_xxs_f32,     ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,          mul_mv_iq3_s_f32,       ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,          mul_mv_iq2_s_f32,       ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,          mul_mv_iq1_s_f32,       ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,          mul_mv_iq1_m_f32,       ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,         mul_mv_iq4_nl_f32,      ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,         mul_mv_iq4_xs_f32,      ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,         mul_mv_id_f32_f32,      ctx->support_simdgroup_reduction);
-      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,         mul_mv_id_f16_f16,      ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,         mul_mv_id_f16_f32,      ctx->support_simdgroup_reduction);
-      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,    mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
-      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,      mul_mv_id_f16_f32_l4,   ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,        mul_mv_id_q4_0_f32,     ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,        mul_mv_id_q4_1_f32,     ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,        mul_mv_id_q5_0_f32,     ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,        mul_mv_id_q5_1_f32,     ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,        mul_mv_id_q8_0_f32,     ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,        mul_mv_id_q2_K_f32,     ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,        mul_mv_id_q3_K_f32,     ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,        mul_mv_id_q4_K_f32,     ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,        mul_mv_id_q5_K_f32,     ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,        mul_mv_id_q6_K_f32,     ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,     mul_mv_id_iq2_xxs_f32,  ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,      mul_mv_id_iq2_xs_f32,   ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,     mul_mv_id_iq3_xxs_f32,  ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,       mul_mv_id_iq3_s_f32,    ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,       mul_mv_id_iq2_s_f32,    ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,       mul_mv_id_iq1_s_f32,    ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,       mul_mv_id_iq1_m_f32,    ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,      mul_mv_id_iq4_nl_f32,   ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,      mul_mv_id_iq4_xs_f32,   ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,            mul_mm_f32_f32,         ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,            mul_mm_f16_f32,         ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,           mul_mm_q4_0_f32,        ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,           mul_mm_q4_1_f32,        ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,           mul_mm_q5_0_f32,        ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,           mul_mm_q5_1_f32,        ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,           mul_mm_q8_0_f32,        ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,           mul_mm_q2_K_f32,        ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,           mul_mm_q3_K_f32,        ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,           mul_mm_q4_K_f32,        ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,           mul_mm_q5_K_f32,        ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,           mul_mm_q6_K_f32,        ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,        mul_mm_iq2_xxs_f32,     ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,         mul_mm_iq2_xs_f32,      ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,        mul_mm_iq3_xxs_f32,     ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,          mul_mm_iq3_s_f32,       ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,          mul_mm_iq2_s_f32,       ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,          mul_mm_iq1_s_f32,       ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,          mul_mm_iq1_m_f32,       ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,         mul_mm_iq4_nl_f32,      ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,         mul_mm_iq4_xs_f32,      ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,         mul_mm_id_f32_f32,      ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,         mul_mm_id_f16_f32,      ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,        mul_mm_id_q4_0_f32,     ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,        mul_mm_id_q4_1_f32,     ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,        mul_mm_id_q5_0_f32,     ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,        mul_mm_id_q5_1_f32,     ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,        mul_mm_id_q8_0_f32,     ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,        mul_mm_id_q2_K_f32,     ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,        mul_mm_id_q3_K_f32,     ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,        mul_mm_id_q4_K_f32,     ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,        mul_mm_id_q5_K_f32,     ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,        mul_mm_id_q6_K_f32,     ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,     mul_mm_id_iq2_xxs_f32,  ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,      mul_mm_id_iq2_xs_f32,   ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,     mul_mm_id_iq3_xxs_f32,  ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,       mul_mm_id_iq3_s_f32,    ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,       mul_mm_id_iq2_s_f32,    ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,       mul_mm_id_iq1_s_f32,    ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,       mul_mm_id_iq1_m_f32,    ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,      mul_mm_id_iq4_nl_f32,   ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,      mul_mm_id_iq4_xs_f32,   ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32,                  rope_f32,               true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16,                  rope_f16,               true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32,                 alibi_f32,              true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16,                im2col_f16,             true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32,                im2col_f32,             true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,               upscale_f32,            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32,                   pad_f32,                true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,    timestep_embedding_f32, true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32,                arange_f32,             true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,       argsort_f32_i32_asc,    true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,      argsort_f32_i32_desc,   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,            leaky_relu_f32,         true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16,               cpy_f32_f16,            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32,               cpy_f32_f32,            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,              cpy_f32_q8_0,           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,              cpy_f32_q4_0,           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,              cpy_f32_q4_1,           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,              cpy_f32_q5_0,           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,              cpy_f32_q5_1,           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,            cpy_f32_iq4_nl,         true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16,               cpy_f16_f16,            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32,               cpy_f16_f32,            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT,                    concat,                 true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR,                       sqr,                    true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                  sum_rows,               true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD,                           add,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW,                       add_row,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL,                           mul,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW,                       mul_row,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV,                           div,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW,                       div_row,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE,                         scale,                          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4,                       scale_4,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP,                         clamp,                          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH,                          tanh,                           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU,                          relu,                           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID,                       sigmoid,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU,                          gelu,                           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4,                        gelu_4,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK,                    gelu_quick,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,                  gelu_quick_4,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU,                          silu,                           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4,                        silu_4,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,                  soft_max_f16,                   ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,                soft_max_f16_4,                 ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,                  soft_max_f32,                   ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,                soft_max_f32_4,                 ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,                 diag_mask_inf,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,               diag_mask_inf_8,                true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,                  get_rows_f32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,                  get_rows_f16,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,                 get_rows_q4_0,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,                 get_rows_q4_1,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,                 get_rows_q5_0,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,                 get_rows_q5_1,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,                 get_rows_q8_0,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,                 get_rows_q2_K,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,                 get_rows_q3_K,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,                 get_rows_q4_K,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,                 get_rows_q5_K,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,                 get_rows_q6_K,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,              get_rows_iq2_xxs,               true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,               get_rows_iq2_xs,                true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,              get_rows_iq3_xxs,               true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,                get_rows_iq3_s,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,                get_rows_iq2_s,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,                get_rows_iq1_s,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,                get_rows_iq1_m,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,               get_rows_iq4_nl,                true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,               get_rows_iq4_xs,                true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,                  get_rows_i32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                      rms_norm,                       ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                    group_norm,                     ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM,                          norm,                           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,                mul_mv_f32_f32,                 ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,                mul_mv_f16_f16,                 ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,                mul_mv_f16_f32,                 ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,           mul_mv_f16_f32_1row,            ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,             mul_mv_f16_f32_l4,              ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,               mul_mv_q4_0_f32,                ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,               mul_mv_q4_1_f32,                ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,               mul_mv_q5_0_f32,                ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,               mul_mv_q5_1_f32,                ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,               mul_mv_q8_0_f32,                ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,               mul_mv_q2_K_f32,                ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,               mul_mv_q3_K_f32,                ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,               mul_mv_q4_K_f32,                ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,               mul_mv_q5_K_f32,                ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,               mul_mv_q6_K_f32,                ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,            mul_mv_iq2_xxs_f32,             ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,             mul_mv_iq2_xs_f32,              ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,            mul_mv_iq3_xxs_f32,             ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,              mul_mv_iq3_s_f32,               ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,              mul_mv_iq2_s_f32,               ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,              mul_mv_iq1_s_f32,               ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,              mul_mv_iq1_m_f32,               ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,             mul_mv_iq4_nl_f32,              ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,             mul_mv_iq4_xs_f32,              ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,             mul_mv_id_f32_f32,              ctx->support_simdgroup_reduction);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,             mul_mv_id_f16_f16,              ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,             mul_mv_id_f16_f32,              ctx->support_simdgroup_reduction);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,        mul_mv_id_f16_f32_1row,         ctx->support_simdgroup_reduction);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,          mul_mv_id_f16_f32_l4,           ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,            mul_mv_id_q4_0_f32,             ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,            mul_mv_id_q4_1_f32,             ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,            mul_mv_id_q5_0_f32,             ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,            mul_mv_id_q5_1_f32,             ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,            mul_mv_id_q8_0_f32,             ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,            mul_mv_id_q2_K_f32,             ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,            mul_mv_id_q3_K_f32,             ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,            mul_mv_id_q4_K_f32,             ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,            mul_mv_id_q5_K_f32,             ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,            mul_mv_id_q6_K_f32,             ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,         mul_mv_id_iq2_xxs_f32,          ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,          mul_mv_id_iq2_xs_f32,           ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,         mul_mv_id_iq3_xxs_f32,          ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,           mul_mv_id_iq3_s_f32,            ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,           mul_mv_id_iq2_s_f32,            ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,           mul_mv_id_iq1_s_f32,            ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,           mul_mv_id_iq1_m_f32,            ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,          mul_mv_id_iq4_nl_f32,           ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,          mul_mv_id_iq4_xs_f32,           ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,                mul_mm_f32_f32,                 ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,                mul_mm_f16_f32,                 ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,               mul_mm_q4_0_f32,                ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,               mul_mm_q4_1_f32,                ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,               mul_mm_q5_0_f32,                ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,               mul_mm_q5_1_f32,                ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,               mul_mm_q8_0_f32,                ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,               mul_mm_q2_K_f32,                ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,               mul_mm_q3_K_f32,                ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,               mul_mm_q4_K_f32,                ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,               mul_mm_q5_K_f32,                ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,               mul_mm_q6_K_f32,                ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,            mul_mm_iq2_xxs_f32,             ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,             mul_mm_iq2_xs_f32,              ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,            mul_mm_iq3_xxs_f32,             ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,              mul_mm_iq3_s_f32,               ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,              mul_mm_iq2_s_f32,               ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,              mul_mm_iq1_s_f32,               ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,              mul_mm_iq1_m_f32,               ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,             mul_mm_iq4_nl_f32,              ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,             mul_mm_iq4_xs_f32,              ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,             mul_mm_id_f32_f32,              ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,             mul_mm_id_f16_f32,              ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,            mul_mm_id_q4_0_f32,             ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,            mul_mm_id_q4_1_f32,             ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,            mul_mm_id_q5_0_f32,             ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,            mul_mm_id_q5_1_f32,             ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,            mul_mm_id_q8_0_f32,             ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,            mul_mm_id_q2_K_f32,             ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,            mul_mm_id_q3_K_f32,             ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,            mul_mm_id_q4_K_f32,             ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,            mul_mm_id_q5_K_f32,             ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,            mul_mm_id_q6_K_f32,             ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,         mul_mm_id_iq2_xxs_f32,          ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,          mul_mm_id_iq2_xs_f32,           ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,         mul_mm_id_iq3_xxs_f32,          ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,           mul_mm_id_iq3_s_f32,            ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,           mul_mm_id_iq2_s_f32,            ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,           mul_mm_id_iq1_s_f32,            ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,           mul_mm_id_iq1_m_f32,            ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,          mul_mm_id_iq4_nl_f32,           ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,          mul_mm_id_iq4_xs_f32,           ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32,                      rope_f32,                       true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16,                      rope_f16,                       true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32,                     alibi_f32,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16,                    im2col_f16,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32,                    im2col_f32,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,                   upscale_f32,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32,                       pad_f32,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,        timestep_embedding_f32,         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32,                    arange_f32,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,           argsort_f32_i32_asc,            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,          argsort_f32_i32_desc,           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,                leaky_relu_f32,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,        flash_attn_ext_f16_h64,         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,        flash_attn_ext_f16_h80,         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,        flash_attn_ext_f16_h96,         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,       flash_attn_ext_f16_h112,        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,       flash_attn_ext_f16_h128,        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,       flash_attn_ext_f16_h256,        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,   flash_attn_ext_vec_f16_h128,    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16,                   cpy_f32_f16,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32,                   cpy_f32_f32,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,                  cpy_f32_q8_0,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,                  cpy_f32_q4_0,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,                  cpy_f32_q4_1,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,                  cpy_f32_q5_0,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,                  cpy_f32_q5_1,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,                cpy_f32_iq4_nl,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16,                   cpy_f16_f16,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32,                   cpy_f16_f32,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT,                        concat,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR,                           sqr,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                      sum_rows,                       true);
     }
 
     [metal_library release];
@@ -746,6 +766,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_ARGSORT:
         case GGML_OP_LEAKY_RELU:
+        case GGML_OP_FLASH_ATTN_EXT:
             return true;
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
@@ -1341,20 +1362,33 @@ static enum ggml_status ggml_metal_graph_compute(
                     } break;
                 case GGML_OP_SOFT_MAX:
                     {
+                        GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
+                        GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
+
                         int nth = 32; // SIMD width
 
                         id<MTLComputePipelineState> pipeline = nil;
 
+                        const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
+
                         if (ne00%4 == 0) {
                             while (nth < ne00/4 && nth < 256) {
                                 nth *= 2;
                             }
-                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline;
+                            if (use_f16) {
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline;
+                            } else {
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
+                            }
                         } else {
                             while (nth < ne00 && nth < 1024) {
                                 nth *= 2;
                             }
-                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
+                            if (use_f16) {
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline;
+                            } else {
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline;
+                            }
                         }
 
                         float scale;
@@ -2518,6 +2552,161 @@ static enum ggml_status ggml_metal_graph_compute(
 
                         [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                     } break;
+                case GGML_OP_FLASH_ATTN_EXT:
+                    {
+                        GGML_ASSERT(ne00 % 4 == 0);
+                        GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+                        struct ggml_tensor * src3 = gf->nodes[i]->src[3];
+
+                        GGML_ASSERT(ggml_are_same_shape(src1, src2));
+                        GGML_ASSERT(src3);
+
+                        size_t offs_src3 = 0;
+
+                        id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
+
+                        GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
+                        GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
+                                "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
+
+                        const int64_t  ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
+                        const int64_t  ne31 = src3 ? src3->ne[1] : 0;
+                        const int64_t  ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
+                        const int64_t  ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
+
+                        const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30);
+                        const uint64_t nb31 = src3 ? src3->nb[1] : 0;
+                        const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
+                        const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
+
+                        const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
+
+                        float scale;
+                        memcpy(&scale, dst->op_params, sizeof(float));
+
+                        id<MTLComputePipelineState> pipeline = nil;
+
+                        bool use_vec_kernel = false;
+
+                        if (ne01 >= 4 || (ne00%128 != 0)) {
+                            switch (ne00) {
+                                case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
+                                case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
+                                case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
+                                case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
+                                case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
+                                case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
+                                default:
+                                          {
+                                              GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                              GGML_METAL_LOG_ERROR("add template specialization for this size\n");
+                                              GGML_ASSERT(false && "add template specialization for this size");
+                                          }
+                            }
+                        } else {
+                            use_vec_kernel = true;
+
+                            switch (ne00) {
+                                case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
+                                case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
+                                default:
+                                          {
+                                              GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                              GGML_METAL_LOG_ERROR("add template specialization for this size\n");
+                                              GGML_ASSERT(false && "add template specialization for this size");
+                                          }
+                            }
+                        }
+
+                        [encoder setComputePipelineState:pipeline];
+                        [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
+                        [encoder setBuffer:id_src1 offset:offs_src1        atIndex:1];
+                        [encoder setBuffer:id_src2 offset:offs_src2        atIndex:2];
+                        [encoder setBuffer:id_src3 offset:offs_src3        atIndex:3];
+                        [encoder setBuffer:id_dst  offset:offs_dst         atIndex:4];
+                        [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:5];
+                        [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:6];
+                        [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:7];
+                        [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:8];
+                        [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:9];
+                        [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:10];
+                        [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:11];
+                        [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:12];
+                        [encoder setBytes:&ne10    length:sizeof( int64_t) atIndex:13];
+                        [encoder setBytes:&ne11    length:sizeof( int64_t) atIndex:14];
+                        [encoder setBytes:&ne12    length:sizeof( int64_t) atIndex:15];
+                        [encoder setBytes:&ne13    length:sizeof( int64_t) atIndex:16];
+                        [encoder setBytes:&nb10    length:sizeof(uint64_t) atIndex:17];
+                        [encoder setBytes:&nb11    length:sizeof(uint64_t) atIndex:18];
+                        [encoder setBytes:&nb12    length:sizeof(uint64_t) atIndex:19];
+                        [encoder setBytes:&nb13    length:sizeof(uint64_t) atIndex:20];
+                        [encoder setBytes:&ne31    length:sizeof( int64_t) atIndex:21];
+                        [encoder setBytes:&nb31    length:sizeof(uint64_t) atIndex:22];
+                        [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:23];
+                        [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:24];
+                        [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:25];
+                        [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:26];
+                        [encoder setBytes:&scale   length:sizeof(   float) atIndex:27];
+
+                        if (!use_vec_kernel) {
+                            // half8x8 kernel
+                            const int64_t nqptg = 8;  // queries per threadgroup    !! sync with kernel template arguments !!
+                            const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
+
+                            GGML_ASSERT(nqptg <= 32);
+                            GGML_ASSERT(nqptg  % 8  == 0);
+                            GGML_ASSERT(ncpsg  % 32 == 0);
+
+                            int64_t nsgmax = 2;
+
+                            while (true) {
+                                const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
+                                if (smem > ctx->device.maxThreadgroupMemoryLength) {
+                                    break;
+                                }
+                                nsgmax *= 2;
+                            }
+                            nsgmax /= 2;
+
+                            // simdgroups per threadgroup (a.k.a. warps)
+                            const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
+
+                            const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
+
+                            //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
+                            GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
+
+                            [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+                        } else {
+                            // half1x4 kernel
+                            const int64_t nqptg = 1;  // queries per threadgroup    !! sync with kernel template arguments !!
+                            const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
+
+                            GGML_ASSERT(nqptg <= 32);
+                            GGML_ASSERT(nqptg  % 1  == 0);
+                            GGML_ASSERT(ncpsg  % 32 == 0);
+
+                            // simdgroups per threadgroup (a.k.a. warps)
+                            const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
+
+                            int64_t nsg = 1;
+                            while (nsg <= nsgt) {
+                                nsg *= 2;
+                            }
+                            nsg /= 2;
+
+                            const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
+
+                            //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
+                            GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
+                            [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+                        }
+                    } break;
                 case GGML_OP_DUP:
                 case GGML_OP_CPY:
                 case GGML_OP_CONT:
@@ -2721,10 +2910,13 @@ GGML_CALL static const char * ggml_backend_metal_buffer_type_get_name(ggml_backe
     UNUSED(buft);
 }
 
-static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device) {
+static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
+#ifndef GGML_METAL_NDEBUG
 #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
     if (@available(macOS 10.12, iOS 16.0, *)) {
-        GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
+        GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)",
+                __func__,
+                size_aligned / 1024.0 / 1024.0,
                 device.currentAllocatedSize / 1024.0 / 1024.0,
                 device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
 
@@ -2734,10 +2926,15 @@ static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device) {
             GGML_METAL_LOG_INFO("\n");
         }
     } else {
-        GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
+        GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n",
+                __func__,
+                size_aligned / 1024.0 / 1024.0,
+                device.currentAllocatedSize / 1024.0 / 1024.0);
     }
+#endif
 #endif
     UNUSED(device);
+    UNUSED(size_aligned);
 }
 
 GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
@@ -2771,8 +2968,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff
         return NULL;
     }
 
-    GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
-    ggml_backend_metal_log_allocated_size(device);
+    //ggml_backend_metal_log_allocated_size(device, size_aligned);
 
     return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
 }
@@ -2859,7 +3055,7 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
             return false;
         }
 
-        GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
+        ggml_backend_metal_log_allocated_size(device, size_aligned);
 
         ++ctx->n_buffers;
     } else {
@@ -2882,7 +3078,8 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
                 return false;
             }
 
-            GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, offs = %12ld", __func__, size_step_aligned / 1024.0 / 1024.0, i);
+            ggml_backend_metal_log_allocated_size(device, size_step_aligned);
+
             if (i + size_step < size) {
                 GGML_METAL_LOG_INFO("\n");
             }
@@ -2891,8 +3088,6 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
         }
     }
 
-    ggml_backend_metal_log_allocated_size(device);
-
     return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
 }
 
index 7f37c17d668a82a614e3af4fb213c12b4a9d41b6..4d710b04fa270ed3284055554d30cbbd4d03e89b 100644 (file)
@@ -359,11 +359,12 @@ kernel void kernel_sum_rows(
     dst_row[0] = row_sum;
 }
 
+template<typename T>
 kernel void kernel_soft_max(
-        device const float * src0,
-        device const float * src1,
-        device const float * src2,
-        device       float * dst,
+        device const  char * src0,
+        device const  char * src1,
+        device const  char * src2,
+        device        char * dst,
         constant   int64_t & ne00,
         constant   int64_t & ne01,
         constant   int64_t & ne02,
@@ -382,10 +383,10 @@ kernel void kernel_soft_max(
     const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
     const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
 
-    device const float * psrc0 =         src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-    device const float * pmask = src1 != src0 ? src1                               + i01*ne00 : nullptr;
-    device const float * ppos  = src2 != src0 ? src2                                          : nullptr;
-    device       float * pdst  =         dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    device const     T * pmask = src1 != src0 ? (device const    T *) src1         + i01*ne00 : nullptr;
+    device const     T * ppos  = src2 != src0 ? (device const    T *) src2                    : nullptr;
+    device       float * pdst  = (device       float *) dst  + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
 
     float slope = 0.0f;
 
@@ -463,11 +464,12 @@ kernel void kernel_soft_max(
     }
 }
 
+template<typename T>
 kernel void kernel_soft_max_4(
-        device const float * src0,
-        device const float * src1,
-        device const float * src2,
-        device       float * dst,
+        device const  char * src0,
+        device const  char * src1,
+        device const  char * src2,
+        device        char * dst,
         constant   int64_t & ne00,
         constant   int64_t & ne01,
         constant   int64_t & ne02,
@@ -486,10 +488,10 @@ kernel void kernel_soft_max_4(
     const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
     const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
 
-    device const float4 * psrc4 =                (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
-    device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 +                                      i01*ne00) : nullptr;
-    device const float4 * ppos  = src2 != src0 ? (device const float4 *)(src2)                                                 : nullptr;
-    device       float4 * pdst4 =                (device       float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
+    device const      T * pmask = src1 != src0 ? (device const     T *) src1         + i01*ne00/4 : nullptr;
+    device const      T * ppos  = src2 != src0 ? (device const     T *) src2                      : nullptr;
+    device       float4 * pdst4 = (device       float4 *) dst  + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
 
     float slope = 0.0f;
 
@@ -506,7 +508,7 @@ kernel void kernel_soft_max_4(
     float4 lmax4 = -INFINITY;
 
     for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
+        lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)));
     }
 
     const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@@ -532,7 +534,7 @@ kernel void kernel_soft_max_4(
     // parallel sum
     float4 lsum4 = 0.0f;
     for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
+        const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val);
         lsum4 += exp_psrc4;
         pdst4[i00] = exp_psrc4;
     }
@@ -569,6 +571,14 @@ kernel void kernel_soft_max_4(
     }
 }
 
+typedef decltype(kernel_soft_max<float>)    kernel_soft_max_t;
+typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
+
+template [[host_name("kernel_soft_max_f16")]]   kernel kernel_soft_max_t   kernel_soft_max<half>;
+template [[host_name("kernel_soft_max_f32")]]   kernel kernel_soft_max_t   kernel_soft_max<float>;
+template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
+template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
+
 kernel void kernel_diag_mask_inf(
         device const float * src0,
         device       float * dst,
@@ -2091,6 +2101,632 @@ kernel void kernel_leaky_relu_f32(
     dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
 }
 
+typedef void (flash_attn_ext_f16_t)(
+        device const  char * q,
+        device const  char * k,
+        device const  char * v,
+        device const  char * mask,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant   int64_t & ne13,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
+        constant   int64_t & ne31,
+        constant  uint64_t & nb31,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant     float & scale,
+        threadgroup   half * shared,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        uint3  tpitg[[thread_position_in_threadgroup]],
+        uint3    ntg[[threads_per_threadgroup]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]);
+
+// ref: https://arxiv.org/pdf/2307.08691.pdf
+template<int64_t D, int64_t Q = 8, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
+kernel void kernel_flash_attn_ext_f16(
+        device const  char * q,
+        device const  char * k,
+        device const  char * v,
+        device const  char * mask,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant   int64_t & ne13,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
+        constant   int64_t & ne31,
+        constant  uint64_t & nb31,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant     float & scale,
+        threadgroup   half * shared [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        uint3  tpitg[[thread_position_in_threadgroup]],
+        uint3    ntg[[threads_per_threadgroup]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+    const short nsg = ntg.y; // number of simdgroups
+
+    const short iq3 = tgpig[2];
+    const short iq2 = tgpig[1];
+    const short iq1 = tgpig[0]*Q;
+
+    const short D4 = D/4;
+    const short D8 = D/8;
+    const short Q8 = Q/8;
+    const short NW = N_SIMDWIDTH;
+    const short SH = (C + Q); // shared memory per simdgroup in (half)
+
+    const short T  = D + 2*nsg*SH; // shared memory size per query in (half)
+    const short TF = T/2;        // shared memory size per query in (float)
+    const short T4 = T/4;        // shared memory size per query in (half4)
+
+    threadgroup half  * sq  = (threadgroup half  *) (shared +              0*D); // holds the query data
+    threadgroup half4 * sq4 = (threadgroup half4 *) (shared +              0*D); // same as above but in half4
+    threadgroup float * ss  = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
+
+    // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
+    simdgroup_half8x8 lo[D8];
+
+    // load heads from Q to shared memory
+    for (short j = sgitg; j < Q; j += nsg) {
+        device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
+
+        for (short i = tiisg; i < D4; i += NW) {
+            if (iq1 + j < ne01) {
+                sq4[j*T4 + i] = (half4) q4[i];
+            } else {
+                sq4[j*T4 + i] = 0.0h;
+            }
+        }
+    }
+
+    // zero out lo
+    for (short i = 0; i < D8; ++i) {
+        lo[i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
+    }
+
+    // zero out shared memory SH
+    for (short j = 0; j < Q; ++j) {
+        for (short i = tiisg; i < SH; i += NW) {
+            ss[j*TF + i] = 0.0f;
+        }
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    {
+        float S[Q] = { [0 ... Q-1] = 0.0h };
+        float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
+
+        // assume K and V are same shape
+        const short ne22 = ne12;
+        const short ne23 = ne13;
+
+        const uint nb21 = nb11;
+        const uint nb22 = nb12;
+        const uint nb23 = nb13;
+
+        // broadcast
+        const short rk2 = ne02/ne12;
+        const short rk3 = ne03/ne13;
+
+        const short rv2 = ne02/ne22;
+        const short rv3 = ne03/ne23;
+
+        // k indices
+        const short ik2 = iq2/rk2;
+        const short ik3 = iq3/rk3;
+
+        // v indices
+        const short iv2 = iq2/rv2;
+        const short iv3 = iq3/rv3;
+
+        // load the queries from shared memory into local memory
+        simdgroup_half8x8 mq[D8];
+
+        for (short i = 0; i < D8; ++i) {
+            simdgroup_load(mq[i], sq + i*8, T);
+        }
+
+        // pointer to the mask
+        device const half * mp = (device const half *) (mask + iq1*nb31);
+
+        // prepare diagonal scale matrix
+        simdgroup_float8x8 mscale(scale);
+
+        // loop over the KV cache
+        // each simdgroup handles blocks of Q rows and C columns
+        for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
+            const int ic = ic0 + C*sgitg;
+            if (ic >= ne11) {
+                break;
+            }
+
+            // Q*K^T
+            {
+                for (short cc = 0; cc < C/8; ++cc) {
+                    simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
+
+                    device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
+
+                    for (short i = 0; i < D8; ++i) {
+                        simdgroup_half8x8 mk;
+                        simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
+
+                        simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
+                    }
+
+                    // mqk = mqk*scale + mask
+                    simdgroup_half8x8 mm;
+                    simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
+                    simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
+
+                    simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
+                }
+            }
+
+            // used to detect blocks full of -INF
+            float smax = -INFINITY;
+
+            // online softmax
+            {
+                float ms[Q];
+
+                for (short j = 0; j < Q; ++j) {
+                    const short p = tiisg;
+
+                    const float m = M[j];
+                    const float s = ss[j*TF + p];
+
+                    smax = simd_max(max(smax, s));
+                    M[j] = simd_max(max(M[j], s));
+
+                                ms[j] = exp(m - M[j]);
+                    const float vs    = exp(s - M[j]);
+
+                    S[j] = S[j]*ms[j] + simd_sum(vs);
+
+                    // the P matrix from the paper (Q rows, C columns)
+                    ss[j*TF + p] = vs;
+                }
+
+                // create a QxQ diagonal matrix for rescaling the output
+                if (tiisg < Q) {
+                    ss[tiisg*TF + C + tiisg] = ms[tiisg];
+                }
+            }
+
+            // skip -INF blocks
+            if (smax == -INFINITY) {
+                continue;
+            }
+
+            // O = diag(ms)*O
+            {
+                simdgroup_float8x8 mm;
+                simdgroup_load(mm, ss + C, TF, 0, false);
+
+                for (short i = 0; i < D8; ++i) {
+                    simdgroup_multiply(lo[i], mm, lo[i]);
+                }
+            }
+
+            // O = O + (Q*K^T)*V
+            {
+                for (short cc = 0; cc < C/8; ++cc) {
+                    device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
+
+                    for (short i = 0; i < D8; ++i) {
+                        simdgroup_half8x8 mk;
+                        simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
+
+                        simdgroup_float8x8 mv;
+                        simdgroup_load(mv, ss + 8*cc, TF, 0, false);
+
+                        simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]);
+                    }
+                }
+            }
+        }
+
+        // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
+        for (short j = 0; j < Q; ++j) {
+            if (tiisg == 0) {
+                ss[j*TF + 0] = S[j];
+                ss[j*TF + 1] = M[j];
+            }
+        }
+    }
+
+    // reduce the warps sequentially
+    for (short sg = 1; sg < nsg; ++sg) {
+        float S = { 0.0h };
+        float M = { -FLT_MAX/2 };
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        // each simdgroup stores its output to shared memory, reusing sq
+        if (sgitg == sg) {
+            for (short i = 0; i < D8; ++i) {
+                simdgroup_store(lo[i], sq + i*8, T, 0, false);
+            }
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        // the first simdgroup accumulates the results from the other simdgroups
+        if (sgitg == 0) {
+            for (short j = 0; j < Q; ++j) {
+                const float S0 = ss[j*TF +         0];
+                const float S1 = ss[j*TF + sg*SH + 0];
+
+                const float M0 = ss[j*TF +         1];
+                const float M1 = ss[j*TF + sg*SH + 1];
+
+                M = max(M0, M1);
+
+                const float ms0 = exp(M0 - M);
+                const float ms1 = exp(M1 - M);
+
+                S = S0*ms0 + S1*ms1;
+
+                if (tiisg == 0) {
+                    ss[j*TF + 0] = S;
+                    ss[j*TF + 1] = M;
+
+                    ss[j*TF + C + j        ] = ms0;
+                    ss[j*TF + C + j + sg*SH] = ms1;
+                }
+            }
+
+            // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
+            {
+                simdgroup_half8x8 t;
+                simdgroup_float8x8 ms0;
+                simdgroup_float8x8 ms1;
+
+                simdgroup_load(ms0, ss + C,         TF, 0, false);
+                simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false);
+
+                for (short i = 0; i < D8; ++i) {
+                    simdgroup_load    (t, sq + i*8, T, 0, false);
+                    simdgroup_multiply(t, ms1, t);
+
+                    simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
+                }
+            }
+        }
+    }
+
+    // store result to shared memory (reuse sq)
+    if (sgitg == 0) {
+        for (short i = 0; i < D8; ++i) {
+            simdgroup_store(lo[i], sq + i*8, T, 0, false);
+        }
+    }
+
+    device float4 * dst4 = (device float4 *) dst;
+
+    // final rescale with 1/S and store to global memory
+    if (sgitg == 0) {
+        for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
+            const float S = ss[j*TF + 0];
+
+            for (short i = tiisg; i < D4; i += NW) {
+                dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
+            }
+        }
+    }
+}
+
+template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>;
+template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>;
+template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
+template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
+template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
+template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
+
+template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
+kernel void kernel_flash_attn_ext_vec_f16(
+        device const  char * q,
+        device const  char * k,
+        device const  char * v,
+        device const  char * mask,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant   int64_t & ne13,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
+        constant   int64_t & ne31,
+        constant  uint64_t & nb31,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant     float & scale,
+        threadgroup   half * shared [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        uint3  tpitg[[thread_position_in_threadgroup]],
+        uint3    ntg[[threads_per_threadgroup]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+    const short nsg = ntg.y; // number of simdgroups
+
+    const short iq3 = tgpig[2];
+    const short iq2 = tgpig[1];
+    const short iq1 = tgpig[0];
+
+    const short D4 = D/4;
+    const short NW = N_SIMDWIDTH;
+    const short SH = (C + Q); // shared memory per simdgroup in (half)
+
+    const short T  = D + 2*nsg*SH; // shared memory size per query in (half)
+
+  //threadgroup half   * sq  = (threadgroup half   *) (shared +              0*D); // holds the query data
+    threadgroup half4  * sq4 = (threadgroup half4  *) (shared +              0*D); // same as above but in half4
+    threadgroup float  * ss  = (threadgroup float  *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
+    threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
+    threadgroup half4  * sr4 = (threadgroup half4  *) (shared +   sgitg*D  + 1*T); // scratch buffer for the results
+
+    // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
+    half4 lo[D4/NW];
+
+    // load heads from Q to shared memory
+    device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
+
+    for (short i = tiisg; i < D4; i += NW) {
+        if (iq1 < ne01) {
+            sq4[i] = (half4) q4[i];
+        } else {
+            sq4[i] = 0.0h;
+        }
+    }
+
+    // zero out lo
+    for (short i = tiisg; i < D4; i += NW) {
+        lo[i/NW] = 0.0h;
+    }
+
+    // zero out shared memory SH
+    for (short i = tiisg; i < SH/4; i += NW) {
+        ss4[i] = 0.0h;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    {
+        float S = { 0.0h };
+        float M = { -FLT_MAX/2 };
+
+        // assume K and V are same shape
+        const short ne22 = ne12;
+        const short ne23 = ne13;
+
+        const uint nb21 = nb11;
+        const uint nb22 = nb12;
+        const uint nb23 = nb13;
+
+        // broadcast
+        const short rk2 = ne02/ne12;
+        const short rk3 = ne03/ne13;
+
+        const short rv2 = ne02/ne22;
+        const short rv3 = ne03/ne23;
+
+        // k indices
+        const short ik2 = iq2 / rk2;
+        const short ik3 = iq3 / rk3;
+
+        // v indices
+        const short iv2 = iq2 / rv2;
+        const short iv3 = iq3 / rv3;
+
+        // load the queries from shared memory into local memory
+        half4 mq[D4];
+
+        for (short ii = 0; ii < D4; ii += NW) {
+            short i = ii + tiisg;
+            mq[i] = sq4[i];
+        }
+
+        // pointer to the mask
+        device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
+
+        // loop over the KV cache
+        // each simdgroup handles blocks of Q rows and C columns
+        for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
+            const int ic = ic0 + C*sgitg;
+            if (ic >= ne11) {
+                break;
+            }
+
+            // Q*K^T
+            {
+#pragma unroll
+                for (short cc = 0; cc < C/4; ++cc) {
+                    float4 mqk = { 0.0h };
+
+                    device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
+
+#pragma unroll
+                    for (short ii = 0; ii < D4; ii += NW) {
+                        const short i = ii + tiisg;
+
+                        half4x4 mk;
+                        mk[0] = pk4[i + 0*(nb11/8)];
+                        mk[1] = pk4[i + 1*(nb11/8)];
+                        mk[2] = pk4[i + 2*(nb11/8)];
+                        mk[3] = pk4[i + 3*(nb11/8)];
+
+                        mqk += (float4) (mq[i] * mk);
+                    }
+
+                    // reduce the results from the threads in the simdgroup
+                    mqk += simd_shuffle_down(mqk, 16);
+                    mqk += simd_shuffle_down(mqk,  8);
+                    mqk += simd_shuffle_down(mqk,  4);
+                    mqk += simd_shuffle_down(mqk,  2);
+                    mqk += simd_shuffle_down(mqk,  1);
+
+                    // mqk = mqk*scale + mask
+                    if (tiisg == 0) {
+                        float4 mm = (float4) mp4[ic/4 + cc];
+                        mqk = mqk*scale + mm;
+
+                        ss4[cc] = mqk;
+                    }
+                }
+            }
+
+            // online softmax
+            {
+                const short p = tiisg;
+
+                const float m = M;
+                const float s = ss[p];
+
+                M = simd_max(max(M, s));
+
+                const float ms = exp(m - M);
+                const float vs = exp(s - M);
+
+                S = S*ms + simd_sum(vs);
+
+                // the P matrix from the paper (Q rows, C columns)
+                ss[p] = vs;
+
+                // O = diag(ms)*O
+#pragma unroll
+                for (short ii = 0; ii < D4; ii += NW) {
+                    const short i = ii + tiisg;
+                    lo[i/NW] *= ms;
+                }
+            }
+
+            // O = O + (Q*K^T)*V
+            {
+#pragma unroll
+                for (short cc = 0; cc < C/4; ++cc) {
+                    device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23));
+
+#pragma unroll
+                    for (short ii = 0; ii < D4; ii += NW) {
+                        const short i = ii + tiisg;
+
+                        lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
+                        lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
+                        lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
+                        lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
+                    }
+                }
+            }
+
+        }
+
+        // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
+        if (tiisg == 0) {
+            ss[0] = S;
+            ss[1] = M;
+        }
+    }
+
+    // store results to shared memory
+    for (short ii = 0; ii < D4; ii += NW) {
+        short i = ii + tiisg;
+        sr4[i] = lo[ii/NW];
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    // parallel reduce
+    for (short r = nsg/2; r > 0; r >>= 1) {
+        if (sgitg < r) {
+            const float S0 = ss[       0];
+            const float S1 = ss[r*SH + 0];
+
+            const float M0 = ss[       1];
+            const float M1 = ss[r*SH + 1];
+
+            const float M = max(M0, M1);
+
+            const float ms0 = exp(M0 - M);
+            const float ms1 = exp(M1 - M);
+
+            const float S = S0*ms0 + S1*ms1;
+
+            if (tiisg == 0) {
+                ss[0] = S;
+                ss[1] = M;
+            }
+
+            // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
+            for (short ii = 0; ii < D4; ii += NW) {
+                short i = ii + tiisg;
+                sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
+            }
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    device float4 * dst4 = (device float4 *) dst;
+
+    // final rescale with 1/S and store to global memory
+    if (sgitg == 0) {
+        const float S = ss[0];
+
+        for (short ii = 0; ii < D4; ii += NW) {
+            short i = ii + tiisg;
+            dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
+        }
+    }
+}
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
+template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
+
 kernel void kernel_cpy_f16_f16(
         device  const half * src0,
         device        half * dst,
index 2b76b3ebd64f749726de80f6d76a81bb0ec520a8..57fe4ea3d4ac25d2cf65511bffcc03b6c6f58901 100644 (file)
@@ -14744,7 +14744,12 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
+    const ggml_tensor * src2 = dst->src[2];
+
+#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support")
+#pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/5021")
     GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
+    GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
 
     const int64_t ne00 = src0->ne[0];
     const int64_t nrows_x = ggml_nrows(src0);
@@ -14760,7 +14765,6 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
     float * src2_dd = nullptr;
     sycl_pool_alloc<float> src2_f;
 
-    ggml_tensor * src2 = dst->src[2];
     const bool use_src2 = src2 != nullptr;
 
     if (use_src2) {
index 1736ab7361c273443d1c6af4435b23cd2d2a678a..f712cdd5a900eb454625d4c1a3a6a1c27cdb89e1 100644 (file)
@@ -3178,6 +3178,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
         }
         return nullptr;
     case GGML_OP_SOFT_MAX:
+#pragma message("TODO: add ggml_vk_soft_max() F16 src1 and src2 support")
+#pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/5021")
+        GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32);
+        GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
+
         if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
             return ctx->device->pipeline_soft_max_f32;
         }
diff --git a/ggml.c b/ggml.c
index 3bddcdbf28a90c2f367c4b26abe364dffb0a3aa7..00f3e170a16b95d6fa603f900f9344174ffa8373 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -951,7 +951,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
     #define GGML_F16_VEC_ZERO           GGML_F16x8_ZERO
     #define GGML_F16_VEC_SET1           GGML_F16x8_SET1
     #define GGML_F16_VEC_LOAD(p, i)     GGML_F16x8_LOAD(p)
-    #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i])
+    #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), r[i])
     #define GGML_F16_VEC_FMA            GGML_F16x8_FMA
     #define GGML_F16_VEC_ADD            GGML_F16x8_ADD
     #define GGML_F16_VEC_MUL            GGML_F16x8_MUL
@@ -977,7 +977,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
     #define GGML_F16_VEC_ZERO           GGML_F32Cx4_ZERO
     #define GGML_F16_VEC_SET1           GGML_F32Cx4_SET1
     #define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx4_LOAD(p)
-    #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
+    #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i])
     #define GGML_F16_VEC_FMA            GGML_F32Cx4_FMA
     #define GGML_F16_VEC_ADD            GGML_F32Cx4_ADD
     #define GGML_F16_VEC_MUL            GGML_F32Cx4_MUL
@@ -1046,7 +1046,7 @@ do {                                                                  \
 
 // unlike  _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F
 // so F16C guard isn't required
-#define GGML_F32Cx16_LOAD(x)     _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(x)))
+#define GGML_F32Cx16_LOAD(x)     _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
 #define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0))
 
 #define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
@@ -1144,7 +1144,7 @@ do {                                                              \
 
 #if defined(__F16C__)
 // the  _mm256_cvt intrinsics require F16C
-#define GGML_F32Cx8_LOAD(x)     _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
+#define GGML_F32Cx8_LOAD(x)     _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
 #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
 #else
 static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
@@ -1662,6 +1662,37 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
 #endif
 }
 
+inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) {
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F16_STEP - 1));
+
+    GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
+
+    GGML_F16_VEC ax[GGML_F16_ARR];
+    GGML_F16_VEC ay[GGML_F16_ARR];
+
+    for (int i = 0; i < np; i += GGML_F16_STEP) {
+        for (int j = 0; j < GGML_F16_ARR; j++) {
+            ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
+            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+            ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
+
+            GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
+        }
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
+    }
+#else
+    // scalar
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
+    }
+#endif
+}
+
 // xs and vs are byte strides of x and v
 inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
 
@@ -1746,6 +1777,35 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) {
 #endif
 }
 
+inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F16_STEP - 1));
+
+    GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
+
+    GGML_F16_VEC ay[GGML_F16_ARR];
+
+    for (int i = 0; i < np; i += GGML_F16_STEP) {
+        for (int j = 0; j < GGML_F16_ARR; j++) {
+            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+            ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
+
+            GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
+        }
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
+    }
+#else
+    // scalar
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
+    }
+#endif
+}
+
 inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s);   }
 inline static void ggml_vec_sqr_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i];   }
 inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
@@ -2001,6 +2061,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "LEAKY_RELU",
 
     "FLASH_ATTN",
+    "FLASH_ATTN_EXT",
     "FLASH_FF",
     "FLASH_ATTN_BACK",
     "SSM_CONV",
@@ -2027,7 +2088,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CROSS_ENTROPY_LOSS_BACK",
 };
 
-static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
+static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -2091,6 +2152,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "leaky_relu(x)",
 
     "flash_attn(x)",
+    "flash_attn_ext(x)",
     "flash_ff(x)",
     "flash_attn_back(x)",
     "ssm_conv(x)",
@@ -2117,7 +2179,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "cross_entropy_loss_back(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
+static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -4575,6 +4637,8 @@ struct ggml_tensor * ggml_mul_mat(
 void ggml_mul_mat_set_prec(
         struct ggml_tensor * a,
         enum ggml_prec       prec) {
+    GGML_ASSERT(a->op == GGML_OP_MUL_MAT);
+
     const int32_t prec_i32 = (int32_t) prec;
 
     ggml_set_op_params_i32(a, 0, prec_i32);
@@ -5413,17 +5477,23 @@ static struct ggml_tensor * ggml_soft_max_impl(
     GGML_ASSERT(ggml_is_contiguous(a));
 
     if (mask) {
+        GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
         GGML_ASSERT(ggml_is_contiguous(mask));
         GGML_ASSERT(ggml_is_matrix(mask));
-        GGML_ASSERT(ggml_can_repeat_rows(mask, a));
+        GGML_ASSERT(mask->ne[0] == a->ne[0]);
+        GGML_ASSERT(mask->ne[1] >= a->ne[1]);
     }
 
     if (pos) {
         GGML_ASSERT(ggml_is_vector(pos));
-        GGML_ASSERT(pos->type == GGML_TYPE_F32);
+        GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32);
         GGML_ASSERT(pos->ne[0] == a->ne[0]);
     }
 
+    if (pos && mask) {
+        GGML_ASSERT(pos->type == mask->type);
+    }
+
     if (max_bias > 0.0f) {
         GGML_ASSERT(pos);
     }
@@ -6232,6 +6302,59 @@ struct ggml_tensor * ggml_flash_attn(
     return result;
 }
 
+// ggml_flash_attn_ext
+
+struct ggml_tensor * ggml_flash_attn_ext(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * q,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        struct ggml_tensor  * mask,
+        float                 scale) {
+    GGML_ASSERT(ggml_can_mul_mat(k, q));
+    // TODO: check if vT can be multiplied by (k*qT)
+    if (mask) {
+        GGML_ASSERT(ggml_is_contiguous(mask));
+        GGML_ASSERT(mask->ne[2] == 1);
+        GGML_ASSERT(mask->ne[3] == 1);
+        GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
+                "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
+        //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
+    }
+
+    bool is_node = false;
+
+    if (q->grad || k->grad || v->grad) {
+        is_node = true;
+    }
+
+    // permute(0, 2, 1, 3)
+    int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    float params[] = { scale };
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op   = GGML_OP_FLASH_ATTN_EXT;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = q;
+    result->src[1] = k;
+    result->src[2] = v;
+    result->src[3] = mask;
+
+    return result;
+}
+
+void ggml_flash_attn_ext_set_prec(
+        struct ggml_tensor * a,
+        enum ggml_prec       prec) {
+    GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
+
+    const int32_t prec_i32 = (int32_t) prec;
+
+    ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos
+}
+
 // ggml_flash_ff
 
 struct ggml_tensor * ggml_flash_ff(
@@ -12317,7 +12440,7 @@ static void ggml_compute_forward_soft_max_f32(
 
     GGML_TENSOR_UNARY_OP_LOCALS
 
-    const int64_t ne11 = src1 ? src1->ne[1] : 1;
+    //const int64_t ne11 = src1 ? src1->ne[1] : 1;
 
     // TODO: is this supposed to be ceil instead of floor?
     //       https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
@@ -12340,19 +12463,31 @@ static void ggml_compute_forward_soft_max_f32(
     float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
 
     // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
-    float * pos = src2 ? (float *) src2->data : src0->data;
+    ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data;
+    float       * pos_f32 = src2 ? (float       *) src2->data : src0->data;
+
+    const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
 
     for (int i1 = ir0; i1 < ir1; i1++) {
         float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
         float * dp = (float *)((char *)  dst->data +  i1*dst->nb[1]);
 
         // broadcast the mask across rows
-        float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
+        ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
+        float       * mp_f32 = src1 ? (float       *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
 
         ggml_vec_cpy_f32  (nc, wp, sp);
         ggml_vec_scale_f32(nc, wp, scale);
-        if (mp) {
-            ggml_vec_acc_f32(nc, wp, mp);
+        if (mp_f32) {
+            if (use_f16) {
+                for (int i = 0; i < nc; ++i) {
+                    wp[i] += GGML_FP16_TO_FP32(mp_f16[i]);
+                }
+            } else {
+                for (int i = 0; i < nc; ++i) {
+                    wp[i] += mp_f32[i];
+                }
+            }
         }
 
         // ALiBi bias
@@ -12360,8 +12495,14 @@ static void ggml_compute_forward_soft_max_f32(
             const uint32_t h  = (i1/ne01)%ne02; // head
             const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
 
-            for (int i = 0; i < nc; i++) {
-                wp[i] = wp[i] + slope*pos[i];
+            if (use_f16) {
+                for (int i = 0; i < nc; ++i) {
+                    wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]);
+                }
+            } else {
+                for (int i = 0; i < nc; ++i) {
+                    wp[i] += slope*pos_f32[i];
+                }
             }
         }
 
@@ -14631,6 +14772,198 @@ static void ggml_compute_forward_flash_attn(
     }
 }
 
+// ggml_compute_forward_flash_attn_ext
+
+static void ggml_compute_forward_flash_attn_ext_f16(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * q,
+        const struct ggml_tensor * k,
+        const struct ggml_tensor * v,
+        const struct ggml_tensor * mask,
+        struct ggml_tensor * dst) {
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t D = neq0;
+    const int64_t N = neq1;
+
+    GGML_ASSERT(ne0 == D);
+    GGML_ASSERT(ne2 == N);
+
+    GGML_ASSERT(nbq0 == sizeof(float));
+    GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
+
+    GGML_ASSERT(neq0 == D);
+    GGML_ASSERT(nek0 == D);
+    GGML_ASSERT(nev0 == D);
+
+    GGML_ASSERT(neq1 == N);
+    GGML_ASSERT(nev0 == D);
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    // broadcast factors
+    const int64_t rk2 = neq2/nek2;
+    const int64_t rk3 = neq3/nek3;
+
+    const int64_t rv2 = neq2/nev2;
+    const int64_t rv3 = neq3/nev3;
+
+    if (params->type == GGML_TASK_TYPE_INIT) {
+        return;
+    }
+
+    if (params->type == GGML_TASK_TYPE_FINALIZE) {
+        return;
+    }
+
+    // parallelize by q rows using ggml_vec_dot_f32
+
+    // total rows in q
+    const int nr = neq1*neq2*neq3;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    float scale = 1.0f;
+    memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
+
+    // loop over n_batch and n_head
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // q indices
+        const int iq3 = ir/(neq2*neq1);
+        const int iq2 = (ir - iq3*neq2*neq1)/neq1;
+        const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
+
+        float S = 0.0f;
+        float M = -INFINITY;
+
+        float       * V32 = (float       *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
+        ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory
+        ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
+
+        memset(V16, 0, D*sizeof(ggml_fp16_t));
+
+        const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
+
+        // k indices
+        const int ik3 = iq3 / rk3;
+        const int ik2 = iq2 / rk2;
+
+        // v indices
+        const int iv3 = iq3 / rv3;
+        const int iv2 = iq2 / rv2;
+
+        // online softmax / attention
+        // loop over n_kv and n_head_kv
+        // ref: https://arxiv.org/pdf/2112.05682.pdf
+        for (int64_t ic = 0; ic < nek1; ++ic) {
+            const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
+            if (mv == -INFINITY) {
+                continue;
+            }
+
+            float s;
+
+            // convert Q to F16 in V32
+            {
+                const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
+
+                for (int64_t d = 0; d < D; ++d) {
+                    Q16[d] = GGML_FP32_TO_FP16(pq[d]);
+                }
+            }
+
+            ggml_vec_dot_f16(D,
+                    &s, 0,
+                    (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
+                    Q16, 0, 1);
+
+            s = s*scale + mv;
+
+            const float Mold = M;
+
+            float ms = 1.0f;
+            float vs = 1.0f;
+
+            if (s > M) {
+                M = s;
+                ms = expf(Mold - M);
+
+                // V = V*expf(Mold - M)
+                ggml_vec_scale_f16(D, V16, ms);
+            } else {
+                vs = expf(s - M);
+            }
+
+            const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
+
+            // V += v*expf(s - M)
+            ggml_vec_mad_f16(D, V16, v16, vs);
+
+            S = S*ms + vs;
+        }
+
+        // V /= S
+        for (int64_t d = 0; d < D; ++d) {
+            V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
+        }
+
+        // dst indices
+        const int i1 = iq1;
+        const int i2 = iq2;
+        const int i3 = iq3;
+
+        // original
+        //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
+
+        // permute(0, 2, 1, 3)
+        memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1);
+    }
+}
+
+static void ggml_compute_forward_flash_attn_ext(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * q,
+        const struct ggml_tensor * k,
+        const struct ggml_tensor * v,
+        const struct ggml_tensor * mask,
+        struct ggml_tensor * dst) {
+    switch (dst->op_params[1]) {
+        case GGML_PREC_DEFAULT:
+        case GGML_PREC_F32:
+            {
+                // uses F32 accumulators
+                ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_flash_ff
 
 static void ggml_compute_forward_flash_ff_f16(
@@ -16442,6 +16775,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
                 const bool masked = t != 0;
                 ggml_compute_forward_flash_attn(params, masked, tensor);
             } break;
+        case GGML_OP_FLASH_ATTN_EXT:
+            {
+                ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
+            } break;
         case GGML_OP_FLASH_FF:
             {
                 ggml_compute_forward_flash_ff(params, tensor);
@@ -17454,6 +17791,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
         case GGML_OP_FLASH_ATTN:
+        case GGML_OP_FLASH_ATTN_EXT:
             {
                 struct ggml_tensor * flash_grad = NULL;
                 if (src0->grad || src1->grad || tensor->src[2]->grad) {
@@ -18231,6 +18569,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
                 n_tasks = n_threads;
             } break;
         case GGML_OP_FLASH_ATTN:
+        case GGML_OP_FLASH_ATTN_EXT:
             {
                 n_tasks = n_threads;
             } break;
@@ -18634,6 +18973,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
                         cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
                     }
                 } break;
+            case GGML_OP_FLASH_ATTN_EXT:
+                {
+                    const int64_t ne00 = node->src[0]->ne[0]; // D
+
+                    cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size
+                } break;
             case GGML_OP_FLASH_FF:
                 {
                     if (node->src[1]->type == GGML_TYPE_F32) {
diff --git a/ggml.h b/ggml.h
index 06cafbd78ba1ac6dc7b337834d732f7c0433b533..d90ba8ed66445262d5c8ec5b349d7b25267c4d12 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -475,6 +475,7 @@ extern "C" {
         GGML_OP_LEAKY_RELU,
 
         GGML_OP_FLASH_ATTN,
+        GGML_OP_FLASH_ATTN_EXT,
         GGML_OP_FLASH_FF,
         GGML_OP_FLASH_ATTN_BACK,
         GGML_OP_SSM_CONV,
@@ -1731,6 +1732,25 @@ extern "C" {
             struct ggml_tensor  * v,
             bool                  masked);
 
+#define GGML_KQ_MASK_PAD 32
+
+    // q:    [n_embd, n_batch,     n_head,    1]
+    // k:    [n_embd, n_kv,        n_head_kv, 1]
+    // v:    [n_embd, n_kv,        n_head_kv, 1] !! not transposed !!
+    // mask: [n_kv,   n_batch_pad, 1,         1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
+    // res:  [n_embd, n_head,      n_batch,   1] !! permuted !!
+    GGML_API struct ggml_tensor * ggml_flash_attn_ext(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * q,
+            struct ggml_tensor  * k,
+            struct ggml_tensor  * v,
+            struct ggml_tensor  * mask,
+            float                 scale);
+
+    GGML_API void ggml_flash_attn_ext_set_prec(
+            struct ggml_tensor * a,
+            enum ggml_prec       prec);
+
     GGML_API struct ggml_tensor * ggml_flash_attn_back(
            struct ggml_context * ctx,
            struct ggml_tensor  * q,