#include "unary-ops.h"
#include "vec.h"
-#include <float.h>
+#include <cfloat>
#include <algorithm>
+#include <functional>
// ggml_compute_forward_dup
ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
for (int64_t i = ith; i < nr; i += nth) {
- int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
const float * src_data = (float *)((char *) src0->data + i*nb01);
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
+
for (int64_t j = 0; j < ne0; j++) {
dst_data[j] = j;
}
- // C doesn't have a functional sort, so we do a bubble sort instead
- for (int64_t j = 0; j < ne0; j++) {
- for (int64_t k = j + 1; k < ne0; k++) {
- if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
- (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
- int32_t tmp = dst_data[j];
- dst_data[j] = dst_data[k];
- dst_data[k] = tmp;
- }
- }
+ std::function<bool(int32_t, int32_t)> cmp;
+
+ // note: this might be causing memory allocations? ideally should be avoided if it's the case
+ switch (order) {
+ case GGML_SORT_ORDER_ASC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] < src_data[b]; }; break;
+ case GGML_SORT_ORDER_DESC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] > src_data[b]; }; break;
+ default: GGML_ABORT("invalid sort order");
}
+
+ std::sort(dst_data, dst_data + ne0, cmp);
}
}