From: Diego Devesa Date: Thu, 13 Nov 2025 08:59:05 +0000 (-0800) Subject: ggml-cpu : use template for argsort (llama/17222) X-Git-Tag: upstream/0.9.4.395~186 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=993832081c0fc7e12d8e586c8563619b2c226217;p=pkg%2Fggml%2Fsources%2Fggml ggml-cpu : use template for argsort (llama/17222) --- diff --git a/src/ggml-cpu/ops.cpp b/src/ggml-cpu/ops.cpp index 9f1e5f8d..09f53b47 100644 --- a/src/ggml-cpu/ops.cpp +++ b/src/ggml-cpu/ops.cpp @@ -7665,6 +7665,18 @@ void ggml_compute_forward_timestep_embedding( // ggml_compute_forward_argsort +template +struct argsort_cmp { + const float * data; + bool operator()(int32_t a, int32_t b) const { + if constexpr (order == GGML_SORT_ORDER_ASC) { + return data[a] < data[b]; + } else { + return data[a] > data[b]; + } + } +}; + static void ggml_compute_forward_argsort_f32( const ggml_compute_params * params, ggml_tensor * dst) { @@ -7691,16 +7703,18 @@ static void ggml_compute_forward_argsort_f32( dst_data[j] = j; } - std::function cmp; - - // note: this might be causing memory allocations? ideally should be avoided if it's the case switch (order) { - case GGML_SORT_ORDER_ASC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] < src_data[b]; }; break; - case GGML_SORT_ORDER_DESC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] > src_data[b]; }; break; - default: GGML_ABORT("invalid sort order"); - } + case GGML_SORT_ORDER_ASC: + std::sort(dst_data, dst_data + ne0, argsort_cmp{src_data}); + break; - std::sort(dst_data, dst_data + ne0, cmp); + case GGML_SORT_ORDER_DESC: + std::sort(dst_data, dst_data + ne0, argsort_cmp{src_data}); + break; + + default: + GGML_ABORT("invalid sort order"); + } } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 92c17ac4..38b7ddf2 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7631,6 +7631,8 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_sum(GGML_TYPE_F32, it)); } + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1})); + return test_cases; }