#include "ggml-cuda/cpy.cuh"
#include "ggml-cuda/diagmask.cuh"
#include "ggml-cuda/dmmv.cuh"
+#include "ggml-cuda/fattn.cuh"
#include "ggml-cuda/getrows.cuh"
#include "ggml-cuda/im2col.cuh"
#include "ggml-cuda/mmq.cuh"
info.devices[id].cc = 100*prop.major + 10*prop.minor;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
info.devices[id].smpb = prop.sharedMemPerBlock;
+ info.devices[id].nsm = prop.multiProcessorCount;
}
for (int id = 0; id < info.device_count; ++id) {
case GGML_OP_ARGSORT:
ggml_cuda_op_argsort(ctx, dst);
break;
+ case GGML_OP_FLASH_ATTN_EXT:
+ ggml_cuda_flash_attn_ext(ctx, dst);
+ break;
default:
return false;
}
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_LEAKY_RELU:
+ case GGML_OP_FLASH_ATTN_EXT:
return true;
default:
return false;
#define CC_PASCAL 600
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define CC_VOLTA 700
+#define CC_AMPERE 800
#define CC_OFFSET_AMD 1000000
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
return a;
}
-#ifdef GGML_CUDA_F16
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
#pragma unroll
NO_DEVICE_CODE;
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
}
-#endif // GGML_CUDA_F16
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
return x;
}
-//static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
-//#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
-//#pragma unroll
-// for (int mask = 16; mask > 0; mask >>= 1) {
-// x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
-// }
-// return x;
-//#else
-// GGML_UNUSED(x);
-// NO_DEVICE_CODE;
-//#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
-//}
+static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
+ }
+ return x;
+#else
+ GGML_UNUSED(x);
+ NO_DEVICE_CODE;
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
+}
+#if CUDART_VERSION < 12000
+static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
+ const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
+ const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
+ return mask_low | mask_high;
+}
+#endif // CUDART_VERSION < 12000
#if defined(GGML_USE_HIPBLAS)
#define __CUDA_ARCH__ 1300
}
#endif // defined(GGML_USE_HIPBLAS)
+#define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \
+ defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL
+
+#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
+
// TODO: move to ggml-common.h
static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
struct cuda_device_info {
int cc; // compute capability
+ int nsm; // number of streaming multiprocessors
size_t smpb; // max. shared memory per block
bool vmm; // virtual memory support
size_t vmm_granularity; // granularity of virtual memory
--- /dev/null
+#include "common.cuh"
+#include "fattn.cuh"
+
+#include <cstdint>
+
+#if FP16_MMA_AVAILABLE
+#include <mma.h>
+#endif
+
+#define FATTN_KQ_STRIDE 256
+#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
+#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
+
+template<int D, int parallel_blocks> // D == head size
+__launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1)
+static __global__ void flash_attn_vec_ext_f16(
+ const char * __restrict__ Q,
+ const char * __restrict__ K,
+ const char * __restrict__ V,
+ const char * __restrict__ mask,
+ float * __restrict__ dst,
+ float2 * __restrict__ dst_meta,
+ const float scale,
+ const int ne00,
+ const int ne01,
+ const int ne02,
+ const int ne03,
+ const int ne10,
+ const int ne11,
+ const int ne12,
+ const int ne13,
+ const int ne31,
+ const int nb31,
+ const int nb01,
+ const int nb02,
+ const int nb03,
+ const int nb11,
+ const int nb12,
+ const int nb13,
+ const int ne0,
+ const int ne1,
+ const int ne2,
+ const int ne3) {
+#if FP16_AVAILABLE
+ //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+ const int ic = blockIdx.x / parallel_blocks; // Index of the Q/QKV column to work on.
+ const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+ const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic);
+ const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
+ const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
+ const half * maskh = (const half *) mask + ne11*ic;
+
+ const int stride_KV = nb11 / sizeof(half);
+ const int stride_KV2 = nb11 / sizeof(half2);
+
+ constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
+ const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
+ __builtin_assume(tid < nwarps*WARP_SIZE);
+
+ __shared__ half KQ[nwarps*WARP_SIZE];
+ KQ[tid] = -INFINITY;
+ half2 * KQ2 = (half2 *) KQ;
+
+ half kqmax = -HALF_MAX_HALF;
+ half kqsum = 0.0f;
+
+ __shared__ half kqmax_shared[WARP_SIZE];
+ __shared__ half kqsum_shared[WARP_SIZE];
+ if (threadIdx.y == 0) {
+ kqmax_shared[threadIdx.x] = -HALF_MAX_HALF;
+ kqsum_shared[threadIdx.x] = 0.0f;
+ }
+ __syncthreads();
+
+ // Convert Q to half2 and store in registers:
+ half2 Q_h2[(D/2 + WARP_SIZE - 1) / WARP_SIZE];
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+ if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+ break;
+ }
+
+ Q_h2[i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(Q_f2[i].x, Q_f2[i].y);
+ }
+
+ half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value.
+
+ const int k_start = parallel_blocks == 1 ? 0 : ip*D;
+ for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
+ // Calculate KQ tile and keep track of new maximum KQ values:
+ half kqmax_new = kqmax;
+#pragma unroll
+ for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
+ const int i_KQ = i_KQ_0 + threadIdx.y;
+
+ if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
+ break;
+ }
+
+ half2 sum2 = make_half2(0.0f, 0.0f);
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
+ const int k_KQ = k_KQ_0 + threadIdx.x;
+ if (k_KQ_0 + WARP_SIZE > D/2 && k_KQ >= D/2) {
+ break;
+ }
+
+ const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
+ sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE];
+ }
+
+ sum2 = warp_reduce_sum(sum2);
+ half sum = __low2half(sum2) + __high2half(sum2);
+ sum += mask ? maskh[k_VKQ_0 + i_KQ] : __float2half(0.0f);
+ kqmax_new = __hmax(kqmax_new, sum);
+ if (threadIdx.x == 0) {
+ KQ[i_KQ] = sum;
+ }
+ }
+
+ kqmax_new = warp_reduce_max(kqmax_new);
+ if (threadIdx.x == 0) {
+ kqmax_shared[threadIdx.y] = kqmax_new;
+ }
+ __syncthreads();
+ kqmax_new = kqmax_shared[threadIdx.x];
+ kqmax_new = warp_reduce_max(kqmax_new);
+
+ const half KQ_max_scale = hexp(kqmax - kqmax_new);
+ kqmax = kqmax_new;
+
+ const half val = hexp(KQ[tid] - kqmax);
+ kqsum = kqsum*KQ_max_scale + val;
+ KQ[tid] = val;
+
+ VKQ *= __half2half2(KQ_max_scale);
+
+ __syncthreads();
+
+ if (tid < D) {
+#pragma unroll
+ for (int k0 = 0; k0 < D; k0 += 2) {
+ if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
+ break;
+ }
+
+ half2 V_k;
+ reinterpret_cast<half&>(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid];
+ reinterpret_cast<half&>(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid];
+ VKQ += V_k*KQ2[k0/2];
+ }
+ }
+
+ __syncthreads();
+ }
+
+ if (tid >= D) {
+ kqsum = 0.0f;
+ }
+
+ kqsum = warp_reduce_sum(kqsum);
+ if (threadIdx.x == 0) {
+ kqsum_shared[threadIdx.y] = kqsum;
+ }
+ __syncthreads();
+ kqsum = kqsum_shared[threadIdx.x];
+ kqsum = warp_reduce_sum(kqsum);
+
+ if (tid >= D) {
+ return;
+ }
+
+ half dst_val = (__low2half(VKQ) + __high2half(VKQ));
+ if (parallel_blocks == 1) {
+ dst_val /= kqsum;
+ }
+ dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = dst_val;
+
+ if (parallel_blocks == 1 || tid != 0) {
+ return;
+ }
+ dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax, kqsum);
+#else
+ NO_DEVICE_CODE;
+#endif // FP16_AVAILABLE
+}
+
+// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
+template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
+__launch_bounds__(nwarps*WARP_SIZE, 1)
+static __global__ void flash_attn_ext_f16(
+ const char * __restrict__ Q,
+ const char * __restrict__ K,
+ const char * __restrict__ V,
+ const char * __restrict__ mask,
+ float * __restrict__ dst,
+ float2 * __restrict__ dst_meta,
+ const float scale,
+ const int ne00,
+ const int ne01,
+ const int ne02,
+ const int ne03,
+ const int ne10,
+ const int ne11,
+ const int ne12,
+ const int ne13,
+ const int ne31,
+ const int nb31,
+ const int nb01,
+ const int nb02,
+ const int nb03,
+ const int nb11,
+ const int nb12,
+ const int nb13,
+ const int ne0,
+ const int ne1,
+ const int ne2,
+ const int ne3) {
+#if FP16_MMA_AVAILABLE
+ //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+ const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
+ const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+
+ static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
+ static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
+ constexpr int frag_m = ncols == 8 ? 32 : 16;
+ constexpr int frag_n = ncols == 8 ? 8 : 16;
+ static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
+
+ constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
+ constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
+ static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
+
+ // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
+ constexpr int D_padded = D + 8;
+ constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
+ constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
+
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+ const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
+ const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
+ const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
+ const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
+ const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
+
+ const int stride_Q = nb01 / sizeof(float);
+ const int stride_KV = nb11 / sizeof(half);
+
+ frag_b Q_b[D/16][ncols/frag_n];
+
+ // A single buffer for temporarily holding tiles of KQ and VKQ parts:
+ constexpr int mem_KQ = ncols*kqs_padded*kqar;
+ constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
+ __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
+ float * KQ_f = (float *) KQ;
+ half2 * KQ2 = (half2 *) KQ;
+
+ float KQ_rowsum_f[ncols/nwarps] = {0.0f};
+ float KQ_max_f[ncols/nwarps];
+ float KQ_max_scale_f[ncols/nwarps] = {0.0f};
+
+#pragma unroll
+ for (int j = 0; j < ncols/nwarps; ++j) {
+ KQ_max_f[j] = -FLT_MAX/2.0f;
+ }
+
+ half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
+ half2 KQ_max_h2[ncols/nwarps];
+ half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
+
+#pragma unroll
+ for (int j = 0; j < ncols/nwarps; ++j) {
+ KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
+ }
+
+ __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
+ half2 * VKQ2 = (half2 *) VKQ;
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+ if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+ break;
+ }
+ VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
+ }
+ }
+
+ // Convert Q to half and apply scale, temporarily store in KQ:
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+#pragma unroll
+ for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+ if (i0 + WARP_SIZE > D && i >= D) {
+ break;
+ }
+ KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
+ }
+ }
+
+ __syncthreads();
+
+ // Load Q into tensor core fragments/registers since it will be used frequently:
+#pragma unroll
+ for (int i0 = 0; i0 < D; i0 += 16) {
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+ nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
+ }
+ }
+
+ __syncthreads();
+
+ // Iterate over ne11 == previous tokens:
+ for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
+ // Calculate tile of KQ:
+#pragma unroll
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
+ frag_c_KQ KQ_c[ncols/frag_n];
+#pragma unroll
+ for (int j = 0; j < ncols/frag_n; ++j) {
+ nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
+ }
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
+ frag_a_K K_a;
+ nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
+#pragma unroll
+ for (int j = 0; j < ncols/frag_n; ++j) {
+ nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
+ }
+ }
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+ nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
+ }
+ }
+
+ __syncthreads();
+
+ // Calculate softmax for each KQ column using the current max. value.
+ // The divisor is stored in KQ_rowsum and will be applied at the end.
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ if (std::is_same<KQ_acc_t, float>::value) {
+ float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
+#pragma unroll
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+ const int k = k0 + threadIdx.x;
+
+ KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
+ }
+
+ float KQ_max_new = KQ_max_f[j0/nwarps];
+#pragma unroll
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+ const int k = k0 + threadIdx.x;
+
+ KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
+ KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
+ }
+ KQ_max_new = warp_reduce_max(KQ_max_new);
+
+ const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
+ KQ_max_scale_f[j0/nwarps] = expf(diff);
+ if (diff <= SOFTMAX_FTZ_THRESHOLD) {
+ KQ_max_scale_f[j0/nwarps] = 0.0f;
+ }
+ KQ_max_f[j0/nwarps] = KQ_max_new;
+
+ float KQ_rowsum_add = 0.0f;
+#pragma unroll
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+ const int k = k0 + threadIdx.x;
+
+ const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
+ KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
+ if (diff <= SOFTMAX_FTZ_THRESHOLD) {
+ KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
+ }
+ KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
+ KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
+ }
+ KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
+
+ // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+ KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
+ } else {
+ half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
+#pragma unroll
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+ const int k = k0 + threadIdx.x;
+
+ KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
+ }
+
+ half2 KQ_max_new = KQ_max_h2[j0/nwarps];
+#pragma unroll
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+ const int k = k0 + threadIdx.x;
+
+ KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
+ KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
+ }
+ KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
+ const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
+ KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
+ const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
+ *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
+ KQ_max_h2[j0/nwarps] = KQ_max_new;
+
+ half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
+#pragma unroll
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+ const int k = k0 + threadIdx.x;
+
+ const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
+ KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
+ const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
+ *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
+ KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
+ KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
+ }
+ KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
+
+ // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+ KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
+ }
+ }
+
+ __syncthreads();
+
+ frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+#pragma unroll
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
+ const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
+ nvcuda::wmma::load_matrix_sync(
+ KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
+ KQ + j0*(kqar*kqs_padded) + k,
+ kqar*kqs_padded);
+ }
+ }
+
+ frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
+#pragma unroll
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
+#pragma unroll
+ for (int j = 0; j < ncols/frag_n; ++j) {
+ nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
+ }
+
+#pragma unroll
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
+ const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
+
+ frag_a_V v_a;
+ nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
+#pragma unroll
+ for (int j = 0; j < ncols/frag_n; ++j) {
+ nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
+ }
+ }
+ }
+
+ __syncthreads();
+
+ const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
+#pragma unroll
+ for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+ nvcuda::wmma::store_matrix_sync(
+ KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
+ VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
+ D_padded, nvcuda::wmma::mem_col_major);
+ }
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ half2 VKQ_scale;
+ if (std::is_same<KQ_acc_t, float>::value) {
+ VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
+ } else {
+ VKQ_scale = KQ_max_scale_h2[j0/nwarps];
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+ if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+ break;
+ }
+
+ half2 VKQ_add = make_half2(0.0f, 0.0f);
+#pragma unroll
+ for (int l = 0; l < VKQ_ratio; ++l) {
+ VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
+ }
+ VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
+ }
+ }
+
+ __syncthreads();
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j_VKQ = j0 + threadIdx.y;
+ if (ic0 + j_VKQ >= ne01) {
+ return;
+ }
+ const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
+
+ float KQ_rowsum_j;
+ if (std::is_same<KQ_acc_t, float>::value) {
+ KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
+ } else {
+ KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+ if (i0 + WARP_SIZE > D && i >= D) {
+ break;
+ }
+ float dst_val = VKQ[j_VKQ*D_padded + i];
+ if (parallel_blocks == 1) {
+ dst_val /= KQ_rowsum_j;
+ }
+ dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
+ }
+
+ if (parallel_blocks == 1 || threadIdx.x != 0) {
+ continue;
+ }
+
+ float2 dst_meta_val;
+ if (std::is_same<KQ_acc_t, float>::value) {
+ dst_meta_val.x = KQ_max_f[j0/nwarps];
+ } else {
+ dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
+ }
+ dst_meta_val.y = KQ_rowsum_j;
+ dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
+ }
+#else
+ NO_DEVICE_CODE;
+#endif // FP16_MMA_AVAILABLE
+}
+
+template<int D, int parallel_blocks> // D == head size
+__launch_bounds__(D, 1)
+static __global__ void flash_attn_combine_results(
+ const float * __restrict__ VKQ_parts,
+ const float2 * __restrict__ VKQ_meta,
+ float * __restrict__ dst) {
+#if FP16_AVAILABLE
+ VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
+ VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
+ dst += D * gridDim.y*blockIdx.x;
+
+ const int tid = threadIdx.x;
+ __builtin_assume(tid < D);
+
+ __shared__ float2 meta[parallel_blocks];
+ if (tid < 2*parallel_blocks) {
+ ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
+ }
+
+ __syncthreads();
+
+ float kqmax = meta[0].x;
+#pragma unroll
+ for (int l = 1; l < parallel_blocks; ++l) {
+ kqmax = max(kqmax, meta[l].x);
+ }
+
+ float VKQ_numerator = 0.0f;
+ float VKQ_denominator = 0.0f;
+#pragma unroll
+ for (int l = 0; l < parallel_blocks; ++l) {
+ const float diff = meta[l].x - kqmax;
+ const float KQ_max_scale = expf(diff);
+ const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
+ *((uint32_t *) &KQ_max_scale) &= ftz_mask;
+
+ VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
+ VKQ_denominator += KQ_max_scale * meta[l].y;
+ }
+
+ dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
+#else
+ NO_DEVICE_CODE;
+#endif // FP16_AVAILABLE
+}
+
+constexpr int get_max_power_of_2(int x) {
+ return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
+}
+
+static_assert(get_max_power_of_2(1) == 1, "Test failed.");
+static_assert(get_max_power_of_2(2) == 2, "Test failed.");
+static_assert(get_max_power_of_2(4) == 4, "Test failed.");
+static_assert(get_max_power_of_2(6) == 2, "Test failed.");
+
+// Number of VKQ rows calculated in parallel:
+constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
+ return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
+}
+
+static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed.");
+static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed.");
+static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
+static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed.");
+static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed.");
+static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed.");
+static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
+static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
+static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
+
+template <int D, int parallel_blocks> void launch_fattn_vec_f16(
+ const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
+ ggml_cuda_pool & pool, cudaStream_t main_stream
+) {
+ ggml_cuda_pool_alloc<float> dst_tmp(pool);
+ ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
+
+ if (parallel_blocks > 1) {
+ dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
+ dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
+ }
+
+ constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
+ const dim3 block_dim(WARP_SIZE, nwarps, 1);
+ const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]);
+ const int shmem = 0;
+
+ float scale;
+ memcpy(&scale, KQV->op_params, sizeof(float));
+
+ flash_attn_vec_ext_f16<D, parallel_blocks>
+ <<<blocks_num, block_dim, shmem, main_stream>>> (
+ (const char *) Q->data,
+ (const char *) K->data,
+ (const char *) V->data,
+ mask ? ((const char *) mask->data) : nullptr,
+ parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
+ scale,
+ Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
+ K->ne[0], K->ne[1], K->ne[2], K->ne[3],
+ mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
+ Q->nb[1], Q->nb[2], Q->nb[3],
+ K->nb[1], K->nb[2], K->nb[3],
+ KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
+ );
+ CUDA_CHECK(cudaGetLastError());
+
+ if (parallel_blocks == 1) {
+ return;
+ }
+
+ const dim3 block_dim_combine(D, 1, 1);
+ const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
+ const int shmem_combine = 0;
+
+ flash_attn_combine_results<D, parallel_blocks>
+ <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
+ (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
+ CUDA_CHECK(cudaGetLastError());
+}
+
+template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename KQ_acc_t> void launch_fattn_f16_impl(
+ const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
+ ggml_cuda_pool & pool, cudaStream_t main_stream
+) {
+ ggml_cuda_pool_alloc<float> dst_tmp(pool);
+ ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
+
+ if (parallel_blocks > 1) {
+ dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
+ dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
+ }
+
+ constexpr int frag_m = (cols_per_block) == 8 && (D) % 32 == 0 ? 32 : 16;
+ const dim3 block_dim(WARP_SIZE, nwarps, 1);
+ const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
+ const int shmem = 0;
+
+ float scale;
+ memcpy(&scale, KQV->op_params, sizeof(float));
+
+ flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>
+ <<<blocks_num, block_dim, shmem, main_stream>>> (
+ (const char *) Q->data,
+ (const char *) K->data,
+ (const char *) V->data,
+ mask ? ((const char *) mask->data) : nullptr,
+ (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
+ scale,
+ Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
+ K->ne[0], K->ne[1], K->ne[2], K->ne[3],
+ mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
+ Q->nb[1], Q->nb[2], Q->nb[3],
+ K->nb[1], K->nb[2], K->nb[3],
+ KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
+ );
+ CUDA_CHECK(cudaGetLastError());
+
+ if ((parallel_blocks) == 1) {
+ return;
+ }
+
+ const dim3 block_dim_combine(D, 1, 1);
+ const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
+ const int shmem_combine = 0;
+
+ flash_attn_combine_results<D, parallel_blocks>
+ <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
+ (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
+ CUDA_CHECK(cudaGetLastError());
+}
+
+template <int D, int cols_per_block, int nwarps, typename KQ_acc_t> void launch_fattn_f16(
+ const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
+ const int nsm, ggml_cuda_pool & pool, cudaStream_t main_stream
+) {
+ const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
+
+ if (4*blocks_num_pb1 < 2*nsm) {
+ launch_fattn_f16_impl<D, cols_per_block, nwarps, 4, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
+ return;
+ }
+ if (2*blocks_num_pb1 < 2*nsm) {
+ launch_fattn_f16_impl<D, cols_per_block, nwarps, 2, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
+ return;
+ }
+ launch_fattn_f16_impl<D, cols_per_block, nwarps, 1, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
+}
+
+void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * Q = dst->src[0];
+ const ggml_tensor * K = dst->src[1];
+ const ggml_tensor * V = dst->src[2];
+
+ const ggml_tensor * mask = dst->src[3];
+
+ ggml_tensor * KQV = dst;
+
+ GGML_ASSERT(Q->type == GGML_TYPE_F32);
+ GGML_ASSERT(K->type == GGML_TYPE_F16);
+ GGML_ASSERT(V->type == GGML_TYPE_F16);
+ GGML_ASSERT(KQV->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
+ GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
+ "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
+
+ GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
+
+ ggml_cuda_set_device(ctx.device);
+
+ const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
+
+ const int32_t precision = KQV->op_params[1];
+
+ if (precision != GGML_PREC_DEFAULT) {
+ if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
+ constexpr int cols_per_block = 16;
+ constexpr int nwarps = 4;
+ switch (Q->ne[0]) {
+ case 64:
+ launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ break;
+ case 80:
+ launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ break;
+ case 96:
+ launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ break;
+ case 112:
+ launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ break;
+ case 128:
+ launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ break;
+ case 256:
+ launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ break;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
+ } else {
+ constexpr int cols_per_block = 32;
+ constexpr int nwarps = 4;
+ switch (Q->ne[0]) {
+ case 64:
+ launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ break;
+ case 80:
+ launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ break;
+ case 96:
+ launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ break;
+ case 112:
+ launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ break;
+ case 128:
+ launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ break;
+ // case 256:
+ // launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ // break;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
+ }
+ return;
+ }
+
+ if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
+ constexpr int parallel_blocks = 4;
+ switch (Q->ne[0]) {
+ case 64:
+ launch_fattn_vec_f16< 64, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
+ break;
+ case 128:
+ launch_fattn_vec_f16<128, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
+ break;
+ case 256:
+ launch_fattn_vec_f16<256, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
+ break;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
+ return;
+ }
+
+ if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
+ constexpr int cols_per_block = 8;
+ constexpr int nwarps = 4;
+ switch (Q->ne[0]) {
+ case 64:
+ launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ break;
+ case 96:
+ launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ break;
+ case 128:
+ launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ break;
+ case 256:
+ launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ break;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
+ return;
+ }
+
+ if (Q->ne[1] <= 32) {
+ constexpr int cols_per_block = 16;
+ constexpr int nwarps = 4;
+ switch (Q->ne[0]) {
+ case 64:
+ launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ break;
+ case 80:
+ launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ break;
+ case 96:
+ launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ break;
+ case 112:
+ launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ break;
+ case 128:
+ launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ break;
+ case 256:
+ launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ break;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
+ return;
+ }
+
+ constexpr int cols_per_block = 32;
+ constexpr int nwarps = 4;
+ switch (Q->ne[0]) {
+ case 64:
+ launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ break;
+ case 80:
+ launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ break;
+ case 96:
+ launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ break;
+ case 112:
+ launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ break;
+ case 128:
+ launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ break;
+ case 256:
+ launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+ break;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
+ return;
+}
--- /dev/null
+#include "common.cuh"
+
+void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
#include "softmax.cuh"
-template <bool vals_smem, int ncols_template, int block_size_template>
-static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
+template <typename T>
+static __device__ __forceinline__ float t2f32(T val) {
+ return (float) val;
+}
+
+template <>
+__device__ float __forceinline__ t2f32<half>(half val) {
+ return __half2float(val);
+}
+
+template <bool vals_smem, int ncols_template, int block_size_template, typename T>
+static __global__ void soft_max_f32(const float * x, const T * mask, const T * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
const int tid = threadIdx.x;
const int64_t ix = (int64_t)rowx*ncols + col;
const int64_t iy = (int64_t)rowy*ncols + col;
- const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
+ const float val = x[ix]*scale + (mask ? t2f32(mask[iy]) : 0.0f) + (pos ? slope*t2f32(pos[col]) : 0.0f);
vals[col] = val;
max_val = max(max_val, val);
}
}
-static void soft_max_f32_cuda(const float * x, const float * mask, const float * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
+template<typename T>
+static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
int nth = WARP_SIZE;
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1);
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
+ const ggml_tensor * src2 = dst->src[2];
+
const float * src0_d = (const float *)src0->data;
- const float * src1_d = src1 ? (const float *)src1->data : nullptr;
+ const void * src1_d = src1 ? (const void *)src1->data : nullptr;
+
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
- GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
+ GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
+ GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
const int64_t ne00 = src0->ne[0];
const int64_t nrows_x = ggml_nrows(src0);
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
// positions tensor
- float * src2_dd = nullptr;
+ void * src2_d = nullptr;
- ggml_tensor * src2 = dst->src[2];
const bool use_src2 = src2 != nullptr;
if (use_src2) {
- src2_dd = (float *)src2->data;
+ src2_d = (void *)src2->data;
}
- soft_max_f32_cuda(src0_d, src1_d, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
+
+ if (use_f16) {
+ const half * src1_dd = (const half *)src1_d;
+ const half * src2_dd = (const half *)src2_d;
+
+ soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+ } else {
+ const float * src1_dd = (const float *)src1_d;
+ const float * src2_dd = (const float *)src2_d;
+
+ soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+ }
}
for (int i = node_start; i < node_end; ++i) {
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
+ struct ggml_tensor * src2 = gf->nodes[i]->src[2]; GGML_UNUSED(src2);
struct ggml_tensor * dst = gf->nodes[i];
GGML_ASSERT(dst->data != nullptr);
{
float scale;
memcpy(&scale, dst->op_params, sizeof(float));
+
+#pragma message("TODO: add ggml_vk_soft_max() F16/F32 src1 and src2 support")
+#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
+ GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
+ GGML_ASSERT(src2 == nullptr);
+
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
} break;
case GGML_OP_DIAG_MASK_INF:
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
GGML_METAL_KERNEL_TYPE_SILU,
GGML_METAL_KERNEL_TYPE_SILU_4,
- GGML_METAL_KERNEL_TYPE_SOFT_MAX,
- GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
+ GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
+ GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
+ GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
+ GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
}
/*
- GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
+ GGML_METAL_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
(int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
(int) kernel->pipeline.threadExecutionWidth); \
*/
return NULL; \
} \
} else { \
- GGML_METAL_LOG_WARN("%s: skipping %-32s (not supported)\n", __func__, "kernel_"#name); \
+ GGML_METAL_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
}
// simd_sum and simd_max requires MTLGPUFamilyApple7
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
}
[metal_library release];
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT:
case GGML_OP_LEAKY_RELU:
+ case GGML_OP_FLASH_ATTN_EXT:
return true;
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
} break;
case GGML_OP_SOFT_MAX:
{
+ GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
+
int nth = 32; // SIMD width
id<MTLComputePipelineState> pipeline = nil;
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
+
if (ne00%4 == 0) {
while (nth < ne00/4 && nth < 256) {
nth *= 2;
}
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline;
+ if (use_f16) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
+ }
} else {
while (nth < ne00 && nth < 1024) {
nth *= 2;
}
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
+ if (use_f16) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline;
+ }
}
float scale;
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
+ case GGML_OP_FLASH_ATTN_EXT:
+ {
+ GGML_ASSERT(ne00 % 4 == 0);
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+ struct ggml_tensor * src3 = gf->nodes[i]->src[3];
+
+ GGML_ASSERT(ggml_are_same_shape(src1, src2));
+ GGML_ASSERT(src3);
+
+ size_t offs_src3 = 0;
+
+ id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
+
+ GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
+ GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
+ "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
+
+ const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
+ const int64_t ne31 = src3 ? src3->ne[1] : 0;
+ const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
+ const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
+
+ const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30);
+ const uint64_t nb31 = src3 ? src3->nb[1] : 0;
+ const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
+ const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
+
+ const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
+
+ float scale;
+ memcpy(&scale, dst->op_params, sizeof(float));
+
+ id<MTLComputePipelineState> pipeline = nil;
+
+ bool use_vec_kernel = false;
+
+ if (ne01 >= 4 || (ne00%128 != 0)) {
+ switch (ne00) {
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
+ default:
+ {
+ GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
+ GGML_METAL_LOG_ERROR("add template specialization for this size\n");
+ GGML_ASSERT(false && "add template specialization for this size");
+ }
+ }
+ } else {
+ use_vec_kernel = true;
+
+ switch (ne00) {
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
+ default:
+ {
+ GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
+ GGML_METAL_LOG_ERROR("add template specialization for this size\n");
+ GGML_ASSERT(false && "add template specialization for this size");
+ }
+ }
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
+ [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13];
+ [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14];
+ [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15];
+ [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16];
+ [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17];
+ [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
+ [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
+ [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
+ [encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21];
+ [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22];
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23];
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24];
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25];
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
+ [encoder setBytes:&scale length:sizeof( float) atIndex:27];
+
+ if (!use_vec_kernel) {
+ // half8x8 kernel
+ const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
+ const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
+
+ GGML_ASSERT(nqptg <= 32);
+ GGML_ASSERT(nqptg % 8 == 0);
+ GGML_ASSERT(ncpsg % 32 == 0);
+
+ int64_t nsgmax = 2;
+
+ while (true) {
+ const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
+ if (smem > ctx->device.maxThreadgroupMemoryLength) {
+ break;
+ }
+ nsgmax *= 2;
+ }
+ nsgmax /= 2;
+
+ // simdgroups per threadgroup (a.k.a. warps)
+ const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
+
+ const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
+
+ //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
+ GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
+
+ [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
+
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+ } else {
+ // half1x4 kernel
+ const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
+ const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
+
+ GGML_ASSERT(nqptg <= 32);
+ GGML_ASSERT(nqptg % 1 == 0);
+ GGML_ASSERT(ncpsg % 32 == 0);
+
+ // simdgroups per threadgroup (a.k.a. warps)
+ const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
+
+ int64_t nsg = 1;
+ while (nsg <= nsgt) {
+ nsg *= 2;
+ }
+ nsg /= 2;
+
+ const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
+
+ //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
+ GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
+ [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
+
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+ }
+ } break;
case GGML_OP_DUP:
case GGML_OP_CPY:
case GGML_OP_CONT:
UNUSED(buft);
}
-static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device) {
+static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
+#ifndef GGML_METAL_NDEBUG
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
if (@available(macOS 10.12, iOS 16.0, *)) {
- GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
+ GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)",
+ __func__,
+ size_aligned / 1024.0 / 1024.0,
device.currentAllocatedSize / 1024.0 / 1024.0,
device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
GGML_METAL_LOG_INFO("\n");
}
} else {
- GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
+ GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n",
+ __func__,
+ size_aligned / 1024.0 / 1024.0,
+ device.currentAllocatedSize / 1024.0 / 1024.0);
}
+#endif
#endif
UNUSED(device);
+ UNUSED(size_aligned);
}
GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
return NULL;
}
- GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
- ggml_backend_metal_log_allocated_size(device);
+ //ggml_backend_metal_log_allocated_size(device, size_aligned);
return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
}
return false;
}
- GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
+ ggml_backend_metal_log_allocated_size(device, size_aligned);
++ctx->n_buffers;
} else {
return false;
}
- GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, offs = %12ld", __func__, size_step_aligned / 1024.0 / 1024.0, i);
+ ggml_backend_metal_log_allocated_size(device, size_step_aligned);
+
if (i + size_step < size) {
GGML_METAL_LOG_INFO("\n");
}
}
}
- ggml_backend_metal_log_allocated_size(device);
-
return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
}
dst_row[0] = row_sum;
}
+template<typename T>
kernel void kernel_soft_max(
- device const float * src0,
- device const float * src1,
- device const float * src2,
- device float * dst,
+ device const char * src0,
+ device const char * src1,
+ device const char * src2,
+ device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
- device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
- device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
- device const float * ppos = src2 != src0 ? src2 : nullptr;
- device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+ device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
+ device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
+ device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
float slope = 0.0f;
}
}
+template<typename T>
kernel void kernel_soft_max_4(
- device const float * src0,
- device const float * src1,
- device const float * src2,
- device float * dst,
+ device const char * src0,
+ device const char * src1,
+ device const char * src2,
+ device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
- device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
- device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
- device const float4 * ppos = src2 != src0 ? (device const float4 *)(src2) : nullptr;
- device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+ device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
+ device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
+ device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
float slope = 0.0f;
float4 lmax4 = -INFINITY;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
- lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)));
}
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
// parallel sum
float4 lsum4 = 0.0f;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
- const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
}
}
+typedef decltype(kernel_soft_max<float>) kernel_soft_max_t;
+typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
+
+template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max<half>;
+template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max<float>;
+template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
+template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
+
kernel void kernel_diag_mask_inf(
device const float * src0,
device float * dst,
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
}
+typedef void (flash_attn_ext_f16_t)(
+ device const char * q,
+ device const char * k,
+ device const char * v,
+ device const char * mask,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant int64_t & ne31,
+ constant uint64_t & nb31,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant float & scale,
+ threadgroup half * shared,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]);
+
+// ref: https://arxiv.org/pdf/2307.08691.pdf
+template<int64_t D, int64_t Q = 8, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
+kernel void kernel_flash_attn_ext_f16(
+ device const char * q,
+ device const char * k,
+ device const char * v,
+ device const char * mask,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant int64_t & ne31,
+ constant uint64_t & nb31,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant float & scale,
+ threadgroup half * shared [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ const short nsg = ntg.y; // number of simdgroups
+
+ const short iq3 = tgpig[2];
+ const short iq2 = tgpig[1];
+ const short iq1 = tgpig[0]*Q;
+
+ const short D4 = D/4;
+ const short D8 = D/8;
+ const short Q8 = Q/8;
+ const short NW = N_SIMDWIDTH;
+ const short SH = (C + Q); // shared memory per simdgroup in (half)
+
+ const short T = D + 2*nsg*SH; // shared memory size per query in (half)
+ const short TF = T/2; // shared memory size per query in (float)
+ const short T4 = T/4; // shared memory size per query in (half4)
+
+ threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
+ threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
+ threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
+
+ // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
+ simdgroup_half8x8 lo[D8];
+
+ // load heads from Q to shared memory
+ for (short j = sgitg; j < Q; j += nsg) {
+ device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
+
+ for (short i = tiisg; i < D4; i += NW) {
+ if (iq1 + j < ne01) {
+ sq4[j*T4 + i] = (half4) q4[i];
+ } else {
+ sq4[j*T4 + i] = 0.0h;
+ }
+ }
+ }
+
+ // zero out lo
+ for (short i = 0; i < D8; ++i) {
+ lo[i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
+ }
+
+ // zero out shared memory SH
+ for (short j = 0; j < Q; ++j) {
+ for (short i = tiisg; i < SH; i += NW) {
+ ss[j*TF + i] = 0.0f;
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ {
+ float S[Q] = { [0 ... Q-1] = 0.0h };
+ float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
+
+ // assume K and V are same shape
+ const short ne22 = ne12;
+ const short ne23 = ne13;
+
+ const uint nb21 = nb11;
+ const uint nb22 = nb12;
+ const uint nb23 = nb13;
+
+ // broadcast
+ const short rk2 = ne02/ne12;
+ const short rk3 = ne03/ne13;
+
+ const short rv2 = ne02/ne22;
+ const short rv3 = ne03/ne23;
+
+ // k indices
+ const short ik2 = iq2/rk2;
+ const short ik3 = iq3/rk3;
+
+ // v indices
+ const short iv2 = iq2/rv2;
+ const short iv3 = iq3/rv3;
+
+ // load the queries from shared memory into local memory
+ simdgroup_half8x8 mq[D8];
+
+ for (short i = 0; i < D8; ++i) {
+ simdgroup_load(mq[i], sq + i*8, T);
+ }
+
+ // pointer to the mask
+ device const half * mp = (device const half *) (mask + iq1*nb31);
+
+ // prepare diagonal scale matrix
+ simdgroup_float8x8 mscale(scale);
+
+ // loop over the KV cache
+ // each simdgroup handles blocks of Q rows and C columns
+ for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
+ const int ic = ic0 + C*sgitg;
+ if (ic >= ne11) {
+ break;
+ }
+
+ // Q*K^T
+ {
+ for (short cc = 0; cc < C/8; ++cc) {
+ simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
+
+ device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
+
+ for (short i = 0; i < D8; ++i) {
+ simdgroup_half8x8 mk;
+ simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
+
+ simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
+ }
+
+ // mqk = mqk*scale + mask
+ simdgroup_half8x8 mm;
+ simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
+ simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
+
+ simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
+ }
+ }
+
+ // used to detect blocks full of -INF
+ float smax = -INFINITY;
+
+ // online softmax
+ {
+ float ms[Q];
+
+ for (short j = 0; j < Q; ++j) {
+ const short p = tiisg;
+
+ const float m = M[j];
+ const float s = ss[j*TF + p];
+
+ smax = simd_max(max(smax, s));
+ M[j] = simd_max(max(M[j], s));
+
+ ms[j] = exp(m - M[j]);
+ const float vs = exp(s - M[j]);
+
+ S[j] = S[j]*ms[j] + simd_sum(vs);
+
+ // the P matrix from the paper (Q rows, C columns)
+ ss[j*TF + p] = vs;
+ }
+
+ // create a QxQ diagonal matrix for rescaling the output
+ if (tiisg < Q) {
+ ss[tiisg*TF + C + tiisg] = ms[tiisg];
+ }
+ }
+
+ // skip -INF blocks
+ if (smax == -INFINITY) {
+ continue;
+ }
+
+ // O = diag(ms)*O
+ {
+ simdgroup_float8x8 mm;
+ simdgroup_load(mm, ss + C, TF, 0, false);
+
+ for (short i = 0; i < D8; ++i) {
+ simdgroup_multiply(lo[i], mm, lo[i]);
+ }
+ }
+
+ // O = O + (Q*K^T)*V
+ {
+ for (short cc = 0; cc < C/8; ++cc) {
+ device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
+
+ for (short i = 0; i < D8; ++i) {
+ simdgroup_half8x8 mk;
+ simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
+
+ simdgroup_float8x8 mv;
+ simdgroup_load(mv, ss + 8*cc, TF, 0, false);
+
+ simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]);
+ }
+ }
+ }
+ }
+
+ // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
+ for (short j = 0; j < Q; ++j) {
+ if (tiisg == 0) {
+ ss[j*TF + 0] = S[j];
+ ss[j*TF + 1] = M[j];
+ }
+ }
+ }
+
+ // reduce the warps sequentially
+ for (short sg = 1; sg < nsg; ++sg) {
+ float S = { 0.0h };
+ float M = { -FLT_MAX/2 };
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // each simdgroup stores its output to shared memory, reusing sq
+ if (sgitg == sg) {
+ for (short i = 0; i < D8; ++i) {
+ simdgroup_store(lo[i], sq + i*8, T, 0, false);
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // the first simdgroup accumulates the results from the other simdgroups
+ if (sgitg == 0) {
+ for (short j = 0; j < Q; ++j) {
+ const float S0 = ss[j*TF + 0];
+ const float S1 = ss[j*TF + sg*SH + 0];
+
+ const float M0 = ss[j*TF + 1];
+ const float M1 = ss[j*TF + sg*SH + 1];
+
+ M = max(M0, M1);
+
+ const float ms0 = exp(M0 - M);
+ const float ms1 = exp(M1 - M);
+
+ S = S0*ms0 + S1*ms1;
+
+ if (tiisg == 0) {
+ ss[j*TF + 0] = S;
+ ss[j*TF + 1] = M;
+
+ ss[j*TF + C + j ] = ms0;
+ ss[j*TF + C + j + sg*SH] = ms1;
+ }
+ }
+
+ // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
+ {
+ simdgroup_half8x8 t;
+ simdgroup_float8x8 ms0;
+ simdgroup_float8x8 ms1;
+
+ simdgroup_load(ms0, ss + C, TF, 0, false);
+ simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false);
+
+ for (short i = 0; i < D8; ++i) {
+ simdgroup_load (t, sq + i*8, T, 0, false);
+ simdgroup_multiply(t, ms1, t);
+
+ simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
+ }
+ }
+ }
+ }
+
+ // store result to shared memory (reuse sq)
+ if (sgitg == 0) {
+ for (short i = 0; i < D8; ++i) {
+ simdgroup_store(lo[i], sq + i*8, T, 0, false);
+ }
+ }
+
+ device float4 * dst4 = (device float4 *) dst;
+
+ // final rescale with 1/S and store to global memory
+ if (sgitg == 0) {
+ for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
+ const float S = ss[j*TF + 0];
+
+ for (short i = tiisg; i < D4; i += NW) {
+ dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
+ }
+ }
+ }
+}
+
+template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>;
+template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>;
+template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
+template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
+template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
+template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
+
+template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
+kernel void kernel_flash_attn_ext_vec_f16(
+ device const char * q,
+ device const char * k,
+ device const char * v,
+ device const char * mask,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant int64_t & ne31,
+ constant uint64_t & nb31,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant float & scale,
+ threadgroup half * shared [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ const short nsg = ntg.y; // number of simdgroups
+
+ const short iq3 = tgpig[2];
+ const short iq2 = tgpig[1];
+ const short iq1 = tgpig[0];
+
+ const short D4 = D/4;
+ const short NW = N_SIMDWIDTH;
+ const short SH = (C + Q); // shared memory per simdgroup in (half)
+
+ const short T = D + 2*nsg*SH; // shared memory size per query in (half)
+
+ //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
+ threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
+ threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
+ threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
+ threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results
+
+ // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
+ half4 lo[D4/NW];
+
+ // load heads from Q to shared memory
+ device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
+
+ for (short i = tiisg; i < D4; i += NW) {
+ if (iq1 < ne01) {
+ sq4[i] = (half4) q4[i];
+ } else {
+ sq4[i] = 0.0h;
+ }
+ }
+
+ // zero out lo
+ for (short i = tiisg; i < D4; i += NW) {
+ lo[i/NW] = 0.0h;
+ }
+
+ // zero out shared memory SH
+ for (short i = tiisg; i < SH/4; i += NW) {
+ ss4[i] = 0.0h;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ {
+ float S = { 0.0h };
+ float M = { -FLT_MAX/2 };
+
+ // assume K and V are same shape
+ const short ne22 = ne12;
+ const short ne23 = ne13;
+
+ const uint nb21 = nb11;
+ const uint nb22 = nb12;
+ const uint nb23 = nb13;
+
+ // broadcast
+ const short rk2 = ne02/ne12;
+ const short rk3 = ne03/ne13;
+
+ const short rv2 = ne02/ne22;
+ const short rv3 = ne03/ne23;
+
+ // k indices
+ const short ik2 = iq2 / rk2;
+ const short ik3 = iq3 / rk3;
+
+ // v indices
+ const short iv2 = iq2 / rv2;
+ const short iv3 = iq3 / rv3;
+
+ // load the queries from shared memory into local memory
+ half4 mq[D4];
+
+ for (short ii = 0; ii < D4; ii += NW) {
+ short i = ii + tiisg;
+ mq[i] = sq4[i];
+ }
+
+ // pointer to the mask
+ device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
+
+ // loop over the KV cache
+ // each simdgroup handles blocks of Q rows and C columns
+ for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
+ const int ic = ic0 + C*sgitg;
+ if (ic >= ne11) {
+ break;
+ }
+
+ // Q*K^T
+ {
+#pragma unroll
+ for (short cc = 0; cc < C/4; ++cc) {
+ float4 mqk = { 0.0h };
+
+ device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
+
+#pragma unroll
+ for (short ii = 0; ii < D4; ii += NW) {
+ const short i = ii + tiisg;
+
+ half4x4 mk;
+ mk[0] = pk4[i + 0*(nb11/8)];
+ mk[1] = pk4[i + 1*(nb11/8)];
+ mk[2] = pk4[i + 2*(nb11/8)];
+ mk[3] = pk4[i + 3*(nb11/8)];
+
+ mqk += (float4) (mq[i] * mk);
+ }
+
+ // reduce the results from the threads in the simdgroup
+ mqk += simd_shuffle_down(mqk, 16);
+ mqk += simd_shuffle_down(mqk, 8);
+ mqk += simd_shuffle_down(mqk, 4);
+ mqk += simd_shuffle_down(mqk, 2);
+ mqk += simd_shuffle_down(mqk, 1);
+
+ // mqk = mqk*scale + mask
+ if (tiisg == 0) {
+ float4 mm = (float4) mp4[ic/4 + cc];
+ mqk = mqk*scale + mm;
+
+ ss4[cc] = mqk;
+ }
+ }
+ }
+
+ // online softmax
+ {
+ const short p = tiisg;
+
+ const float m = M;
+ const float s = ss[p];
+
+ M = simd_max(max(M, s));
+
+ const float ms = exp(m - M);
+ const float vs = exp(s - M);
+
+ S = S*ms + simd_sum(vs);
+
+ // the P matrix from the paper (Q rows, C columns)
+ ss[p] = vs;
+
+ // O = diag(ms)*O
+#pragma unroll
+ for (short ii = 0; ii < D4; ii += NW) {
+ const short i = ii + tiisg;
+ lo[i/NW] *= ms;
+ }
+ }
+
+ // O = O + (Q*K^T)*V
+ {
+#pragma unroll
+ for (short cc = 0; cc < C/4; ++cc) {
+ device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23));
+
+#pragma unroll
+ for (short ii = 0; ii < D4; ii += NW) {
+ const short i = ii + tiisg;
+
+ lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
+ lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
+ lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
+ lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
+ }
+ }
+ }
+
+ }
+
+ // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
+ if (tiisg == 0) {
+ ss[0] = S;
+ ss[1] = M;
+ }
+ }
+
+ // store results to shared memory
+ for (short ii = 0; ii < D4; ii += NW) {
+ short i = ii + tiisg;
+ sr4[i] = lo[ii/NW];
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // parallel reduce
+ for (short r = nsg/2; r > 0; r >>= 1) {
+ if (sgitg < r) {
+ const float S0 = ss[ 0];
+ const float S1 = ss[r*SH + 0];
+
+ const float M0 = ss[ 1];
+ const float M1 = ss[r*SH + 1];
+
+ const float M = max(M0, M1);
+
+ const float ms0 = exp(M0 - M);
+ const float ms1 = exp(M1 - M);
+
+ const float S = S0*ms0 + S1*ms1;
+
+ if (tiisg == 0) {
+ ss[0] = S;
+ ss[1] = M;
+ }
+
+ // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
+ for (short ii = 0; ii < D4; ii += NW) {
+ short i = ii + tiisg;
+ sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+
+ device float4 * dst4 = (device float4 *) dst;
+
+ // final rescale with 1/S and store to global memory
+ if (sgitg == 0) {
+ const float S = ss[0];
+
+ for (short ii = 0; ii < D4; ii += NW) {
+ short i = ii + tiisg;
+ dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
+ }
+ }
+}
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
+template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
+
kernel void kernel_cpy_f16_f16(
device const half * src0,
device half * dst,
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
+ const ggml_tensor * src2 = dst->src[2];
+
+#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support")
+#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
+ GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
const int64_t ne00 = src0->ne[0];
const int64_t nrows_x = ggml_nrows(src0);
float * src2_dd = nullptr;
sycl_pool_alloc<float> src2_f;
- ggml_tensor * src2 = dst->src[2];
const bool use_src2 = src2 != nullptr;
if (use_src2) {
}
return nullptr;
case GGML_OP_SOFT_MAX:
+#pragma message("TODO: add ggml_vk_soft_max() F16 src1 and src2 support")
+#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
+ GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
+
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_soft_max_f32;
}
#define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
#define GGML_F16_VEC_SET1 GGML_F16x8_SET1
#define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p)
- #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i])
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), r[i])
#define GGML_F16_VEC_FMA GGML_F16x8_FMA
#define GGML_F16_VEC_ADD GGML_F16x8_ADD
#define GGML_F16_VEC_MUL GGML_F16x8_MUL
#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
- #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i])
#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
// unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F
// so F16C guard isn't required
-#define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(x)))
+#define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
#define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0))
#define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
#if defined(__F16C__)
// the _mm256_cvt intrinsics require F16C
-#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
+#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
#else
static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
#endif
}
+inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) {
+#if defined(GGML_SIMD)
+ const int np = (n & ~(GGML_F16_STEP - 1));
+
+ GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
+
+ GGML_F16_VEC ax[GGML_F16_ARR];
+ GGML_F16_VEC ay[GGML_F16_ARR];
+
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
+ for (int j = 0; j < GGML_F16_ARR; j++) {
+ ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+ ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
+
+ GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
+ }
+ }
+
+ // leftovers
+ for (int i = np; i < n; ++i) {
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
+ }
+#else
+ // scalar
+ for (int i = 0; i < n; ++i) {
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
+ }
+#endif
+}
+
// xs and vs are byte strides of x and v
inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
#endif
}
+inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
+#if defined(GGML_SIMD)
+ const int np = (n & ~(GGML_F16_STEP - 1));
+
+ GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
+
+ GGML_F16_VEC ay[GGML_F16_ARR];
+
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
+ for (int j = 0; j < GGML_F16_ARR; j++) {
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+ ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
+
+ GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
+ }
+ }
+
+ // leftovers
+ for (int i = np; i < n; ++i) {
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
+ }
+#else
+ // scalar
+ for (int i = 0; i < n; ++i) {
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
+ }
+#endif
+}
+
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
"LEAKY_RELU",
"FLASH_ATTN",
+ "FLASH_ATTN_EXT",
"FLASH_FF",
"FLASH_ATTN_BACK",
"SSM_CONV",
"CROSS_ENTROPY_LOSS_BACK",
};
-static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
+static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
"leaky_relu(x)",
"flash_attn(x)",
+ "flash_attn_ext(x)",
"flash_ff(x)",
"flash_attn_back(x)",
"ssm_conv(x)",
"cross_entropy_loss_back(x,y)",
};
-static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
+static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
void ggml_mul_mat_set_prec(
struct ggml_tensor * a,
enum ggml_prec prec) {
+ GGML_ASSERT(a->op == GGML_OP_MUL_MAT);
+
const int32_t prec_i32 = (int32_t) prec;
ggml_set_op_params_i32(a, 0, prec_i32);
GGML_ASSERT(ggml_is_contiguous(a));
if (mask) {
+ GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(mask));
GGML_ASSERT(ggml_is_matrix(mask));
- GGML_ASSERT(ggml_can_repeat_rows(mask, a));
+ GGML_ASSERT(mask->ne[0] == a->ne[0]);
+ GGML_ASSERT(mask->ne[1] >= a->ne[1]);
}
if (pos) {
GGML_ASSERT(ggml_is_vector(pos));
- GGML_ASSERT(pos->type == GGML_TYPE_F32);
+ GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32);
GGML_ASSERT(pos->ne[0] == a->ne[0]);
}
+ if (pos && mask) {
+ GGML_ASSERT(pos->type == mask->type);
+ }
+
if (max_bias > 0.0f) {
GGML_ASSERT(pos);
}
return result;
}
+// ggml_flash_attn_ext
+
+struct ggml_tensor * ggml_flash_attn_ext(
+ struct ggml_context * ctx,
+ struct ggml_tensor * q,
+ struct ggml_tensor * k,
+ struct ggml_tensor * v,
+ struct ggml_tensor * mask,
+ float scale) {
+ GGML_ASSERT(ggml_can_mul_mat(k, q));
+ // TODO: check if vT can be multiplied by (k*qT)
+ if (mask) {
+ GGML_ASSERT(ggml_is_contiguous(mask));
+ GGML_ASSERT(mask->ne[2] == 1);
+ GGML_ASSERT(mask->ne[3] == 1);
+ GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
+ "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
+ //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
+ }
+
+ bool is_node = false;
+
+ if (q->grad || k->grad || v->grad) {
+ is_node = true;
+ }
+
+ // permute(0, 2, 1, 3)
+ int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+ float params[] = { scale };
+ ggml_set_op_params(result, params, sizeof(params));
+
+ result->op = GGML_OP_FLASH_ATTN_EXT;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = q;
+ result->src[1] = k;
+ result->src[2] = v;
+ result->src[3] = mask;
+
+ return result;
+}
+
+void ggml_flash_attn_ext_set_prec(
+ struct ggml_tensor * a,
+ enum ggml_prec prec) {
+ GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
+
+ const int32_t prec_i32 = (int32_t) prec;
+
+ ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos
+}
+
// ggml_flash_ff
struct ggml_tensor * ggml_flash_ff(
GGML_TENSOR_UNARY_OP_LOCALS
- const int64_t ne11 = src1 ? src1->ne[1] : 1;
+ //const int64_t ne11 = src1 ? src1->ne[1] : 1;
// TODO: is this supposed to be ceil instead of floor?
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
// when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
- float * pos = src2 ? (float *) src2->data : src0->data;
+ ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data;
+ float * pos_f32 = src2 ? (float *) src2->data : src0->data;
+
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
for (int i1 = ir0; i1 < ir1; i1++) {
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
// broadcast the mask across rows
- float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
+ ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
+ float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
ggml_vec_cpy_f32 (nc, wp, sp);
ggml_vec_scale_f32(nc, wp, scale);
- if (mp) {
- ggml_vec_acc_f32(nc, wp, mp);
+ if (mp_f32) {
+ if (use_f16) {
+ for (int i = 0; i < nc; ++i) {
+ wp[i] += GGML_FP16_TO_FP32(mp_f16[i]);
+ }
+ } else {
+ for (int i = 0; i < nc; ++i) {
+ wp[i] += mp_f32[i];
+ }
+ }
}
// ALiBi bias
const uint32_t h = (i1/ne01)%ne02; // head
const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
- for (int i = 0; i < nc; i++) {
- wp[i] = wp[i] + slope*pos[i];
+ if (use_f16) {
+ for (int i = 0; i < nc; ++i) {
+ wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]);
+ }
+ } else {
+ for (int i = 0; i < nc; ++i) {
+ wp[i] += slope*pos_f32[i];
+ }
}
}
}
}
+// ggml_compute_forward_flash_attn_ext
+
+static void ggml_compute_forward_flash_attn_ext_f16(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * q,
+ const struct ggml_tensor * k,
+ const struct ggml_tensor * v,
+ const struct ggml_tensor * mask,
+ struct ggml_tensor * dst) {
+ int64_t t0 = ggml_perf_time_us();
+ UNUSED(t0);
+
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int64_t D = neq0;
+ const int64_t N = neq1;
+
+ GGML_ASSERT(ne0 == D);
+ GGML_ASSERT(ne2 == N);
+
+ GGML_ASSERT(nbq0 == sizeof(float));
+ GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
+ GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
+
+ GGML_ASSERT(neq0 == D);
+ GGML_ASSERT(nek0 == D);
+ GGML_ASSERT(nev0 == D);
+
+ GGML_ASSERT(neq1 == N);
+ GGML_ASSERT(nev0 == D);
+
+ // dst cannot be transposed or permuted
+ GGML_ASSERT(nb0 == sizeof(float));
+ GGML_ASSERT(nb0 <= nb1);
+ GGML_ASSERT(nb1 <= nb2);
+ GGML_ASSERT(nb2 <= nb3);
+
+ // broadcast factors
+ const int64_t rk2 = neq2/nek2;
+ const int64_t rk3 = neq3/nek3;
+
+ const int64_t rv2 = neq2/nev2;
+ const int64_t rv3 = neq3/nev3;
+
+ if (params->type == GGML_TASK_TYPE_INIT) {
+ return;
+ }
+
+ if (params->type == GGML_TASK_TYPE_FINALIZE) {
+ return;
+ }
+
+ // parallelize by q rows using ggml_vec_dot_f32
+
+ // total rows in q
+ const int nr = neq1*neq2*neq3;
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ float scale = 1.0f;
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
+
+ // loop over n_batch and n_head
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // q indices
+ const int iq3 = ir/(neq2*neq1);
+ const int iq2 = (ir - iq3*neq2*neq1)/neq1;
+ const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
+
+ float S = 0.0f;
+ float M = -INFINITY;
+
+ float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
+ ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory
+ ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
+
+ memset(V16, 0, D*sizeof(ggml_fp16_t));
+
+ const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
+
+ // k indices
+ const int ik3 = iq3 / rk3;
+ const int ik2 = iq2 / rk2;
+
+ // v indices
+ const int iv3 = iq3 / rv3;
+ const int iv2 = iq2 / rv2;
+
+ // online softmax / attention
+ // loop over n_kv and n_head_kv
+ // ref: https://arxiv.org/pdf/2112.05682.pdf
+ for (int64_t ic = 0; ic < nek1; ++ic) {
+ const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
+ if (mv == -INFINITY) {
+ continue;
+ }
+
+ float s;
+
+ // convert Q to F16 in V32
+ {
+ const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
+
+ for (int64_t d = 0; d < D; ++d) {
+ Q16[d] = GGML_FP32_TO_FP16(pq[d]);
+ }
+ }
+
+ ggml_vec_dot_f16(D,
+ &s, 0,
+ (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
+ Q16, 0, 1);
+
+ s = s*scale + mv;
+
+ const float Mold = M;
+
+ float ms = 1.0f;
+ float vs = 1.0f;
+
+ if (s > M) {
+ M = s;
+ ms = expf(Mold - M);
+
+ // V = V*expf(Mold - M)
+ ggml_vec_scale_f16(D, V16, ms);
+ } else {
+ vs = expf(s - M);
+ }
+
+ const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
+
+ // V += v*expf(s - M)
+ ggml_vec_mad_f16(D, V16, v16, vs);
+
+ S = S*ms + vs;
+ }
+
+ // V /= S
+ for (int64_t d = 0; d < D; ++d) {
+ V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
+ }
+
+ // dst indices
+ const int i1 = iq1;
+ const int i2 = iq2;
+ const int i3 = iq3;
+
+ // original
+ //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
+
+ // permute(0, 2, 1, 3)
+ memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1);
+ }
+}
+
+static void ggml_compute_forward_flash_attn_ext(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * q,
+ const struct ggml_tensor * k,
+ const struct ggml_tensor * v,
+ const struct ggml_tensor * mask,
+ struct ggml_tensor * dst) {
+ switch (dst->op_params[1]) {
+ case GGML_PREC_DEFAULT:
+ case GGML_PREC_F32:
+ {
+ // uses F32 accumulators
+ ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
// ggml_compute_forward_flash_ff
static void ggml_compute_forward_flash_ff_f16(
const bool masked = t != 0;
ggml_compute_forward_flash_attn(params, masked, tensor);
} break;
+ case GGML_OP_FLASH_ATTN_EXT:
+ {
+ ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
+ } break;
case GGML_OP_FLASH_FF:
{
ggml_compute_forward_flash_ff(params, tensor);
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_FLASH_ATTN:
+ case GGML_OP_FLASH_ATTN_EXT:
{
struct ggml_tensor * flash_grad = NULL;
if (src0->grad || src1->grad || tensor->src[2]->grad) {
n_tasks = n_threads;
} break;
case GGML_OP_FLASH_ATTN:
+ case GGML_OP_FLASH_ATTN_EXT:
{
n_tasks = n_threads;
} break;
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
}
} break;
+ case GGML_OP_FLASH_ATTN_EXT:
+ {
+ const int64_t ne00 = node->src[0]->ne[0]; // D
+
+ cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size
+ } break;
case GGML_OP_FLASH_FF:
{
if (node->src[1]->type == GGML_TYPE_F32) {
GGML_OP_LEAKY_RELU,
GGML_OP_FLASH_ATTN,
+ GGML_OP_FLASH_ATTN_EXT,
GGML_OP_FLASH_FF,
GGML_OP_FLASH_ATTN_BACK,
GGML_OP_SSM_CONV,
struct ggml_tensor * v,
bool masked);
+#define GGML_KQ_MASK_PAD 32
+
+ // q: [n_embd, n_batch, n_head, 1]
+ // k: [n_embd, n_kv, n_head_kv, 1]
+ // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
+ // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
+ // res: [n_embd, n_head, n_batch, 1] !! permuted !!
+ GGML_API struct ggml_tensor * ggml_flash_attn_ext(
+ struct ggml_context * ctx,
+ struct ggml_tensor * q,
+ struct ggml_tensor * k,
+ struct ggml_tensor * v,
+ struct ggml_tensor * mask,
+ float scale);
+
+ GGML_API void ggml_flash_attn_ext_set_prec(
+ struct ggml_tensor * a,
+ enum ggml_prec prec);
+
GGML_API struct ggml_tensor * ggml_flash_attn_back(
struct ggml_context * ctx,
struct ggml_tensor * q,