]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
bench : pass memcpy threads from cli
authorGeorgi Gerganov <redacted>
Tue, 21 Nov 2023 20:27:22 +0000 (22:27 +0200)
committerGeorgi Gerganov <redacted>
Tue, 21 Nov 2023 20:27:22 +0000 (22:27 +0200)
whisper.cpp

index 03001902b07b8a37dbec8cd43e602ab4f93c57e1..2727bada1855152aeec91cecad8f5aa39ebe914e 100644 (file)
@@ -6138,7 +6138,7 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
 
     // multi-thread
 
-    for (uint32_t n_threads = 1; n_threads <= std::thread::hardware_concurrency(); n_threads++) {
+    for (uint32_t k = 1; k <= n_threads; k++) {
         char * src = (char *) malloc(size);
         char * dst = (char *) malloc(size);
 
@@ -6149,8 +6149,8 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
         double tsum = 0.0;
 
         auto helper = [&](int th) {
-            const int64_t i0 = (th + 0)*size/n_threads;
-            const int64_t i1 = (th + 1)*size/n_threads;
+            const int64_t i0 = (th + 0)*size/k;
+            const int64_t i1 = (th + 1)*size/k;
 
             for (size_t i = 0; i < n; i++) {
                 memcpy(dst + i0, src + i0, i1 - i0);
@@ -6161,14 +6161,14 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
 
         const int64_t t0 = ggml_time_us();
 
-        std::vector<std::thread> threads(n_threads - 1);
-        for (uint32_t th = 0; th < n_threads - 1; ++th) {
+        std::vector<std::thread> threads(k - 1);
+        for (uint32_t th = 0; th < k - 1; ++th) {
             threads[th] = std::thread(helper, th);
         }
 
-        helper(n_threads - 1);
+        helper(k - 1);
 
-        for (uint32_t th = 0; th < n_threads - 1; ++th) {
+        for (uint32_t th = 0; th < k - 1; ++th) {
             threads[th].join();
         }
 
@@ -6176,7 +6176,7 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
 
         tsum += (t1 - t0)*1e-6;
 
-        snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (%2d thread)\n", (double) (n*size)/(tsum*1e9), n_threads);
+        snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (%2d thread)\n", (double) (n*size)/(tsum*1e9), k);
         s += strbuf;
 
         // needed to prevent the compiler from optimizing the memcpy away