GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 64, 64)
return 0;
}
sycl::half2 * const __restrict__ tile_KV,
const int stride_KV,
const int i_sup) {
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
constexpr int cpy_ne = cpy_nb / 4;
auto load = [&] (const int n) {
- auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const int stride_j = warp_size >> n;
if (stride_j == 0) {
flash_attn_tile_load_tile<warp_size, nwarps, nbatch_fa, nbatch_K, cpy_ne, oob_check>
(K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup);
- item_ct1.barrier();
+ item_ct1.barrier(sycl::access::fence_space::local_space);
#ifdef SYCL_FAST_FP16
static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K");
}
if (k_KQ_0 + nbatch_K < DKQ) {
- item_ct1.barrier(); // Sync not needed on last iteration.
+ item_ct1.barrier(sycl::access::fence_space::local_space); // Sync not needed on last iteration.
}
}
const int k_VKQ_max,
const int col_Q_0,
float * KQ_max_new_shared) {
- auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
constexpr int cpy_ne = cpy_nb / 4;
}
if constexpr (np == 1) {
- item_ct1.barrier();
+ item_ct1.barrier(sycl::access::fence_space::local_space);
} else {
static_assert(cpw == 1, "bad cpw");
if (item_ct1.get_local_id(2) == 0) {
KQ_max_new_shared[item_ct1.get_local_id(1)] = KQ_max_new[0];
}
- item_ct1.barrier();
+ item_ct1.barrier(sycl::access::fence_space::local_space);
KQ_max_new[0] = KQ_max_new_shared[(item_ct1.get_local_id(1) & ~(np - 1)) + item_ct1.get_local_id(2) % np];
KQ_max_new[0] = warp_reduce_max<np>(KQ_max_new[0]);
}
for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) {
flash_attn_tile_load_tile<warp_size, nwarps, nbatch_V, DV, 0, oob_check>
(V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0);
- item_ct1.barrier();
+ item_ct1.barrier(sycl::access::fence_space::local_space);
#ifdef SYCL_FAST_FP16
#pragma unroll
}
}
#endif // SYCL_FAST_FP16
- item_ct1.barrier();
+ item_ct1.barrier(sycl::access::fence_space::local_space);
}
}
}
}
- item_ct1.barrier();
+ item_ct1.barrier(sycl::access::fence_space::local_space);
// Main loop over KV cache:
const int k_VKQ_max = KV_max ? KV_max[sequence * item_ct1.get_group_range(2) + item_ct1.get_group(2)] : ne11;
return;
}
- item_ct1.barrier();
+ item_ct1.barrier(sycl::access::fence_space::local_space);
#pragma unroll
for (int ip = 1; ip < np; ++ip) {
constexpr size_t nbytes_shared = 0;
- if constexpr (DV <= 256) {
- if (Q->ne[1] > 16/ncols2) {
- constexpr int cols_per_block = 32;
- const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
- const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
- launch_fattn<DV, cols_per_block/ncols2, ncols2,
- flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
- (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
- return;
+ if (DV < 512 && Q->ne[1] < 32) {
+ if constexpr (ncols2 <= 32) {
+ if (Q->ne[1] > 16/ncols2) {
+ constexpr int cols_per_block = 32;
+ const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+ const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+ launch_fattn<DV, cols_per_block/ncols2, ncols2,
+ flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
+ (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
+ return;
+ }
}
- }
-
- if (Q->ne[1] > 8/ncols2) {
- constexpr int cols_per_block = 16;
- const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
- const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
- launch_fattn<DV, cols_per_block/ncols2, ncols2,
- flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
- (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
- return;
- }
-
- if constexpr (ncols2 <= 8) {
- if (Q->ne[1] > 4/ncols2) {
- constexpr int cols_per_block = 8;
- const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
- const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
- launch_fattn<DV, cols_per_block/ncols2, ncols2,
- flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
- (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
- return;
+ if constexpr (ncols2 <= 16) {
+ if (Q->ne[1] > 8/ncols2) {
+ constexpr int cols_per_block = 16;
+ const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+ const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+ launch_fattn<DV, cols_per_block/ncols2, ncols2,
+ flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
+ (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
+ return;
+ }
+ }
+ if constexpr (ncols2 <= 8) {
+ if (Q->ne[1] > 4/ncols2) {
+ constexpr int cols_per_block = 8;
+ const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+ const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+ launch_fattn<DV, cols_per_block/ncols2, ncols2,
+ flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
+ (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
+ return;
+ }
}
}