]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : add ggml_top_k (llama/17365)
authorGeorgi Gerganov <redacted>
Tue, 25 Nov 2025 13:31:43 +0000 (15:31 +0200)
committerGeorgi Gerganov <redacted>
Fri, 12 Dec 2025 15:53:08 +0000 (17:53 +0200)
* ggml : add ggml_top_k

* cont : add ggml_argsort_top_k

* metal : add top_k support

* ggml : cleanup

* tests : add virtual err() function for test_case

* ggml : add comments

13 files changed:
ggml/include/ggml.h
ggml/src/ggml-cpu/ggml-cpu.c
ggml/src/ggml-cpu/ops.cpp
ggml/src/ggml-cpu/ops.h
ggml/src/ggml-metal/ggml-metal-device.cpp
ggml/src/ggml-metal/ggml-metal-device.h
ggml/src/ggml-metal/ggml-metal-device.m
ggml/src/ggml-metal/ggml-metal-impl.h
ggml/src/ggml-metal/ggml-metal-ops.cpp
ggml/src/ggml-metal/ggml-metal-ops.h
ggml/src/ggml-metal/ggml-metal.cpp
ggml/src/ggml-metal/ggml-metal.metal
ggml/src/ggml.c

index 605fcfcb9c29f7fe870c8f4b1288038a25d9d9f7..4dbca868bc74a8cdafa6f1a6fd621a00cb51454a 100644 (file)
@@ -530,6 +530,7 @@ extern "C" {
         GGML_OP_ARANGE,
         GGML_OP_TIMESTEP_EMBEDDING,
         GGML_OP_ARGSORT,
+        GGML_OP_TOP_K,
         GGML_OP_LEAKY_RELU,
         GGML_OP_TRI,
         GGML_OP_FILL,
@@ -2258,18 +2259,25 @@ extern "C" {
             struct ggml_tensor  * a,
             enum ggml_sort_order  order);
 
-    GGML_API struct ggml_tensor * ggml_arange(
+    // similar to ggml_top_k but implemented as `argsort` + `view`
+    GGML_API struct ggml_tensor * ggml_argsort_top_k(
             struct ggml_context * ctx,
-            float                 start,
-            float                 stop,
-            float                 step);
+            struct ggml_tensor  * a,
+            int                   k);
 
     // top k elements per row
+    // note: the resulting top k indices are in no particular order
     GGML_API struct ggml_tensor * ggml_top_k(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             int                   k);
 
+    GGML_API struct ggml_tensor * ggml_arange(
+            struct ggml_context * ctx,
+            float                 start,
+            float                 stop,
+            float                 step);
+
 #define GGML_KQ_MASK_PAD 64
 
     // q:    [n_embd_k, n_batch,     n_head,    ne3 ]
index c7348cc26c10ca451a4a7a2693f261ee6d4c1611..3247af8bb039ccdfe0ca0cd7afe9a862ab1858e7 100644 (file)
@@ -1927,6 +1927,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_argsort(params, tensor);
             } break;
