From: Georgi Gerganov Date: Wed, 12 Nov 2025 18:43:38 +0000 (+0200) Subject: ggml : use std::sort in ggml_argsort CPU implementation (#17211) X-Git-Tag: upstream/0.0.7446~407 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=374fe09cdd4a9d7ebaec6fde87ac2b9b75f019c4;p=pkg%2Fggml%2Fsources%2Fllama.cpp ggml : use std::sort in ggml_argsort CPU implementation (#17211) * ggml : use std::sort in ggml_argsort CPU implementation * cont : add missing header --- diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 5a272b9a..9f1e5f8d 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7,8 +7,9 @@ #include "unary-ops.h" #include "vec.h" -#include +#include #include +#include // ggml_compute_forward_dup @@ -7682,24 +7683,24 @@ static void ggml_compute_forward_argsort_f32( ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0); for (int64_t i = ith; i < nr; i += nth) { - int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1); const float * src_data = (float *)((char *) src0->data + i*nb01); + int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1); + for (int64_t j = 0; j < ne0; j++) { dst_data[j] = j; } - // C doesn't have a functional sort, so we do a bubble sort instead - for (int64_t j = 0; j < ne0; j++) { - for (int64_t k = j + 1; k < ne0; k++) { - if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) || - (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) { - int32_t tmp = dst_data[j]; - dst_data[j] = dst_data[k]; - dst_data[k] = tmp; - } - } + std::function cmp; + + // note: this might be causing memory allocations? ideally should be avoided if it's the case + switch (order) { + case GGML_SORT_ORDER_ASC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] < src_data[b]; }; break; + case GGML_SORT_ORDER_DESC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] > src_data[b]; }; break; + default: GGML_ABORT("invalid sort order"); } + + std::sort(dst_data, dst_data + ne0, cmp); } }