]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : fix bench regression + fix performance when using CPU BLAS (#1275)
authorGeorgi Gerganov <redacted>
Tue, 12 Sep 2023 10:54:04 +0000 (13:54 +0300)
committerGitHub <redacted>
Tue, 12 Sep 2023 10:54:04 +0000 (13:54 +0300)
* whisper : fix bench regression

* ggml : use sched_yield when using BLAS + add comment

ggml.c
whisper.cpp

diff --git a/ggml.c b/ggml.c
index 3f72379c3553e27bc3f685f8756f163a4aa5f860..dcdebd24cb8c23e94a1b17963fcc70c8f371b01f 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -17283,10 +17283,18 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
         } else {
             // wait for other threads to finish
             const int last = node_n;
-            do {
-                //sched_yield();
+            while (true) {
+                // TODO: this sched_yield can have significant impact on the performance - either positive or negative
+                //       depending on the workload and the operating system.
+                //       since it is not clear what is the best approach, it should potentially become user-configurable
+                //       ref: https://github.com/ggerganov/ggml/issues/291
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+                sched_yield();
+#endif
+
                 node_n = atomic_load(&state->shared->node_n);
-            } while (node_n == last);
+                if (node_n != last) break;
+            };
         }
 
         // check if we should stop
index 5c14b43efdcc9bc2796891cf22eb43604d60c648..f5a9a71f2d3ff32ca41e9410b08f2f8644fc742a 100644 (file)
@@ -118,6 +118,21 @@ static void byteswap_tensor(ggml_tensor * tensor) {
 #define WHISPER_USE_SCRATCH
 #define WHISPER_MAX_SCRATCH_BUFFERS 16
 
+//
+// ggml helpers
+//
+
+static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
+    struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
+
+    if (plan.work_size > 0) {
+        buf.resize(plan.work_size);
+        plan.work_data = buf.data();
+    }
+
+    ggml_graph_compute(graph, &plan);
+}
+
 // available whisper models
 enum e_model {
     MODEL_UNKNOWN,
@@ -666,6 +681,7 @@ struct whisper_state {
 
     // memory buffers used by encode / decode contexts
     std::vector<uint8_t> buf_compute;
+    std::vector<uint8_t> buf_work;
     std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
 
     int    buf_last = 0;
@@ -1830,8 +1846,8 @@ static bool whisper_encode_internal(
         {
             struct ggml_cgraph gf = {};
 
-            ggml_build_forward_expand  (&gf, cur);
-            ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+            ggml_build_forward_expand(&gf, cur);
+            ggml_graph_compute_helper(wstate.buf_work, &gf, n_threads);
 
             //ggml_graph_print(&gf);
         }
@@ -1916,7 +1932,7 @@ static bool whisper_encode_internal(
             ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
         }
 
-        ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+        ggml_graph_compute_helper(wstate.buf_work, &gf, n_threads);
         //ggml_graph_print(&gf);
     }
 
@@ -2329,8 +2345,8 @@ static bool whisper_decode_internal(
 
     // run the computation
     {
-        ggml_build_forward_expand  (&gf, logits);
-        ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+        ggml_build_forward_expand(&gf, logits);
+        ggml_graph_compute_helper(wstate.buf_work, &gf, n_threads);
     }
 
     // extract logits for all N tokens
@@ -5225,7 +5241,8 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
     // b: N*N*sizeof(float)
     // c: N*N*sizeof(float)
     // when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
-    std::vector<char> buf(4llu*N_max*N_max*sizeof(float) + 4*512);
+    std::vector<uint8_t> buf (3llu*N_max*N_max*sizeof(float) + 3*ggml_tensor_overhead());
+    std::vector<uint8_t> work(1llu*N_max*N_max*sizeof(float) + 1*ggml_tensor_overhead());
 
     // put a bunch of random data in the buffer
     for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
@@ -5280,12 +5297,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
             double tsum = 0.0;
 
             // heat-up
-            ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+            ggml_graph_compute_helper(work, &gf, n_threads);
 
             for (int i = 0; i < n_max; ++i) {
                 const int64_t t0 = ggml_time_us();
 
-                ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+                ggml_graph_compute_helper(work, &gf, n_threads);
 
                 const int64_t t1 = ggml_time_us();