GGML_OP_SOFT_MAX_BACK,
GGML_OP_ROPE,
GGML_OP_ROPE_BACK,
- GGML_OP_ALIBI,
GGML_OP_CLAMP,
GGML_OP_CONV_TRANSPOSE_1D,
GGML_OP_IM2COL,
struct ggml_context * ctx,
struct ggml_tensor * a);
- // fused soft_max(a*scale + mask + pos[i]*(ALiBi slope))
+ // fused soft_max(a*scale + mask*(ALiBi slope))
// mask is optional
- // pos is required when max_bias > 0.0f
// max_bias = 0.0f for no ALiBi
GGML_API struct ggml_tensor * ggml_soft_max_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * mask,
- struct ggml_tensor * pos,
float scale,
float max_bias);
float xpos_base,
bool xpos_down);
- // alibi position embedding
- // in-place, returns view(a)
- GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_alibi(
- struct ggml_context * ctx,
- struct ggml_tensor * a,
- int n_past,
- int n_head,
- float bias_max),
- "use ggml_soft_max_ext instead (will be removed in Mar 2024)");
-
// clamp
// in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_clamp(
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * mask,
- float scale);
+ float scale,
+ float max_bias);
GGML_API void ggml_flash_attn_ext_set_prec(
struct ggml_tensor * a,
#include "ggml-cuda/common.cuh"
#include "ggml-cuda/acc.cuh"
-#include "ggml-cuda/alibi.cuh"
#include "ggml-cuda/arange.cuh"
#include "ggml-cuda/argsort.cuh"
#include "ggml-cuda/binbcast.cuh"
case GGML_OP_ROPE:
ggml_cuda_op_rope(ctx, dst);
break;
- case GGML_OP_ALIBI:
- ggml_cuda_op_alibi(ctx, dst);
- break;
case GGML_OP_IM2COL:
ggml_cuda_op_im2col(ctx, dst);
break;
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_ROPE:
- case GGML_OP_ALIBI:
case GGML_OP_IM2COL:
case GGML_OP_POOL_2D:
case GGML_OP_SUM_ROWS:
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
+ const float max_bias,
+ const float m0,
+ const float m1,
+ const uint32_t n_head_log2,
const int ne00,
const int ne01,
const int ne02,
const int stride_KV = nb11 / sizeof(half);
const int stride_KV2 = nb11 / sizeof(half2);
+ half slopeh = __float2half(1.0f);
+
+ // ALiBi
+ if (max_bias > 0.0f) {
+ const int h = blockIdx.y;
+
+ const float base = h < n_head_log2 ? m0 : m1;
+ const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+ slopeh = __float2half(powf(base, exph));
+ }
+
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
constexpr int nwarps = D / WARP_SIZE;
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
for (int j = 0; j < ncols; ++j) {
sum2[j] = warp_reduce_sum(sum2[j]);
half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
- sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
+ sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
if (ncols == 1) {
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
+ const float max_bias,
+ const float m0,
+ const float m1,
+ const uint32_t n_head_log2,
const int ne00,
const int ne01,
const int ne02,
const int stride_Q = nb01 / sizeof(float);
const int stride_KV = nb11 / sizeof(half);
+ half slopeh = __float2half(1.0f);
+ half2 slope2 = make_half2(1.0f, 1.0f);
+
+ // ALiBi
+ if (max_bias > 0.0f) {
+ const int h = blockIdx.y;
+
+ const float base = h < n_head_log2 ? m0 : m1;
+ const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+ slopeh = __float2half(powf(base, exph));
+ slope2 = make_half2(slopeh, slopeh);
+ }
+
frag_b Q_b[D/16][ncols/frag_n];
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
- KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
+ KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
}
KQ_max_new = warp_reduce_max(KQ_max_new);
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
- KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
+ KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
}
KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
const int shmem = 0;
- float scale;
- memcpy(&scale, KQV->op_params, sizeof(float));
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+
+ memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
+
+ const uint32_t n_head = Q->ne[2];
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) V->data,
mask ? ((const char *) mask->data) : nullptr,
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
- scale,
+ scale, max_bias, m0, m1, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
const int shmem = 0;
- float scale;
- memcpy(&scale, KQV->op_params, sizeof(float));
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+
+ memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
+
+ const uint32_t n_head = Q->ne[2];
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) V->data,
mask ? ((const char *) mask->data) : nullptr,
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
- scale,
+ scale, max_bias, m0, m1, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
- const int32_t precision = KQV->op_params[1];
+ const int32_t precision = KQV->op_params[2];
if (!fp16_mma_available(cc)) {
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
}
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
-static __global__ void soft_max_f32(const float * x, const T * mask, const T * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
+static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
const int tid = threadIdx.x;
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
- float slope = 0.0f;
+ float slope = 1.0f;
// ALiBi
if (max_bias > 0.0f) {
const int h = rowx/nrows_y; // head index
const float base = h < n_head_log2 ? m0 : m1;
- const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+ const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
- slope = powf(base, exp);
+ slope = powf(base, exph);
}
extern __shared__ float data_soft_max_f32[];
const int64_t ix = (int64_t)rowx*ncols + col;
const int64_t iy = (int64_t)rowy*ncols + col;
- const float val = x[ix]*scale + (mask ? t2f32(mask[iy]) : 0.0f) + (pos ? slope*t2f32(pos[col]) : 0.0f);
+ const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
vals[col] = val;
max_val = max(max_val, val);
}
template<typename T>
-static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
+static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
int nth = WARP_SIZE;
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1);
const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
- const uint32_t n_head_kv = nrows_x/nrows_y;
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
+ const uint32_t n_head = nrows_x/nrows_y;
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
switch (ncols_x) {
case 32:
- soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 64:
- soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 128:
- soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 256:
- soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 512:
- soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 1024:
- soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 2048:
- soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 4096:
- soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
default:
- soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
}
} else {
const size_t shmem_low = WARP_SIZE*sizeof(float);
- soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
}
}
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
- const ggml_tensor * src2 = dst->src[2];
const float * src0_d = (const float *)src0->data;
const void * src1_d = src1 ? (const void *)src1->data : nullptr;
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
- GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
const int64_t ne00 = src0->ne[0];
const int64_t nrows_x = ggml_nrows(src0);
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
- // positions tensor
- void * src2_d = nullptr;
-
- const bool use_src2 = src2 != nullptr;
-
- if (use_src2) {
- src2_d = (void *)src2->data;
- }
-
- const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
if (use_f16) {
const half * src1_dd = (const half *)src1_d;
- const half * src2_dd = (const half *)src2_d;
- soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+ soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
} else {
const float * src1_dd = (const float *)src1_d;
- const float * src2_dd = (const float *)src2_d;
- soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+ soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
}
}
case GGML_OP_SOFT_MAX:
{
float scale;
- memcpy(&scale, dst->op_params, sizeof(float));
+ float max_bias;
-#pragma message("TODO: add ggml_vk_soft_max() F16/F32 src1 and src2 support")
+ memcpy(&scale, (float *)dst->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (float *)dst->op_params + 1, sizeof(float));
+
+#pragma message("TODO: add ggml_vk_soft_max() F16 src1 support")
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
- GGML_ASSERT(src2 == nullptr);
+
+#pragma message("TODO: add ALiBi support")
+#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
+ GGML_ASSERT(max_bias == 0.0f);
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
} break;
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
GGML_METAL_KERNEL_TYPE_ROPE_F32,
GGML_METAL_KERNEL_TYPE_ROPE_F16,
- GGML_METAL_KERNEL_TYPE_ALIBI_F32,
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
case GGML_OP_GROUP_NORM:
return ctx->support_simdgroup_reduction;
case GGML_OP_NORM:
- case GGML_OP_ALIBI:
case GGML_OP_ROPE:
case GGML_OP_IM2COL:
return true;
case GGML_OP_SOFT_MAX:
{
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
- GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
int nth = 32; // SIMD width
id<MTLComputePipelineState> pipeline = nil;
- const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
if (ne00%4 == 0) {
while (nth < ne00/4 && nth < 256) {
const int64_t nrows_x = ggml_nrows(src0);
const int64_t nrows_y = src0->ne[1];
- const uint32_t n_head_kv = nrows_x/nrows_y;
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
+ const uint32_t n_head = nrows_x/nrows_y;
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
} else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
}
- if (id_src2) {
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
- } else {
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
- }
- [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
- [encoder setBytes:&scale length:sizeof(scale) atIndex:7];
- [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
- [encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
- [encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
+ [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
- case GGML_OP_ALIBI:
- {
- GGML_ASSERT((src0t == GGML_TYPE_F32));
-
- const int nth = MIN(1024, ne00);
-
- //const int n_past = ((int32_t *) dst->op_params)[0];
- const int n_head = ((int32_t *) dst->op_params)[1];
-
- float max_bias;
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
-
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
-
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
- [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
- [encoder setBytes:&m1 length:sizeof( float) atIndex:19];
- [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
case GGML_OP_ROPE:
{
GGML_ASSERT(ne10 == ne02);
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
- const int64_t ne31 = src3 ? src3->ne[1] : 0;
+ //const int64_t ne31 = src3 ? src3->ne[1] : 0;
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
float scale;
- memcpy(&scale, dst->op_params, sizeof(float));
+ float max_bias;
+
+ memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
+ memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
+
+ const uint32_t n_head = src0->ne[2];
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
id<MTLComputePipelineState> pipeline = nil;
}
[encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
- [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
- [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13];
- [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14];
- [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15];
- [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16];
- [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17];
- [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
- [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
- [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
- [encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21];
- [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22];
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23];
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24];
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25];
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
- [encoder setBytes:&scale length:sizeof( float) atIndex:27];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
+ [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13];
+ [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14];
+ [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15];
+ [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16];
+ [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17];
+ [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
+ [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
+ [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
+ [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:21];
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:22];
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:23];
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:24];
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:25];
+ [encoder setBytes:&scale length:sizeof( float) atIndex:26];
+ [encoder setBytes:&max_bias length:sizeof( float) atIndex:27];
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:28];
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:29];
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:30];
if (!use_vec_kernel) {
// half8x8 kernel
kernel void kernel_soft_max(
device const char * src0,
device const char * src1,
- device const char * src2,
device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
- device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
- float slope = 0.0f;
+ float slope = 1.0f;
// ALiBi
if (max_bias > 0.0f) {
float lmax = -INFINITY;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
- lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
+ lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
}
// find the max value in the block
// parallel sum
float lsum = 0.0f;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
- const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
+ const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
lsum += exp_psrc0;
pdst[i00] = exp_psrc0;
}
kernel void kernel_soft_max_4(
device const char * src0,
device const char * src1,
- device const char * src2,
device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
- device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
- float slope = 0.0f;
+ float slope = 1.0f;
if (max_bias > 0.0f) {
const int64_t h = i02;
float4 lmax4 = -INFINITY;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
- lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)));
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
}
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
// parallel sum
float4 lsum4 = 0.0f;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
- const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val);
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
}
}
-kernel void kernel_alibi_f32(
- device const float * src0,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- constant float & m0,
- constant float & m1,
- constant int & n_heads_log2_floor,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
-
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-
- const int64_t i3 = n / (ne2*ne1*ne0);
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- //const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
-
- const int64_t k = i3*ne3 + i2;
-
- float m_k;
- if (k < n_heads_log2_floor) {
- m_k = pow(m0, k + 1);
- } else {
- m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
- }
-
- device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
- device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
- const float src_v = *(device float *)(src_row + i00*nb00);
- device float * dst_v = (device float *)(dst_row + i00*nb0);
- *dst_v = i00 * m_k + src_v;
- }
-}
-
static float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
- constant int64_t & ne31,
constant uint64_t & nb31,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant float & scale,
+ constant float & max_bias,
+ constant float & m0,
+ constant float & m1,
+ constant uint32_t & n_head_log2,
threadgroup half * shared,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
- constant int64_t & ne31,
constant uint64_t & nb31,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant float & scale,
+ constant float & max_bias,
+ constant float & m0,
+ constant float & m1,
+ constant uint32_t & n_head_log2,
threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
// prepare diagonal scale matrix
simdgroup_float8x8 mscale(scale);
+ // prepare diagonal slope matrix
+ simdgroup_float8x8 mslope(1.0f);
+
+ // ALiBi
+ if (max_bias > 0.0f) {
+ const short h = iq2;
+
+ const float base = h < n_head_log2 ? m0 : m1;
+ const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+ mslope = simdgroup_float8x8(pow(base, exph));
+ }
+
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
}
- // mqk = mqk*scale + mask
+ // mqk = mqk*scale + mask*slope
simdgroup_half8x8 mm;
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
+ simdgroup_multiply(mm, mslope, mm);
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
- constant int64_t & ne31,
constant uint64_t & nb31,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant float & scale,
+ constant float & max_bias,
+ constant float & m0,
+ constant float & m1,
+ constant uint32_t & n_head_log2,
threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
+ float slope = 1.0f;
+
+ // ALiBi
+ if (max_bias > 0.0f) {
+ const short h = iq2;
+
+ const float base = h < n_head_log2 ? m0 : m1;
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+ slope = pow(base, exp);
+ }
+
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
mqk += simd_shuffle_down(mqk, 2);
mqk += simd_shuffle_down(mqk, 1);
- // mqk = mqk*scale + mask
+ // mqk = mqk*scale + mask*slope
if (tiisg == 0) {
float4 mm = (float4) mp4[ic/4 + cc];
- mqk = mqk*scale + mm;
+ mqk = mqk*scale + mm*slope;
ss4[cc] = mqk;
}
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
- dst_data[i00] = src[0];
+ // TODO: is there a better way to handle -INFINITY?
+ dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0];
}
}
#define SYCL_SCALE_BLOCK_SIZE 256
#define SYCL_CLAMP_BLOCK_SIZE 256
#define SYCL_ROPE_BLOCK_SIZE 256
-#define SYCL_ALIBI_BLOCK_SIZE 32
#define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32
#define SYCL_QUANTIZE_BLOCK_SIZE 256
#define SYCL_DEQUANTIZE_BLOCK_SIZE 256
dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
}
-static void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows,
- const int n_heads_log2_floor, const float m0, const float m1,
- const sycl::nd_item<3> &item_ct1) {
- const int col = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
- item_ct1.get_local_id(2);
-
- if (col >= ncols) {
- return;
- }
-
- const int row = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
- item_ct1.get_local_id(1);
- const int i = row*ncols + col;
-
- const int k = row/k_rows;
-
- float m_k;
- if (k < n_heads_log2_floor) {
- m_k = dpct::pow(m0, k + 1);
- } else {
- m_k = dpct::pow(m1, 2 * (k - n_heads_log2_floor) + 1);
- }
-
- dst[i] = col * m_k + x[i];
-}
-
static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
const sycl::nd_item<3> &item_ct1) {
const int row = item_ct1.get_group(1);
template <bool vals_smem, int ncols_template, int block_size_template>
-static void soft_max_f32(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
+static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
const int nrows_y, const float scale, const float max_bias, const float m0,
const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
- float slope = 0.0f;
+ float slope = 1.0f;
// ALiBi
if (max_bias > 0.0f) {
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
- const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
+ const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f);
vals[col] = val;
max_val = sycl::max(max_val, val);
});
}
-static void alibi_f32_sycl(const float *x, float *dst, const int ncols,
- const int nrows, const int k_rows,
- const int n_heads_log2_floor, const float m0,
- const float m1, dpct::queue_ptr stream) {
- const sycl::range<3> block_dims(1, 1, SYCL_ALIBI_BLOCK_SIZE);
- const int num_blocks_x = (ncols + SYCL_ALIBI_BLOCK_SIZE - 1) / (SYCL_ALIBI_BLOCK_SIZE);
- const sycl::range<3> block_nums(1, nrows, num_blocks_x);
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
- [=](sycl::nd_item<3> item_ct1) {
- alibi_f32(x, dst, ncols, k_rows,
- n_heads_log2_floor, m0, m1, item_ct1);
- });
-}
-
static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
const int nrows, dpct::queue_ptr stream) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
}
template <bool vals_smem, int ncols_template, int block_size_template>
-static void soft_max_f32_submitter(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
+static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par,
const int nrows_y, const float scale, const float max_bias, const float m0,
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
const size_t n_local_scratch, dpct::queue_ptr stream) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
- soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, pos, dst, ncols_par,
+ soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
nrows_y, scale, max_bias, m0,
m1, n_head_log2, item_ct1,
local_buf_acc.get_pointer());
});
}
-static void soft_max_f32_sycl(const float * x, const float * mask, const float * pos,
+static void soft_max_f32_sycl(const float * x, const float * mask,
float * dst, const int ncols_x, const int nrows_x,
const int nrows_y, const float scale, const float max_bias,
dpct::queue_ptr stream) {
const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
if (n_local_scratch*sizeof(float) < local_mem_size) {
if (ncols_x > max_block_size) {
- soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+ soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
return;
}
switch (ncols_x) {
case 32:
- soft_max_f32_submitter<true, 32, 32>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+ soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
break;
case 64:
- soft_max_f32_submitter<true, 64, 64>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+ soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
break;
case 128:
- soft_max_f32_submitter<true, 128, 128>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+ soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
break;
case 256:
- soft_max_f32_submitter<true, 256, 256>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+ soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
break;
case 512:
- soft_max_f32_submitter<true, 512, 512>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+ soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
break;
case 1024:
- soft_max_f32_submitter<true, 1024, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+ soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
break;
case 2048:
- soft_max_f32_submitter<true, 2048, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+ soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
break;
case 4096:
- soft_max_f32_submitter<true, 4096, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+ soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
break;
default:
- soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+ soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
break;
}
} else {
- soft_max_f32_submitter<false, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+ soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, WARP_SIZE, stream);
}
(void) src1_dd;
}
-inline void ggml_sycl_op_alibi(const ggml_tensor *src0, const ggml_tensor *src1,
- ggml_tensor *dst, const float *src0_dd,
- const float *src1_dd, float *dst_dd,
- const dpct::queue_ptr &main_stream) {
-
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
- GGML_TENSOR_LOCALS_3(int64_t, ne0, src0, ne);
- const int64_t nrows = ggml_nrows(src0);
-
- //const int n_past = ((int32_t *) dst->op_params)[0];
- const int n_head = ((int32_t *) dst->op_params)[1];
- float max_bias;
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
-
- //GGML_ASSERT(ne01 + n_past == ne00);
- GGML_ASSERT(n_head == ne02);
-
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
-
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
-
- alibi_f32_sycl(src0_dd, dst_dd, ne00, nrows, ne01, n_heads_log2_floor, m0, m1, main_stream);
-
- (void) src1;
- (void) src1_dd;
-}
-
static void ggml_sycl_op_pool2d(const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
- const ggml_tensor * src2 = dst->src[2];
-
-#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support")
+#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
- GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
const int64_t ne00 = src0->ne[0];
const int64_t nrows_x = ggml_nrows(src0);
memcpy(&scale, dst->op_params + 0, sizeof(float));
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
- // positions tensor
- float * src2_dd = nullptr;
- sycl_pool_alloc<float> src2_f;
-
- const bool use_src2 = src2 != nullptr;
-
- if (use_src2) {
- const bool src2_on_device = src2->backend == GGML_BACKEND_TYPE_GPU;
-
- if (src2_on_device) {
- ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra;
- src2_dd = (float *) src2_extra->data_device[g_main_device];
- } else {
- src2_dd = src2_f.alloc(ggml_nelements(src2));
- SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream));
- }
- }
-
- soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00,
+ soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
nrows_x, nrows_y, scale, max_bias, main_stream);
}
ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_rope);
}
-static void ggml_sycl_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_alibi);
-}
-
static void ggml_sycl_pool2d(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_pool2d);
}
case GGML_OP_ROPE:
func = ggml_sycl_rope;
break;
- case GGML_OP_ALIBI:
- func = ggml_sycl_alibi;
- break;
case GGML_OP_IM2COL:
func = ggml_sycl_im2col;
break;
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_ROPE:
- case GGML_OP_ALIBI:
case GGML_OP_IM2COL:
case GGML_OP_POOL_2D:
case GGML_OP_SUM_ROWS:
return nullptr;
case GGML_OP_SOFT_MAX:
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
- GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16);
- if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
+ if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_soft_max_f32;
}
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && src2->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+#pragma message("TODO: src2 is no longer used in soft_max - should be removed and ALiBi calculation should be updated")
+#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
+
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, {
ncols,
src1 != nullptr ? nrows_y : (uint32_t)0,
"SOFT_MAX_BACK",
"ROPE",
"ROPE_BACK",
- "ALIBI",
"CLAMP",
"CONV_TRANSPOSE_1D",
"IM2COL",
"CROSS_ENTROPY_LOSS_BACK",
};
-static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
+static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
"soft_max_back(x)",
"rope(x)",
"rope_back(x)",
- "alibi(x)",
"clamp(x)",
"conv_transpose_1d(x)",
"im2col(x)",
"cross_entropy_loss_back(x,y)",
};
-static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
+static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * mask,
- struct ggml_tensor * pos,
float scale,
float max_bias,
bool inplace) {
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
}
- if (pos) {
- GGML_ASSERT(ggml_is_vector(pos));
- GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32);
- GGML_ASSERT(pos->ne[0] == a->ne[0]);
- }
-
- if (pos && mask) {
- GGML_ASSERT(pos->type == mask->type);
- }
-
if (max_bias > 0.0f) {
- GGML_ASSERT(pos);
+ GGML_ASSERT(mask);
}
bool is_node = false;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = mask;
- result->src[2] = pos;
return result;
}
struct ggml_tensor * ggml_soft_max(
struct ggml_context * ctx,
struct ggml_tensor * a) {
- return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, false);
+ return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, false);
}
struct ggml_tensor * ggml_soft_max_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
- return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, true);
+ return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, true);
}
struct ggml_tensor * ggml_soft_max_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * mask,
- struct ggml_tensor * pos,
float scale,
float max_bias) {
- return ggml_soft_max_impl(ctx, a, mask, pos, scale, max_bias, false);
+ return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
}
// ggml_soft_max_back
return result;
}
-// ggml_alibi
-
-struct ggml_tensor * ggml_alibi(
- struct ggml_context * ctx,
- struct ggml_tensor * a,
- int n_past,
- int n_head,
- float bias_max) {
- GGML_ASSERT(n_past >= 0);
- bool is_node = false;
-
- if (a->grad) {
- GGML_ASSERT(false); // TODO: implement backward
- is_node = true;
- }
-
- // TODO: when implement backward, fix this:
- //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
- struct ggml_tensor * result = ggml_view_tensor(ctx, a);
-
- int32_t op_params[3] = { n_past, n_head };
- memcpy(op_params + 2, &bias_max, sizeof(float));
- ggml_set_op_params(result, op_params, sizeof(op_params));
-
- result->op = GGML_OP_ALIBI;
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src[0] = a;
-
- return result;
-}
-
// ggml_clamp
struct ggml_tensor * ggml_clamp(
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * mask,
- float scale) {
+ float scale,
+ float max_bias) {
GGML_ASSERT(ggml_can_mul_mat(k, q));
// TODO: check if vT can be multiplied by (k*qT)
+
if (mask) {
GGML_ASSERT(ggml_is_contiguous(mask));
GGML_ASSERT(mask->ne[2] == 1);
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
}
+ if (max_bias > 0.0f) {
+ GGML_ASSERT(mask);
+ }
+
bool is_node = false;
if (q->grad || k->grad || v->grad) {
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
- float params[] = { scale };
+ float params[] = { scale, max_bias };
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_FLASH_ATTN_EXT;
const int32_t prec_i32 = (int32_t) prec;
- ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos
+ ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
}
// ggml_flash_ff
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
- const struct ggml_tensor * src2 = dst->src[2];
assert(ggml_is_contiguous(dst));
assert(ggml_are_same_shape(src0, dst));
// TODO: is this supposed to be ceil instead of floor?
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
- const uint32_t n_head_kv = ne02;
- const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head_kv));
+ const uint32_t n_head = ne02;
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
- // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
- ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data;
- float * pos_f32 = src2 ? (float *) src2->data : src0->data;
-
- const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
for (int i1 = ir0; i1 < ir1; i1++) {
+ // ALiBi
+ const uint32_t h = (i1/ne01)%ne02; // head
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
+
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
if (mp_f32) {
if (use_f16) {
for (int i = 0; i < nc; ++i) {
- wp[i] += GGML_FP16_TO_FP32(mp_f16[i]);
+ wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
}
} else {
for (int i = 0; i < nc; ++i) {
- wp[i] += mp_f32[i];
- }
- }
- }
-
- // ALiBi bias
- if (max_bias > 0.0f) {
- const uint32_t h = (i1/ne01)%ne02; // head
- const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
-
- if (use_f16) {
- for (int i = 0; i < nc; ++i) {
- wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]);
- }
- } else {
- for (int i = 0; i < nc; ++i) {
- wp[i] += slope*pos_f32[i];
+ wp[i] += slope*mp_f32[i];
}
}
}
}
}
-// ggml_compute_forward_alibi
-
-static void ggml_compute_forward_alibi_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- assert(params->ith == 0);
-
- if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
- return;
- }
-
- //const int n_past = ((int32_t *) dst->op_params)[0];
- const int n_head = ((int32_t *) dst->op_params)[1];
- float max_bias;
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
-
- const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
- const int64_t ne1 = src0->ne[1]; // seq_len_without_past
- const int64_t ne2 = src0->ne[2]; // n_head -> this is k
- //const int64_t ne3 = src0->ne[3]; // 1 -> bsz
-
- const int64_t n = ggml_nrows(src0);
- const int64_t ne2_ne3 = n/ne1; // ne2*ne3
-
- const size_t nb0 = src0->nb[0];
- const size_t nb1 = src0->nb[1];
- const size_t nb2 = src0->nb[2];
- //const int nb3 = src0->nb[3];
-
- GGML_ASSERT(nb0 == sizeof(float));
- GGML_ASSERT(n_head == ne2);
-
- // add alibi to src0 (KQ_scaled)
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
-
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
-
- for (int64_t k = 0; k < ne2_ne3; k++) {
- // TODO: k*nb2 or k*nb3
- float m_k;
-
- if (k < n_heads_log2_floor) {
- m_k = powf(m0, k + 1);
- } else {
- m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
- }
-
- for (int64_t i = 0; i < ne0; i++) {
- for (int64_t j = 0; j < ne1; j++) {
- float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
- float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
- pdst[0] = i * m_k + src[0];
- }
- }
- }
-}
-
-static void ggml_compute_forward_alibi_f16(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- assert(params->ith == 0);
-
- if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
- return;
- }
-
- //const int n_past = ((int32_t *) dst->op_params)[0];
- const int n_head = ((int32_t *) dst->op_params)[1];
- float max_bias;
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
-
- const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
- const int ne1 = src0->ne[1]; // seq_len_without_past
- const int ne2 = src0->ne[2]; // n_head -> this is k
- //const int ne3 = src0->ne[3]; // 1 -> bsz
-
- const int n = ggml_nrows(src0);
- const int ne2_ne3 = n/ne1; // ne2*ne3
-
- const int nb0 = src0->nb[0];
- const int nb1 = src0->nb[1];
- const int nb2 = src0->nb[2];
- //const int nb3 = src0->nb[3];
-
- GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
- //GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
- GGML_ASSERT(n_head == ne2);
-
- // add alibi to src0 (KQ_scaled)
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
-
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
-
- for (int k = 0; k < ne2_ne3; k++) {
- // TODO: k*nb2 or k*nb3
- float m_k;
-
- if (k < n_heads_log2_floor) {
- m_k = powf(m0, k + 1);
- } else {
- m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
- }
-
- for (int i = 0; i < ne0; i++) {
- for (int j = 0; j < ne1; j++) {
- ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
- float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
-
- // we return F32
- pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
- }
- }
- }
-}
-
-static void ggml_compute_forward_alibi(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F16:
- {
- ggml_compute_forward_alibi_f16(params, dst);
- } break;
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_alibi_f32(params, dst);
- } break;
- case GGML_TYPE_BF16:
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1:
- case GGML_TYPE_Q5_0:
- case GGML_TYPE_Q5_1:
- case GGML_TYPE_Q8_0:
- case GGML_TYPE_Q8_1:
- case GGML_TYPE_Q2_K:
- case GGML_TYPE_Q3_K:
- case GGML_TYPE_Q4_K:
- case GGML_TYPE_Q5_K:
- case GGML_TYPE_Q6_K:
- case GGML_TYPE_IQ2_XXS:
- case GGML_TYPE_IQ2_XS:
- case GGML_TYPE_IQ3_XXS:
- case GGML_TYPE_IQ1_S:
- case GGML_TYPE_IQ1_M:
- case GGML_TYPE_IQ4_NL:
- case GGML_TYPE_IQ4_XS:
- case GGML_TYPE_IQ3_S:
- case GGML_TYPE_IQ2_S:
- case GGML_TYPE_Q8_K:
- case GGML_TYPE_I8:
- case GGML_TYPE_I16:
- case GGML_TYPE_I32:
- case GGML_TYPE_I64:
- case GGML_TYPE_F64:
- case GGML_TYPE_COUNT:
- {
- GGML_ASSERT(false);
- } break;
- }
-}
-
// ggml_compute_forward_clamp
static void ggml_compute_forward_clamp_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
- float scale = 1.0f;
- memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
+
+ const uint32_t n_head = neq2;
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
// loop over n_batch and n_head
for (int ir = ir0; ir < ir1; ++ir) {
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
+ const uint32_t h = iq2; // head
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
+
float S = 0.0f;
float M = -INFINITY;
// loop over n_kv and n_head_kv
// ref: https://arxiv.org/pdf/2112.05682.pdf
for (int64_t ic = 0; ic < nek1; ++ic) {
- const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
+ const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
if (mv == -INFINITY) {
continue;
}
const struct ggml_tensor * v,
const struct ggml_tensor * mask,
struct ggml_tensor * dst) {
- switch (dst->op_params[1]) {
+ switch (dst->op_params[2]) {
case GGML_PREC_DEFAULT:
case GGML_PREC_F32:
{
{
ggml_compute_forward_rope_back(params, tensor);
} break;
- case GGML_OP_ALIBI:
- {
- ggml_compute_forward_alibi(params, tensor);
- } break;
case GGML_OP_CLAMP:
{
ggml_compute_forward_clamp(params, tensor);
zero_table);
}
} break;
- case GGML_OP_ALIBI:
- {
- GGML_ASSERT(false); // TODO: not implemented
- } break;
case GGML_OP_CLAMP:
{
GGML_ASSERT(false); // TODO: not implemented
{
n_tasks = n_threads;
} break;
- case GGML_OP_ALIBI:
- {
- n_tasks = 1; //TODO
- } break;
case GGML_OP_CLAMP:
{
n_tasks = 1; //TODO
if (this->mask) {
mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]);
}
- ggml_tensor * pos = nullptr;
- if (max_bias > 0.0f) {
- pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]);
- }
- ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, pos, scale, max_bias);
+ ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias);
return out;
}
};
const int64_t kv; // kv size
const int64_t nb; // batch size
+ const float max_bias; // ALiBi
+
std::string vars() override {
- return VARS_TO_STR4(hs, nh, kv, nb);
+ return VARS_TO_STR5(hs, nh, kv, nb, max_bias);
}
double max_nmse_err() override {
return 5e-4;
}
- test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8)
- : hs(hs), nh(nh), kv(kv), nb(nb) {}
+ test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, float max_bias = 0.0f)
+ : hs(hs), nh(nh), kv(kv), nb(nb), max_bias(max_bias) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1);
- ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs));
+ ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs), max_bias);
return out;
}
};
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
- kq = ggml_soft_max_ext(ctx, kq, kq_mask, nullptr, kq_scale, 0.0f);
+ kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, 0.0f);
// split cached v into n_head heads
struct ggml_tensor * v =
#endif
for (bool mask : {false, true}) {
for (float max_bias : {0.0f, 8.0f}) {
+ if (!mask && max_bias > 0.0f) continue;
for (float scale : {1.0f, 0.1f}) {
for (int64_t ne0 : {16, 1024}) {
for (int64_t ne1 : {16, 1024}) {
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 8.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
#else
for (int hs : { 64, 80, 128, 256, }) {
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- for (int nh : { 32, }) {
- for (int kv : { 512, 1024, }) {
- for (int nb : { 1, 2, 4, 8, }) {
- test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb));
+ for (float max_bias : {0.0f, 8.0f}) {
+ for (int nh : { 32, }) {
+ for (int kv : { 512, 1024, }) {
+ for (int nb : { 1, 2, 4, 8, }) {
+ test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, max_bias));
+ }
}
}
}