]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml-cpu : use template for argsort (llama/17222)
authorDiego Devesa <redacted>
Thu, 13 Nov 2025 08:59:05 +0000 (00:59 -0800)
committerGeorgi Gerganov <redacted>
Mon, 17 Nov 2025 13:34:43 +0000 (15:34 +0200)
src/ggml-cpu/ops.cpp
tests/test-backend-ops.cpp

index 9f1e5f8d6463216849d9e55bb4e5120c1ae4195d..09f53b470b26af8b20715c5f6fb3a059c980cf5e 100644 (file)
@@ -7665,6 +7665,18 @@ void ggml_compute_forward_timestep_embedding(
 
 // ggml_compute_forward_argsort
 
+template<enum ggml_sort_order order>
+struct argsort_cmp {
+    const float * data;
+    bool operator()(int32_t a, int32_t b) const {
+        if constexpr (order == GGML_SORT_ORDER_ASC) {
+            return data[a] < data[b];
+        } else {
+            return data[a] > data[b];
+        }
+    }
+};
+
 static void ggml_compute_forward_argsort_f32(
     const ggml_compute_params * params,
     ggml_tensor * dst) {
@@ -7691,16 +7703,18 @@ static void ggml_compute_forward_argsort_f32(
             dst_data[j] = j;
         }
 
-        std::function<bool(int32_t, int32_t)> cmp;
-
-        // note: this might be causing memory allocations? ideally should be avoided if it's the case
         switch (order) {
-            case GGML_SORT_ORDER_ASC:  cmp = [src_data](int32_t a, int32_t b) { return src_data[a] < src_data[b]; }; break;
-            case GGML_SORT_ORDER_DESC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] > src_data[b]; }; break;
-            default: GGML_ABORT("invalid sort order");
-        }
+            case GGML_SORT_ORDER_ASC:
+                std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_ASC>{src_data});
+                break;
 
-        std::sort(dst_data, dst_data + ne0, cmp);
+            case GGML_SORT_ORDER_DESC:
+                std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_DESC>{src_data});
+                break;
+
+            default:
+                GGML_ABORT("invalid sort order");
+        }
     }
 }
 
index 92c17ac4399c0ac62c85c5b5efa959482eb5836c..38b7ddf22178531cd8f301963c41b0a535b611ad 100644 (file)
@@ -7631,6 +7631,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
         test_cases.emplace_back(new test_sum(GGML_TYPE_F32, it));
     }
 
+    test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1}));
+
     return test_cases;
 }