]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : fix running tasks with variable number of threads
authorGeorgi Gerganov <redacted>
Sat, 7 Jan 2023 17:20:18 +0000 (19:20 +0200)
committerGeorgi Gerganov <redacted>
Sat, 7 Jan 2023 17:20:18 +0000 (19:20 +0200)
ggml.c

diff --git a/ggml.c b/ggml.c
index e627164a80bd2c48570cd98f7af00fc3cca8b894..058241e7f9dcdba86b4e1845367a8485c5dc5587 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -4745,7 +4745,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
         // TODO: do not support transposed src1
         assert(nb10/2 == sizeof(ggml_fp16_t));
 
-        // parallelize by src0 rows using ggml_vec_dot_f32
+        // parallelize by src0 rows using ggml_vec_dot_f16
 
         // total rows in src0
         const int nr = ne01*ne02*ne03;
@@ -4773,7 +4773,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
             const int i3 = i03;
 
             ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
-            ggml_fp16_t * src1_col = wdata + (i13*ne12*ne11 + i12*ne11 + 0)*ne00;
+            ggml_fp16_t * src1_col =                                wdata + (       0 + i12*ne11 + i13*ne12*ne11)*ne00;
 
             float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
 
@@ -7142,7 +7142,9 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
         }
 
         if (state->node) {
-            ggml_compute_forward(&state->params, state->node);
+            if (state->params.ith < state->params.nth) {
+                ggml_compute_forward(&state->params, state->node);
+            }
             state->node = NULL;
         } else {
             break;
@@ -7236,9 +7238,15 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     } break;
                 case GGML_OP_MUL_MAT:
                     {
-                        // TODO: use different scheduling for different matrix sizes
                         node->n_tasks = n_threads;
 
+                        // TODO: use different scheduling for different matrix sizes
+                        //const int nr0 = ggml_nrows(node->src0);
+                        //const int nr1 = ggml_nrows(node->src1);
+
+                        //node->n_tasks = MIN(n_threads, MAX(1, nr0/128));
+                        //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks);
+
                         size_t cur = 0;
 
                         // TODO: better way to determine if the matrix is transposed
@@ -7422,7 +7430,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                 workers[j].params = (struct ggml_compute_params) {
                     .type  = GGML_TASK_COMPUTE,
                     .ith   = j + 1,
-                    .nth   = n_threads,
+                    .nth   = node->n_tasks,
                     .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
                     .wdata = cgraph->work ? cgraph->work->data : NULL,
                 };
@@ -7477,7 +7485,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                 workers[j].params = (struct ggml_compute_params) {
                     .type  = GGML_TASK_FINALIZE,
                     .ith   = j + 1,
-                    .nth   = n_threads,
+                    .nth   = node->n_tasks,
                     .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
                     .wdata = cgraph->work ? cgraph->work->data : NULL,
                 };