]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
sycl: implement GGML_OP_TOP_K (llama/19242)
authorTamar <redacted>
Mon, 2 Feb 2026 13:05:51 +0000 (15:05 +0200)
committerGeorgi Gerganov <redacted>
Sat, 7 Feb 2026 08:37:38 +0000 (10:37 +0200)
src/ggml-sycl/ggml-sycl.cpp

index 12f1e7717b7cfad552d5260fe24809f68965125c..c5139fd3dd2d8c39b2b9e835e01efc55e1db3eeb 100644 (file)
@@ -1840,6 +1840,110 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
     }
 }
 
+static void top_k_f32_sycl(
+    const float * src,
+    int32_t * dst_indices,
+    const int64_t ncols,
+    const int64_t nrows,
+    const int k,
+    dpct::queue_ptr main_stream
+) {
+    const int block_size = 128;
+
+    const sycl::range<1> block_dims(block_size);
+    const sycl::range<1> grid_dims(nrows);
+
+    main_stream->submit([&](sycl::handler &cgh) {
+        sycl::local_accessor<float, 1> shared_vals(sycl::range<1>(block_size * k), cgh);
+        sycl::local_accessor<int, 1> shared_idx(sycl::range<1>(block_size * k), cgh);
+
+        cgh.parallel_for(
+            sycl::nd_range<1>(grid_dims * block_dims, block_dims),
+            [=](sycl::nd_item<1> item_ct1) {
+                const int row = item_ct1.get_group(0);
+                const int tid = item_ct1.get_local_id(0);
+
+                if (row >= nrows) return;
+
+                const float * src_row = src + row * ncols;
+                int32_t * dst_idx_row = dst_indices + row * k;
+
+                float local_vals[32];
+                int local_idx[32];
+
+                for (int i = 0; i < k; i++) {
+                    local_vals[i] = -FLT_MAX;
+                    local_idx[i] = -1;
+                }
+
+                for (int col = tid; col < ncols; col += block_size) {
+                    float val = src_row[col];
+
+                    if (val > local_vals[k-1]) {
+                        int pos = k - 1;
+                        while (pos > 0 && val > local_vals[pos - 1]) {
+                            pos--;
+                        }
+
+                        for (int i = k - 1; i > pos; i--) {
+                            local_vals[i] = local_vals[i - 1];
+                            local_idx[i] = local_idx[i - 1];
+                        }
+                        local_vals[pos] = val;
+                        local_idx[pos] = col;
+                    }
+                }
+
+                for (int i = 0; i < k; i++) {
+                    shared_vals[tid * k + i] = local_vals[i];
+                    shared_idx[tid * k + i] = local_idx[i];
+                }
+                item_ct1.barrier(sycl::access::fence_space::local_space);
+
+                if (tid == 0) {
+                    float final_vals[32];
+                    int final_idx[32];
+
+                    for (int i = 0; i < k; i++) {
+                        final_vals[i] = -FLT_MAX;
+                        final_idx[i] = -1;
+                    }
+
+                    for (int t = 0; t < block_size; t++) {
+                        for (int i = 0; i < k; i++) {
+                            float val = shared_vals[t * k + i];
+                            int idx = shared_idx[t * k + i];
+
+                            if (val > final_vals[k-1]) {
+                                int pos = k - 1;
+                                while (pos > 0 && val > final_vals[pos - 1]) {
+                                    pos--;
+                                }
+
+                                for (int j = k - 1; j > pos; j--) {
+                                    final_vals[j] = final_vals[j - 1];
+                                    final_idx[j] = final_idx[j - 1];
+                                }
+                                final_vals[pos] = val;
+                                final_idx[pos] = idx;
+                            }
+                        }
+                    }
+
+                    for (int i = 0; i < k; i++) {
+                        dst_idx_row[i] = final_idx[i];
+                    }
+
+                    if (k > 1) {
+                        int32_t temp = dst_idx_row[0];
+                        dst_idx_row[0] = dst_idx_row[1];
+                        dst_idx_row[1] = temp;
+                    }
+                }
+            });
+    });
+}
+
 static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
                                const int nrows, queue_ptr stream) {
     const sycl::range<3> block_dims(1, 1, SYCL_ARGMAX_BLOCK_SIZE);
@@ -2231,6 +2335,30 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor *
                          main_stream, ctx.device);
 }
 
+static void ggml_sycl_op_top_k(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    dpct::queue_ptr main_stream = ctx.stream();
+    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
+
+    const float * src0_dd = static_cast<const float *>(src0->data);
+    int32_t * dst_dd = static_cast<int32_t *>(dst->data);
+
+    const int k = dst->ne[0];
+    const int64_t ncols = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    GGML_ASSERT(k > 0 && k <= 32);
+    GGML_ASSERT(k <= ncols);
+
+    top_k_f32_sycl(src0_dd, dst_dd, ncols, nrows, k, main_stream);
+}
+
 inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_I32);
@@ -4007,6 +4135,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
         case GGML_OP_ARGSORT:
             ggml_sycl_argsort(ctx, dst);
             break;
+        case GGML_OP_TOP_K:
+            ggml_sycl_op_top_k(ctx, dst);
+            break;
         case GGML_OP_TIMESTEP_EMBEDDING:
             ggml_sycl_op_timestep_embedding(ctx, dst);
             break;
@@ -4710,6 +4841,15 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_ARGSORT:
             return op->src[0]->ne[0] * sizeof(int) <=
                    ggml_sycl_info().devices[device].smpbo;
+        case GGML_OP_TOP_K: {
+            const ggml_tensor * src0 = op->src[0];
+            const int k = op->ne[0];
+            return src0 &&
+                op->type == GGML_TYPE_I32 &&
+                src0->type == GGML_TYPE_F32 &&
+                ggml_is_contiguous(src0) &&
+                k > 0 && k <= 32;
+        }
         case GGML_OP_POOL_2D:
         case GGML_OP_ACC:
             return true;