]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : use a simple std::thread in AMX without OpenMP (llama/20074)
authorAdrien Gallouët <redacted>
Wed, 4 Mar 2026 10:57:09 +0000 (11:57 +0100)
committerGeorgi Gerganov <redacted>
Mon, 16 Mar 2026 11:10:15 +0000 (13:10 +0200)
Disabling OpenMP generally provides better inference performance (at
least in my testing) but the loading becomes slightly slower.

Benchmark results for `convert_B_packed_format()`:

Before this commit:

         N      K |  No OpenMP     OpenMP |    Diff |  Speedup
    ------------------------------------------------------------
       512   2880 |    640.9us    263.5us |  -58.9% |    0.41x
      2880   4096 |     2.55ms    261.7us |  -89.8% |    0.10x
    201088   2880 |   256.44ms    21.61ms |  -91.6% |    0.08x
    ------------------------------------------------------------

    Total: 325.43ms vs 31.05ms

After:

         N      K |  No OpenMP     OpenMP |    Diff |  Speedup
    ------------------------------------------------------------
       512   2880 |     1.49ms    263.5us |  -82.3% |    0.18x
      2880   4096 |     1.55ms    261.7us |  -83.1% |    0.17x
    201088   2880 |    24.03ms    21.61ms |  -10.1% |    0.90x
    ------------------------------------------------------------

    Total: 78.97ms vs 31.05ms

Tested with unsloth/gpt-oss-20b-GGUF:Q4_K_M.

Signed-off-by: Adrien Gallouët <redacted>
ggml/src/ggml-cpu/amx/common.h

index f392e898518a7ed03931c92eb4acca3aff527946..26a6ec1a2d0092fa24ca8b9d937ebf1fed183976 100644 (file)
@@ -9,6 +9,8 @@
 
 #if defined(GGML_USE_OPENMP)
 #include <omp.h>
+#else
+#include <thread>
 #endif
 
 #define TILE_M 16
@@ -56,18 +58,40 @@ inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
 }
 
 template <typename func_t>
-inline void parallel_for(int n, const func_t& f) {
+inline void parallel_for(int n, const func_t & f) {
+    if (n <= 0) {
+        return;
+    }
 #if defined(GGML_USE_OPENMP)
-#pragma omp parallel
-{
-    int nth = omp_get_num_threads();
-    int ith = omp_get_thread_num();
-    int tbegin, tend;
-    balance211(n, nth, ith, tbegin, tend);
-    f(tbegin, tend);
-}
+    #pragma omp parallel
+    {
+        int nth = omp_get_num_threads();
+        int ith = omp_get_thread_num();
+        int tbegin, tend;
+        balance211(n, nth, ith, tbegin, tend);
+        f(tbegin, tend);
+    }
 #else
-    f(0, n);
+    int nth = std::thread::hardware_concurrency();
+    if (nth <= 1) {
+        f(0, n);
+        return;
+    }
+    if (nth > n) {
+        nth = n;
+    }
+    std::vector<std::thread> threads;
+    threads.reserve(nth);
+    for (int ith = 0; ith < nth; ++ith) {
+        threads.emplace_back([&f, n, ith, nth] {
+            int tbegin, tend;
+            balance211(n, nth, ith, tbegin, tend);
+            f(tbegin, tend);
+        });
+    }
+    for (auto & t : threads) {
+        t.join();
+    }
 #endif
 }