+        case GGML_OP_TOP_K:
+            {
+                ggml_compute_forward_top_k(params, tensor);
+            } break;
         case GGML_OP_LEAKY_RELU:
             {
                 ggml_compute_forward_leaky_relu(params, tensor);
@@ -2311,6 +2315,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_ARANGE:
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_ARGSORT:
+        case GGML_OP_TOP_K:
         case GGML_OP_FLASH_ATTN_EXT:
         case GGML_OP_FLASH_ATTN_BACK:
         case GGML_OP_SSM_CONV:
@@ -2834,6 +2839,10 @@ struct ggml_cplan ggml_graph_plan(
                         cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
                         cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
                     } break;
+                case GGML_OP_TOP_K:
+                    {
+                        cur += sizeof(int32_t)*node->src[0]->ne[0]*n_tasks;
+                    } break;
                 case GGML_OP_FLASH_ATTN_EXT:
                     {
                         const int64_t ne10 = node->src[1]->ne[0]; // DK
index 41e89d83c2ec943e859944bb39669c9ea0cc9845..d405696539e1ce10c511fc88b5ef16f68811129f 100644 (file)
@@ -7794,7 +7794,7 @@ void ggml_compute_forward_timestep_embedding(
 // ggml_compute_forward_argsort
 
 template<enum ggml_sort_order order>
-struct argsort_cmp {
+struct cmp_argsort {
     const float * data;
     bool operator()(int32_t a, int32_t b) const {
         if constexpr (order == GGML_SORT_ORDER_ASC) {
@@ -7833,11 +7833,11 @@ static void ggml_compute_forward_argsort_f32(
 
         switch (order) {
             case GGML_SORT_ORDER_ASC:
-                std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_ASC>{src_data});
+                std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_ASC>{src_data});
                 break;
 
             case GGML_SORT_ORDER_DESC:
-                std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_DESC>{src_data});
+                std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_DESC>{src_data});
                 break;
 
             default:
@@ -7864,6 +7864,72 @@ void ggml_compute_forward_argsort(
     }
 }
 
+// ggml_compute_forward_top_k
+
+struct cmp_top_k {
+    const float * data;
+    bool operator()(int32_t a, int32_t b) const {
+        return data[a] > data[b];
+    }
+};
+
+static void ggml_compute_forward_top_k_f32(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT(nb0 == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t nr = ggml_nrows(src0);
+
+    const int top_k = ne0;
+
+    int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
+
+    for (int64_t i = ith; i < nr; i += nth) {
+        const float * src_data = (float *)((char *) src0->data + i*nb01);
+
+        for (int64_t j = 0; j < ne00; j++) {
+            tmp[j] = j;
+        }
+
+        std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
+
+        int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
+
+        std::copy(tmp, tmp + top_k, dst_data);
+
+        // emphasize that the order is not important
+        if (top_k > 1) {
+            std::swap(dst_data[0], dst_data[1]);
+        }
+    }
+}
+
+void ggml_compute_forward_top_k(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_top_k_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 // ggml_compute_forward_flash_attn_ext
 
 static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
index 98a0eec16d9fec13c77dc0061b9d77e5d3be78b0..0fdfee79766e4aca15ec0c6f1afe1abe3645c8cf 100644 (file)
@@ -81,6 +81,7 @@ void ggml_compute_forward_roll(const struct ggml_compute_params * params, struct
 void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_top_k(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst);
index 0eefc0b137b1ce8dffc74764a620ece03da20c2e..329500a03e0d760efb7916f40ec88a9966213bbc 100644 (file)
@@ -1009,6 +1009,64 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_l
     return res;
 }
 
+// note: reuse the argsort kernel for top_k
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) {
+    assert(op->op == GGML_OP_TOP_K);
+
+    char base[256];
+    char name[256];
+
+    // note: the top_k kernel is always descending order
+    ggml_sort_order order = GGML_SORT_ORDER_DESC;
+
+    const char * order_str = "undefined";
+    switch (order) {
+        case GGML_SORT_ORDER_ASC:  order_str = "asc";  break;
+        case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
+        default: GGML_ABORT("fatal error");
+    };
+
+    snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
+    snprintf(name, 256, "%s", base);
+
+    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
+    if (res) {
+        return res;
+    }
+
+    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+
+    return res;
+}
+
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
+    assert(op->op == GGML_OP_TOP_K);
+
+    char base[256];
+    char name[256];
+
+    ggml_sort_order order = GGML_SORT_ORDER_DESC;
+
+    const char * order_str = "undefined";
+    switch (order) {
+        case GGML_SORT_ORDER_ASC:  order_str = "asc";  break;
+        case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
+        default: GGML_ABORT("fatal error");
+    };
+
+    snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
+    snprintf(name, 256, "%s", base);
+
+    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
+    if (res) {
+        return res;
+    }
+
+    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+
+    return res;
+}
+
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
         ggml_metal_library_t lib,
         const struct ggml_tensor * op,
index 39ee6e3427e8f0546997339495b8c318d5e02887..3976e622b9b9a672e3e100cdc6771300ca58934c 100644 (file)
@@ -128,6 +128,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id         (ggml_me
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax            (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort           (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge     (ggml_metal_library_t lib, const struct ggml_tensor * op);
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k             (ggml_metal_library_t lib, const struct ggml_tensor * op);
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge       (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin               (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm           (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm        (ggml_metal_library_t lib, const struct ggml_tensor * op);
index acf9dfd5fc12c550dedfc5cef5acd514b4f6d956..09b1b50311828c0b8b45d643f788f0694fc744b2 100644 (file)
@@ -905,6 +905,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
         case GGML_OP_LEAKY_RELU:
             return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_ARGSORT:
+        case GGML_OP_TOP_K:
         case GGML_OP_ARANGE:
             return true;
         case GGML_OP_FLASH_ATTN_EXT:
index 0fae97029f14a9812919775186c3940107f73d5e..342dc4f8c378097f7d9a1878ccbea2f7fb06cab7 100644 (file)
@@ -832,14 +832,19 @@ typedef struct {
 } ggml_metal_kargs_leaky_relu;
 
 typedef struct {
-    int64_t  ne00;
-    int64_t  ne01;
-    int64_t  ne02;
-    int64_t  ne03;
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
     uint64_t nb00;
     uint64_t nb01;
     uint64_t nb02;
     uint64_t nb03;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    int32_t  top_k;
 } ggml_metal_kargs_argsort;
 
 typedef struct {
@@ -851,6 +856,11 @@ typedef struct {
     uint64_t nb01;
     uint64_t nb02;
     uint64_t nb03;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    int32_t  top_k;
     int32_t  len;
 } ggml_metal_kargs_argsort_merge;
 
index e46970dad476bb92d6073ebe0ffcdb401865d226..9871e976f23d149b30ae459b01622f52865f4231 100644 (file)
@@ -406,6 +406,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             {
                 n_fuse = ggml_metal_op_argsort(ctx, idx);
             } break;
+        case GGML_OP_TOP_K:
+            {
+                n_fuse = ggml_metal_op_top_k(ctx, idx);
+            } break;
         case GGML_OP_LEAKY_RELU:
             {
                 n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
@@ -3678,14 +3682,19 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
     }
 
     ggml_metal_kargs_argsort args = {
-        /*.ne00 =*/ ne00,
-        /*.ne01 =*/ ne01,
-        /*.ne02 =*/ ne02,
-        /*.ne03 =*/ ne03,
-        /*.nb00 =*/ nb00,
-        /*.nb01 =*/ nb01,
-        /*.nb02 =*/ nb02,
-        /*.nb03 =*/ nb03,
+        /*.ne00  =*/ ne00,
+        /*.ne01  =*/ ne01,
+        /*.ne02  =*/ ne02,
+        /*.ne03  =*/ ne03,
+        /*.nb00  =*/ nb00,
+        /*.nb01  =*/ nb01,
+        /*.nb02  =*/ nb02,
+        /*.nb03  =*/ nb03,
+        /*.ne0   =*/ ne0,
+        /*.ne1   =*/ ne1,
+        /*.ne2   =*/ ne2,
+        /*.ne3   =*/ ne3,
+        /*.top_k =*/ nth,
     };
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
@@ -3705,15 +3714,20 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
         ggml_metal_op_concurrency_reset(ctx);
 
         ggml_metal_kargs_argsort_merge args_merge = {
-            .ne00 = ne00,
-            .ne01 = ne01,
-            .ne02 = ne02,
-            .ne03 = ne03,
-            .nb00 = nb00,
-            .nb01 = nb01,
-            .nb02 = nb02,
-            .nb03 = nb03,
-            .len  = len,
+            /*.ne00  =*/ ne00,
+            /*.ne01  =*/ ne01,
+            /*.ne02  =*/ ne02,
+            /*.ne03  =*/ ne03,
+            /*.nb00  =*/ nb00,
+            /*.nb01  =*/ nb01,
+            /*.nb02  =*/ nb02,
+            /*.nb03  =*/ nb03,
+            /*.ne0   =*/ ne0,
+            /*.ne1   =*/ ne1,
+            /*.ne2   =*/ ne2,
+            /*.ne3   =*/ ne3,
+            /*.top_k =*/ ne00,
+            /*.len   =*/ len,
         };
 
         // merges per row
@@ -3737,6 +3751,118 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
     return 1;
 }
 
+int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
+    ggml_tensor * op = ctx->node(idx);
+
+    ggml_metal_library_t lib = ctx->lib;
+    ggml_metal_encoder_t enc = ctx->enc;
+
+    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
+
+    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
+
+    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);
+
+    // bitonic sort requires the number of elements to be power of 2
+    int nth = 1;
+    while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+        nth *= 2;
+    }
+
+    // blocks per row
+    const int npr = (ne00 + nth - 1)/nth;
+
+    const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
+
+    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
+
+    ggml_metal_buffer_id bid_tmp = bid_dst;
+    bid_tmp.offs += sizeof(int32_t)*ggml_nelements(op->src[0]);
+
+    if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
+        std::swap(bid_dst, bid_tmp);
+    }
+
+    const int top_k = ne0;
+
+    ggml_metal_kargs_argsort args = {
+        /*.ne00  =*/ ne00,
+        /*.ne01  =*/ ne01,
+        /*.ne02  =*/ ne02,
+        /*.ne03  =*/ ne03,
+        /*.nb00  =*/ nb00,
+        /*.nb01  =*/ nb01,
+        /*.nb02  =*/ nb02,
+        /*.nb03  =*/ nb03,
+        /*.ne0   =*/ ne0,
+        /*.ne1   =*/ ne1,
+        /*.ne2   =*/ ne2,
+        /*.ne3   =*/ ne3,
+        /*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices
+    };
+
+    if (npr > 1) {
+        args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k);
+    }
+
+    ggml_metal_encoder_set_pipeline(enc, pipeline);
+    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
+    ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
+    ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
+
+    ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+
+    ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
+
+    ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);
+
+    int len = args.top_k;
+
+    while (len < args.ne0) {
+        ggml_metal_op_concurrency_reset(ctx);
+
+        // merges per row
+        const int nm = (args.ne0 + 2*len - 1) / (2*len);
+
+        const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge)));
+
+        ggml_metal_kargs_argsort_merge args_merge = {
+            /*.ne00  =*/ ne00,
+            /*.ne01  =*/ ne01,
+            /*.ne02  =*/ ne02,
+            /*.ne03  =*/ ne03,
+            /*.nb00  =*/ nb00,
+            /*.nb01  =*/ nb01,
+            /*.nb02  =*/ nb02,
+            /*.nb03  =*/ nb03,
+            /*.ne0   =*/ args.ne0,
+            /*.ne1   =*/ ne1,
+            /*.ne2   =*/ ne2,
+            /*.ne3   =*/ ne3,
+            /*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements
+            /*.len   =*/ len,
+        };
+
+        ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
+        ggml_metal_encoder_set_bytes   (enc, &args_merge, sizeof(args_merge), 0);
+        ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
+        ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
+        ggml_metal_encoder_set_buffer  (enc, bid_tmp,  3);
+
+        ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
+
+        std::swap(bid_dst, bid_tmp);
+
+        len <<= 1;
+    }
+
+    return 1;
+}
+
 int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
     ggml_tensor * op = ctx->node(idx);
 
index 332e550ee703be985ad6f63d35bf0e128c869334..b5546146e13d4d613c11fa1a257eae9197fdeb1d 100644 (file)
@@ -81,6 +81,7 @@ int ggml_metal_op_arange            (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_argmax            (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_argsort           (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_top_k             (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_leaky_relu        (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_opt_step_adamw    (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_opt_step_sgd      (ggml_metal_op_t ctx, int idx);
index f6033ddc97bfee7226ec8299d968ca75f1e7a4d5..70bf6f3d981f88f948355f00db6c836083e4712c 100644 (file)
@@ -202,6 +202,10 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
             {
                 res *= 2;
             } break;
+        case GGML_OP_TOP_K:
+            {
+                res = 2*sizeof(int32_t)*ggml_nelements(tensor->src[0]);
+            } break;
         default:
             break;
     }
index 59e5761704e685dee81e404d3fb5c9e78e8eb8f2..73b45c762d94c40e0f0a306ace0df8573e0d9cfe 100644 (file)
@@ -4670,11 +4670,12 @@ kernel void kernel_argsort_f32_i32(
         ushort3   ntg[[threads_per_threadgroup]]) {
     // bitonic sort
     const int col = tpitg[0];
+    const int ib  = tgpig[0] / args.ne01;
 
-    const int i00 = (tgpig[0]/args.ne01)*ntg.x;
-    const int i01 =  tgpig[0]%args.ne01;
-    const int i02 =  tgpig[1];
-    const int i03 =  tgpig[2];
+    const int i00 = ib*ntg.x;
+    const int i01 = tgpig[0] % args.ne01;
+    const int i02 = tgpig[1];
+    const int i03 = tgpig[2];
 
     device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
 
@@ -4710,9 +4711,11 @@ kernel void kernel_argsort_f32_i32(
         }
     }
 
+    const int64_t i0 = ib*args.top_k;
+
     // copy the result to dst without the padding
-    if (i00 + col < args.ne00) {
-        dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03;
+    if (i0 + col < args.ne0 && col < args.top_k) {
+        dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03;
 
         dst[col] = shmem_i32[col];
     }
@@ -4747,22 +4750,22 @@ kernel void kernel_argsort_merge_f32_i32(
 
     const int start = im * (2 * args.len);
 
-    const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start)));
-    const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len)));
+    const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
+    const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
 
     const int total = len0 + len1;
 
     device const int32_t * tmp0 = tmp + start
-        + i01*args.ne00
-        + i02*args.ne00*args.ne01
-        + i03*args.ne00*args.ne01*args.ne02;
+        + i01*args.ne0
+        + i02*args.ne0*args.ne01
+        + i03*args.ne0*args.ne01*args.ne02;
 
     device const int32_t * tmp1 = tmp0 + args.len;
 
     dst += start
-        + i01*args.ne00
-        + i02*args.ne00*args.ne01
-        + i03*args.ne00*args.ne01*args.ne02;
+        + i01*args.top_k
+        + i02*args.top_k*args.ne01
+        + i03*args.top_k*args.ne01*args.ne02;
 
     device const float * src0_row = (device const float *)(src0
         + args.nb01*i01
@@ -4776,7 +4779,11 @@ kernel void kernel_argsort_merge_f32_i32(
     const int chunk = (total + ntg.x - 1) / ntg.x;
 
     const int k0 = tpitg.x * chunk;
-    const int k1 = min(k0 + chunk, total);
+    const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
+
+    if (k0 >= args.top_k) {
+        return;
+    }
 
     if (k0 >= total) {
         return;
index a5846a23937ce73dae565d21452e6aa8ecf4ac4e..b99345a2e93b024a2296bd7c29d7a7affbaf0e8c 100644 (file)
@@ -990,6 +990,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "ARANGE",
     "TIMESTEP_EMBEDDING",
     "ARGSORT",
+    "TOP_K",
     "LEAKY_RELU",
     "TRI",
     "FILL",
@@ -1023,7 +1024,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "GLU",
 };
 
-static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
+static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1098,6 +1099,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "arange(start, stop, step)",
     "timestep_embedding(timesteps, dim, max_period)",
     "argsort(x)",
+    "top_k(x)",
     "leaky_relu(x)",
     "tri(x)",
     "fill(x, c)",
@@ -1131,7 +1133,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "glu(x)",
 };
 
-static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
+static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -5036,28 +5038,6 @@ struct ggml_tensor * ggml_roll(
     return result;
 }
 
-// ggml_arange
-
-struct ggml_tensor * ggml_arange(
-        struct ggml_context * ctx,
-        float                 start,
-        float                 stop,
-        float                 step) {
-    GGML_ASSERT(stop > start);
-
-    const int64_t steps = (int64_t) ceilf((stop - start) / step);
-
-    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
-
-    ggml_set_op_params_f32(result, 0, start);
-    ggml_set_op_params_f32(result, 1, stop);
-    ggml_set_op_params_f32(result, 2, step);
-
-    result->op = GGML_OP_ARANGE;
-
-    return result;
-}
-
 // ggml_timestep_embedding
 
 struct ggml_tensor * ggml_timestep_embedding(
@@ -5139,6 +5119,7 @@ struct ggml_tensor * ggml_argsort(
         struct ggml_tensor   * a,
         enum ggml_sort_order   order) {
     GGML_ASSERT(a->ne[0] <= INT32_MAX);
+
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
 
     ggml_set_op_params_i32(result, 0, (int32_t) order);
@@ -5149,9 +5130,9 @@ struct ggml_tensor * ggml_argsort(
     return result;
 }
 
-// ggml_top_k
+// ggml_argsort_top_k
 
-struct ggml_tensor * ggml_top_k(
+struct ggml_tensor * ggml_argsort_top_k(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         int                   k) {
@@ -5167,6 +5148,44 @@ struct ggml_tensor * ggml_top_k(
     return result;
 }
 
+// ggml_top_k
+
+struct ggml_tensor * ggml_top_k(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   k) {
+    GGML_ASSERT(a->ne[0] >= k);
+
+    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_I32, k, a->ne[1], a->ne[2], a->ne[3]);
+
+    result->op     = GGML_OP_TOP_K;
+    result->src[0] = a;
+
+    return result;
+}
+
+// ggml_arange
+
+struct ggml_tensor * ggml_arange(
+        struct ggml_context * ctx,
+        float                 start,
+        float                 stop,
+        float                 step) {
+    GGML_ASSERT(stop > start);
+
+    const int64_t steps = (int64_t) ceilf((stop - start) / step);
+
+    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
+
+    ggml_set_op_params_f32(result, 0, start);
+    ggml_set_op_params_f32(result, 1, stop);
+    ggml_set_op_params_f32(result, 2, step);
+
+    result->op = GGML_OP_ARANGE;
+
+    return result;
+}
+
 // ggml_flash_attn_ext
 
 struct ggml_tensor * ggml_flash_attn_ext(