]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : limit n_threads to the max n_tasks (llama/5238)
authorslaren <redacted>
Wed, 31 Jan 2024 12:43:03 +0000 (13:43 +0100)
committerGeorgi Gerganov <redacted>
Sat, 10 Feb 2024 07:55:46 +0000 (09:55 +0200)
ggml.c

diff --git a/ggml.c b/ggml.c
index f6e797d78f244c16dfc4df268f08c0f4f3d62af5..1286ea8e82d487b9df80795414090980b8f380c8 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -16985,12 +16985,16 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
     struct ggml_cplan cplan;
     memset(&cplan, 0, sizeof(struct ggml_cplan));
 
+    int max_tasks = 1;
+
     // thread scheduling for the different operations + work buffer size estimation
     for (int i = 0; i < cgraph->n_nodes; i++) {
         struct ggml_tensor * node = cgraph->nodes[i];
 
         const int n_tasks = ggml_get_n_tasks(node, n_threads);
 
+        max_tasks = MAX(max_tasks, n_tasks);
+
         size_t cur = 0;
 
         switch (node->op) {
@@ -17157,7 +17161,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
         work_size += CACHE_LINE_SIZE*(n_threads - 1);
     }
 
-    cplan.n_threads = n_threads;
+    cplan.n_threads = MIN(max_tasks, n_threads);
     cplan.work_size = work_size;
     cplan.work_data = NULL;