]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml-cpu : use template for argsort (llama/17222)
authorDiego Devesa <redacted>
Thu, 13 Nov 2025 08:59:05 +0000 (00:59 -0800)
committerGeorgi Gerganov <redacted>
Mon, 17 Nov 2025 19:05:46 +0000 (21:05 +0200)
ggml/src/ggml-cpu/ops.cpp

index 9f1e5f8d6463216849d9e55bb4e5120c1ae4195d..09f53b470b26af8b20715c5f6fb3a059c980cf5e 100644 (file)
@@ -7665,6 +7665,18 @@ void ggml_compute_forward_timestep_embedding(
 
 // ggml_compute_forward_argsort
 
+template<enum ggml_sort_order order>
+struct argsort_cmp {
+    const float * data;
+    bool operator()(int32_t a, int32_t b) const {
+        if constexpr (order == GGML_SORT_ORDER_ASC) {
+            return data[a] < data[b];
+        } else {
+            return data[a] > data[b];
+        }
+    }
+};
+
 static void ggml_compute_forward_argsort_f32(
     const ggml_compute_params * params,
     ggml_tensor * dst) {
@@ -7691,16 +7703,18 @@ static void ggml_compute_forward_argsort_f32(
             dst_data[j] = j;
         }
 
-        std::function<bool(int32_t, int32_t)> cmp;
-
-        // note: this might be causing memory allocations? ideally should be avoided if it's the case
         switch (order) {
-            case GGML_SORT_ORDER_ASC:  cmp = [src_data](int32_t a, int32_t b) { return src_data[a] < src_data[b]; }; break;
-            case GGML_SORT_ORDER_DESC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] > src_data[b]; }; break;
-            default: GGML_ABORT("invalid sort order");
-        }
+            case GGML_SORT_ORDER_ASC:
+                std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_ASC>{src_data});
+                break;
 
-        std::sort(dst_data, dst_data + ne0, cmp);
+            case GGML_SORT_ORDER_DESC:
+                std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_DESC>{src_data});
+                break;
+
+            default:
+                GGML_ABORT("invalid sort order");
+        }
     }
 }