return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
case GGML_OP_FLASH_ATTN_EXT:
{
- if (op->src[4]) {
- return false;
- }
-
const ggml_tensor * q = op->src[0];
const ggml_tensor * k = op->src[1];
const ggml_tensor * v = op->src[2];
static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, const ggml_tensor * k, ggml_tensor * dst) {
const ggml_tensor * v = dst->src[2];
const ggml_tensor * mask = dst->src[3];
+ const ggml_tensor * sinks = dst->src[4];
GGML_ASSERT(q->extra);
GGML_ASSERT(k->extra);
GGML_ASSERT(v->extra);
if (mask) {
GGML_ASSERT(mask->extra);
}
+ if (sinks) {
+ GGML_ASSERT(sinks->extra);
+ }
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *)v->extra;
ggml_tensor_extra_cl * extra_o = (ggml_tensor_extra_cl *)dst->extra;
ggml_tensor_extra_cl * extra_mask = mask ? (ggml_tensor_extra_cl *)mask->extra : NULL;
+ ggml_tensor_extra_cl * extra_sinks = sinks ? (ggml_tensor_extra_cl *)sinks->extra : NULL;
cl_ulong offset_q = extra_q->offset + q->view_offs;
cl_ulong offset_k = extra_k->offset + k->view_offs;
cl_ulong offset_o = extra_o->offset + dst->view_offs;
cl_mem mask_buffer = extra_mask ? extra_mask->data_device : NULL;
cl_ulong offset_mask = extra_mask ? extra_mask->offset + mask->view_offs : 0;
+ cl_mem sinks_buffer = extra_sinks ? extra_sinks->data_device : NULL;
+ cl_ulong offset_sinks = extra_sinks ? extra_sinks->offset + sinks->view_offs : 0;
const cl_ulong q_nb1 = q->nb[1], q_nb2 = q->nb[2], q_nb3 = q->nb[3];
const cl_ulong k_nb1 = k->nb[1], k_nb2 = k->nb[2], k_nb3 = k->nb[3];
CL_CHECK(clSetKernelArg(kernel, 35, sizeof(cl_ulong), &mask_nb3));
CL_CHECK(clSetKernelArg(kernel, 36, sizeof(int), &mask_ne2));
CL_CHECK(clSetKernelArg(kernel, 37, sizeof(int), &mask_ne3));
+ CL_CHECK(clSetKernelArg(kernel, 38, sizeof(cl_mem), &sinks_buffer));
+ CL_CHECK(clSetKernelArg(kernel, 39, sizeof(cl_ulong), &offset_sinks));
if (n_q == 1) {
const size_t wg_size = 64;
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
- const int mask_ne3
+ const int mask_ne3,
+ const global void* sinks_void,
+ const ulong sinks_offset
) {
const int tid = get_local_id(0);
const int block_q_idx = get_group_id(0);
}
if (my_query_row < n_q) {
+ if (sinks_void != NULL) {
+ const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
+ const ACC_TYPE m_sink = sinks_ptr[head_idx];
+ const ACC_TYPE m_final = max(m_i, m_sink);
+
+ const ACC_TYPE scale_o = exp(m_i - m_final);
+ #pragma unroll
+ for (int i = 0; i < DV_VEC; ++i) {
+ o_acc[i] *= scale_o;
+ }
+
+ l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final);
+ }
+
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) {
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
- const int mask_ne3
+ const int mask_ne3,
+ const global void* sinks_void,
+ const ulong sinks_offset
) {
const int tid = get_local_id(0);
const int head_batch_idx = get_global_id(1);
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
- ACC_TYPE m_i = -INFINITY;
+ const global ACC_TYPE* sinks_ptr = NULL;
+ if (sinks_void != NULL) {
+ sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
+ }
+
+ ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
- const ACC_TYPE l_final = local_l[0];
+ ACC_TYPE l_final = local_l[0];
+
+ if (sinks_ptr != NULL) {
+ l_final += exp(sinks_ptr[head_idx] - m_final);
+ }
if (l_final > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_final;
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
- const int mask_ne3
+ const int mask_ne3,
+ const global void* sinks_void,
+ const ulong sinks_offset
) {
const int tid = get_local_id(0);
const int block_q_idx = get_group_id(0);
}
if (my_query_row < n_q) {
+ if (sinks_void != NULL) {
+ const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
+ const ACC_TYPE m_sink = sinks_ptr[head_idx];
+ const ACC_TYPE m_final = max(m_i, m_sink);
+
+ const ACC_TYPE scale_o = exp(m_i - m_final);
+ #pragma unroll
+ for (int i = 0; i < DV_VEC; ++i) {
+ o_acc[i] *= scale_o;
+ }
+
+ l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final);
+ }
+
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) {
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
- const int mask_ne3
+ const int mask_ne3,
+ const global void* sinks_void,
+ const ulong sinks_offset
) {
const int tid = get_local_id(0);
const int head_batch_idx = get_global_id(1);
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
- ACC_TYPE m_i = -INFINITY;
+ const global ACC_TYPE* sinks_ptr = NULL;
+ if (sinks_void != NULL) {
+ sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
+ }
+
+ ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
- const ACC_TYPE l_final = local_l[0];
+ ACC_TYPE l_final = local_l[0];
+
+ if (sinks_ptr != NULL) {
+ l_final += exp(sinks_ptr[head_idx] - m_final);
+ }
if (l_final > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_final;
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
- const int mask_ne3
+ const int mask_ne3,
+ const global void* sinks_void,
+ const ulong sinks_offset
) {
const int tid = get_local_id(0);
const int block_q_idx = get_group_id(0);
}
if (my_query_row < n_q) {
+ if (sinks_void != NULL) {
+ const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
+ const ACC_TYPE m_sink = sinks_ptr[head_idx];
+ const ACC_TYPE m_final = max(m_i, m_sink);
+
+ const ACC_TYPE scale_o = exp(m_i - m_final);
+ #pragma unroll
+ for (int i = 0; i < DV_VEC; ++i) {
+ o_acc[i] *= scale_o;
+ }
+
+ l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final);
+ }
+
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) {
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
- const int mask_ne3
+ const int mask_ne3,
+ const global void* sinks_void,
+ const ulong sinks_offset
) {
const int tid = get_local_id(0);
const int head_batch_idx = get_global_id(1);
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
- ACC_TYPE m_i = -INFINITY;
+ const global ACC_TYPE* sinks_ptr = NULL;
+ if (sinks_void != NULL) {
+ sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
+ }
+
+ ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
- const ACC_TYPE l_final = local_l[0];
+ ACC_TYPE l_final = local_l[0];
+
+ if (sinks_ptr != NULL) {
+ l_final += exp(sinks_ptr[head_idx] - m_final);
+ }
if (l_final > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_final;