From: Diego Devesa Date: Thu, 13 Nov 2025 08:59:05 +0000 (-0800) Subject: ggml-cpu : use template for argsort (llama/17222) X-Git-Tag: upstream/1.8.3~314 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=6a91780c3b112a75705ad111de2e32c634c3afba;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp ggml-cpu : use template for argsort (llama/17222) --- diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 9f1e5f8d..09f53b47 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7665,6 +7665,18 @@ void ggml_compute_forward_timestep_embedding( // ggml_compute_forward_argsort +template +struct argsort_cmp { + const float * data; + bool operator()(int32_t a, int32_t b) const { + if constexpr (order == GGML_SORT_ORDER_ASC) { + return data[a] < data[b]; + } else { + return data[a] > data[b]; + } + } +}; + static void ggml_compute_forward_argsort_f32( const ggml_compute_params * params, ggml_tensor * dst) { @@ -7691,16 +7703,18 @@ static void ggml_compute_forward_argsort_f32( dst_data[j] = j; } - std::function cmp; - - // note: this might be causing memory allocations? ideally should be avoided if it's the case switch (order) { - case GGML_SORT_ORDER_ASC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] < src_data[b]; }; break; - case GGML_SORT_ORDER_DESC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] > src_data[b]; }; break; - default: GGML_ABORT("invalid sort order"); - } + case GGML_SORT_ORDER_ASC: + std::sort(dst_data, dst_data + ne0, argsort_cmp{src_data}); + break; - std::sort(dst_data, dst_data + ne0, cmp); + case GGML_SORT_ORDER_DESC: + std::sort(dst_data, dst_data + ne0, argsort_cmp{src_data}); + break; + + default: + GGML_ABORT("invalid sort order"); + } } }