}
}
+static void top_k_f32_sycl(
+ const float * src,
+ int32_t * dst_indices,
+ const int64_t ncols,
+ const int64_t nrows,
+ const int k,
+ dpct::queue_ptr main_stream
+) {
+ const int block_size = 128;
+
+ const sycl::range<1> block_dims(block_size);
+ const sycl::range<1> grid_dims(nrows);
+
+ main_stream->submit([&](sycl::handler &cgh) {
+ sycl::local_accessor<float, 1> shared_vals(sycl::range<1>(block_size * k), cgh);
+ sycl::local_accessor<int, 1> shared_idx(sycl::range<1>(block_size * k), cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<1>(grid_dims * block_dims, block_dims),
+ [=](sycl::nd_item<1> item_ct1) {
+ const int row = item_ct1.get_group(0);
+ const int tid = item_ct1.get_local_id(0);
+
+ if (row >= nrows) return;
+
+ const float * src_row = src + row * ncols;
+ int32_t * dst_idx_row = dst_indices + row * k;
+
+ float local_vals[32];
+ int local_idx[32];
+
+ for (int i = 0; i < k; i++) {
+ local_vals[i] = -FLT_MAX;
+ local_idx[i] = -1;
+ }
+
+ for (int col = tid; col < ncols; col += block_size) {
+ float val = src_row[col];
+
+ if (val > local_vals[k-1]) {
+ int pos = k - 1;
+ while (pos > 0 && val > local_vals[pos - 1]) {
+ pos--;
+ }
+
+ for (int i = k - 1; i > pos; i--) {
+ local_vals[i] = local_vals[i - 1];
+ local_idx[i] = local_idx[i - 1];
+ }
+ local_vals[pos] = val;
+ local_idx[pos] = col;
+ }
+ }
+
+ for (int i = 0; i < k; i++) {
+ shared_vals[tid * k + i] = local_vals[i];
+ shared_idx[tid * k + i] = local_idx[i];
+ }
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+
+ if (tid == 0) {
+ float final_vals[32];
+ int final_idx[32];
+
+ for (int i = 0; i < k; i++) {
+ final_vals[i] = -FLT_MAX;
+ final_idx[i] = -1;
+ }
+
+ for (int t = 0; t < block_size; t++) {
+ for (int i = 0; i < k; i++) {
+ float val = shared_vals[t * k + i];
+ int idx = shared_idx[t * k + i];
+
+ if (val > final_vals[k-1]) {
+ int pos = k - 1;
+ while (pos > 0 && val > final_vals[pos - 1]) {
+ pos--;
+ }
+
+ for (int j = k - 1; j > pos; j--) {
+ final_vals[j] = final_vals[j - 1];
+ final_idx[j] = final_idx[j - 1];
+ }
+ final_vals[pos] = val;
+ final_idx[pos] = idx;
+ }
+ }
+ }
+
+ for (int i = 0; i < k; i++) {
+ dst_idx_row[i] = final_idx[i];
+ }
+
+ if (k > 1) {
+ int32_t temp = dst_idx_row[0];
+ dst_idx_row[0] = dst_idx_row[1];
+ dst_idx_row[1] = temp;
+ }
+ }
+ });
+ });
+}
+
static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
const int nrows, queue_ptr stream) {
const sycl::range<3> block_dims(1, 1, SYCL_ARGMAX_BLOCK_SIZE);
main_stream, ctx.device);
}
+static void ggml_sycl_op_top_k(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+
+ GGML_ASSERT(src0);
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_I32);
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ dpct::queue_ptr main_stream = ctx.stream();
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
+
+ const float * src0_dd = static_cast<const float *>(src0->data);
+ int32_t * dst_dd = static_cast<int32_t *>(dst->data);
+
+ const int k = dst->ne[0];
+ const int64_t ncols = src0->ne[0];
+ const int64_t nrows = ggml_nrows(src0);
+
+ GGML_ASSERT(k > 0 && k <= 32);
+ GGML_ASSERT(k <= ncols);
+
+ top_k_f32_sycl(src0_dd, dst_dd, ncols, nrows, k, main_stream);
+}
+
inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_I32);
case GGML_OP_ARGSORT:
ggml_sycl_argsort(ctx, dst);
break;
+ case GGML_OP_TOP_K:
+ ggml_sycl_op_top_k(ctx, dst);
+ break;
case GGML_OP_TIMESTEP_EMBEDDING:
ggml_sycl_op_timestep_embedding(ctx, dst);
break;
case GGML_OP_ARGSORT:
return op->src[0]->ne[0] * sizeof(int) <=
ggml_sycl_info().devices[device].smpbo;
+ case GGML_OP_TOP_K: {
+ const ggml_tensor * src0 = op->src[0];
+ const int k = op->ne[0];
+ return src0 &&
+ op->type == GGML_TYPE_I32 &&
+ src0->type == GGML_TYPE_F32 &&
+ ggml_is_contiguous(src0) &&
+ k > 0 && k <= 32;
+ }
case GGML_OP_POOL_2D:
case GGML_OP_ACC:
return true;