]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : reuse ggml_get_n_tasks() in ggml_graph_plan() (#4308)
authorGeorgi Gerganov <redacted>
Sun, 3 Dec 2023 13:56:35 +0000 (15:56 +0200)
committerGitHub <redacted>
Sun, 3 Dec 2023 13:56:35 +0000 (15:56 +0200)
* ggml : fix soft max out-of-bounds access

ggml-ci

* ggml : reuse ggml_get_n_tasks() in ggml_graph_plan()

ggml-ci

ggml.c

diff --git a/ggml.c b/ggml.c
index cecb12700ea858bb3feac2121d34feabb0f51260..f743df1f3709a0891be3a57046ab94958a5ff84a 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -15879,18 +15879,16 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
 
     // thread scheduling for the different operations + work buffer size estimation
     for (int i = 0; i < cgraph->n_nodes; i++) {
-        int n_tasks = 1;
-
         struct ggml_tensor * node = cgraph->nodes[i];
 
+        const int n_tasks = ggml_get_n_tasks(node, n_threads);
+
         size_t cur = 0;
 
         switch (node->op) {
             case GGML_OP_CPY:
             case GGML_OP_DUP:
                 {
-                    n_tasks = n_threads;
-
                     if (ggml_is_quantized(node->type)) {
                         cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
                     }
@@ -15898,16 +15896,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
             case GGML_OP_ADD:
             case GGML_OP_ADD1:
                 {
-                    n_tasks = n_threads;
-
                     if (ggml_is_quantized(node->src[0]->type)) {
                         cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
                     }
                 } break;
             case GGML_OP_ACC:
                 {
-                    n_tasks = n_threads;
-
                     if (ggml_is_quantized(node->src[0]->type)) {
                         cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
                     }
@@ -15935,16 +15929,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                 } break;
             case GGML_OP_OUT_PROD:
                 {
-                    n_tasks = n_threads;
-
                     if (ggml_is_quantized(node->src[0]->type)) {
                         cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
                     }
                 } break;
             case GGML_OP_SOFT_MAX:
                 {
-                    n_tasks = MIN(MIN(4, n_threads), ggml_nrows(node->src[0]));
-
                     cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
                 } break;
             case GGML_OP_CONV_TRANSPOSE_1D:
@@ -15974,7 +15964,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                 } break;
             case GGML_OP_IM2COL:
                 {
-                    n_tasks = n_threads;
                 } break;
             case GGML_OP_CONV_TRANSPOSE_2D:
                 {
@@ -15992,8 +15981,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                 } break;
             case GGML_OP_FLASH_ATTN:
                 {
-                    n_tasks = n_threads;
-
                     const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
 
                     if (node->src[1]->type == GGML_TYPE_F32) {
@@ -16006,8 +15993,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                 } break;
             case GGML_OP_FLASH_FF:
                 {
-                    n_tasks = n_threads;
-
                     if (node->src[1]->type == GGML_TYPE_F32) {
                         cur  = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
                         cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
@@ -16018,8 +16003,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                 } break;
             case GGML_OP_FLASH_ATTN_BACK:
                 {
-                    n_tasks = n_threads;
-
                     const int64_t    D = node->src[0]->ne[0];
                     const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
                     const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
@@ -16034,8 +16017,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
 
             case GGML_OP_CROSS_ENTROPY_LOSS:
                 {
-                    n_tasks = n_threads;
-
                     cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
                 } break;
             case GGML_OP_COUNT: