]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : multi-thread mul and diag_mask ops (#1428)
authorGeorgi Gerganov <redacted>
Sat, 13 May 2023 13:48:03 +0000 (16:48 +0300)
committerGitHub <redacted>
Sat, 13 May 2023 13:48:03 +0000 (16:48 +0300)
ggml.c

diff --git a/ggml.c b/ggml.c
index 05746383974a0b8a1fa64a018e4f8141d9734093..e5b3528d8a742de5f4f4d83fe8230d757ea27ec2 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -7765,12 +7765,13 @@ static void ggml_compute_forward_mul_f32(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    assert(params->ith == 0);
     assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
+    const int ith = params->ith;
+    const int nth = params->nth;
 
     const int nr  = ggml_nrows(src0);
     const int64_t ne0 = src0->ne[0];
@@ -7796,7 +7797,7 @@ static void ggml_compute_forward_mul_f32(
     GGML_ASSERT(nb00 == sizeof(float));
 
     if (nb10 == sizeof(float)) {
-        for (int ir = 0; ir < nr; ++ir) {
+        for (int ir = ith; ir < nr; ir += nth) {
             // src0, src1 and dst are same shape => same indices
             const int i3 = ir/(ne2*ne1);
             const int i2 = (ir - i3*ne2*ne1)/ne1;
@@ -7822,7 +7823,7 @@ static void ggml_compute_forward_mul_f32(
         }
     } else {
         // src1 is not contiguous
-        for (int ir = 0; ir < nr; ++ir) {
+        for (int ir = ith; ir < nr; ir += nth) {
             // src0, src1 and dst are same shape => same indices
             const int i3 = ir/(ne2*ne1);
             const int i2 = (ir - i3*ne2*ne1)/ne1;
@@ -10317,7 +10318,6 @@ static void ggml_compute_forward_diag_mask_f32(
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst,
         const float value) {
-    assert(params->ith == 0);
     assert(src1->type == GGML_TYPE_I32);
     assert(ggml_nelements(src1) == 2);
 
@@ -10325,6 +10325,9 @@ static void ggml_compute_forward_diag_mask_f32(
         return;
     }
 
+    const int ith = params->ith;
+    const int nth = params->nth;
+
     const int  n_past  =       ((int32_t *) src1->data)[0];
     const bool inplace = (bool)((int32_t *) src1->data)[1];
 
@@ -10343,7 +10346,7 @@ static void ggml_compute_forward_diag_mask_f32(
     assert(src0->nb[0] == sizeof(float));
 
     for (int k = 0; k < nz; k++) {
-        for (int j = 0; j < nr; j++) {
+        for (int j = ith; j < nr; j += nth) {
             for (int i = n_past; i < nc; i++) {
                 if (i > n_past + j) {
                     *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value;
@@ -13609,7 +13612,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                         work_size = MAX(work_size, cur);
                     } break;
                 case GGML_OP_SUB:
-                case GGML_OP_MUL:
                 case GGML_OP_DIV:
                 case GGML_OP_SQR:
                 case GGML_OP_SQRT:
@@ -13626,18 +13628,10 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     {
                         node->n_tasks = 1;
                     } break;
+                case GGML_OP_MUL:
                 case GGML_OP_GELU:
-                    {
-                        node->n_tasks = n_threads;
-                    } break;
                 case GGML_OP_SILU:
-                    {
-                        node->n_tasks = n_threads;
-                    } break;
                 case GGML_OP_SILU_BACK:
-                    {
-                        node->n_tasks = n_threads;
-                    } break;
                 case GGML_OP_NORM:
                 case GGML_OP_RMS_NORM:
                 case GGML_OP_RMS_NORM_BACK:
@@ -13715,11 +13709,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                 case GGML_OP_GET_ROWS:
                 case GGML_OP_GET_ROWS_BACK:
                 case GGML_OP_DIAG:
-                case GGML_OP_DIAG_MASK_INF:
                 case GGML_OP_DIAG_MASK_ZERO:
                     {
                         node->n_tasks = 1;
                     } break;
+                case GGML_OP_DIAG_MASK_INF:
                 case GGML_OP_SOFT_MAX:
                 case GGML_OP_ROPE:
                 case GGML_OP_ROPE_BACK: