// ggml_compute_forward_argsort
+template<enum ggml_sort_order order>
+struct argsort_cmp {
+ const float * data;
+ bool operator()(int32_t a, int32_t b) const {
+ if constexpr (order == GGML_SORT_ORDER_ASC) {
+ return data[a] < data[b];
+ } else {
+ return data[a] > data[b];
+ }
+ }
+};
+
static void ggml_compute_forward_argsort_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
dst_data[j] = j;
}
- std::function<bool(int32_t, int32_t)> cmp;
-
- // note: this might be causing memory allocations? ideally should be avoided if it's the case
switch (order) {
- case GGML_SORT_ORDER_ASC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] < src_data[b]; }; break;
- case GGML_SORT_ORDER_DESC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] > src_data[b]; }; break;
- default: GGML_ABORT("invalid sort order");
- }
+ case GGML_SORT_ORDER_ASC:
+ std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_ASC>{src_data});
+ break;
- std::sort(dst_data, dst_data + ne0, cmp);
+ case GGML_SORT_ORDER_DESC:
+ std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_DESC>{src_data});
+ break;
+
+ default:
+ GGML_ABORT("invalid sort order");
+ }
}
}