]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : multi-thread ggml_rope() (~3-4 times faster on M1) (#781)
authorGeorgi Gerganov <redacted>
Wed, 5 Apr 2023 19:11:03 +0000 (22:11 +0300)
committerGitHub <redacted>
Wed, 5 Apr 2023 19:11:03 +0000 (22:11 +0300)
ggml.c

diff --git a/ggml.c b/ggml.c
index a3a3314c746d16b211ce087e8d84d6376bc6203c..8a60bc3831c6d3f1c3f4731cafff11e5847d94f1 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -7238,7 +7238,6 @@ static void ggml_compute_forward_rope_f32(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    assert(params->ith == 0);
     assert(src1->type == GGML_TYPE_I32);
     assert(ggml_nelements(src1) == 3);
 
@@ -7265,11 +7264,28 @@ static void ggml_compute_forward_rope_f32(
 
     assert(nb0 == sizeof(float));
 
-    // TODO: optimize
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    // row index used to determine which thread to use
+    int ir = 0;
+
     for (int64_t i3 = 0; i3 < ne3; i3++) {
         for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
             const int p = (mode == 0 ? n_past + i2 : i2);
             for (int64_t i1 = 0; i1 < ne1; i1++) {
+                if (ir++ < ir0) continue;
+                if (ir   > ir1) break;
+
                 for (int i0 = 0; i0 < n_dims; i0 += 2) {
                     const float theta = powf(10000.0, ((float)-i0)/n_dims);
 
@@ -7295,7 +7311,6 @@ static void ggml_compute_forward_rope_f16(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    assert(params->ith == 0);
     assert(src1->type == GGML_TYPE_I32);
     assert(ggml_nelements(src1) == 3);
 
@@ -7322,10 +7337,28 @@ static void ggml_compute_forward_rope_f16(
 
     assert(nb0 == sizeof(ggml_fp16_t));
 
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    // row index used to determine which thread to use
+    int ir = 0;
+
     for (int64_t i3 = 0; i3 < ne3; i3++) {
         for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
             const int p = (mode == 0 ? n_past + i2 : i2);
             for (int64_t i1 = 0; i1 < ne1; i1++) {
+                if (ir++ < ir0) continue;
+                if (ir   > ir1) break;
+
                 for (int i0 = 0; i0 < n_dims; i0 += 2) {
                     const float theta = powf(10000.0, ((float)-i0)/n_dims);
 
@@ -9424,7 +9457,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     } break;
                 case GGML_OP_ROPE:
                     {
-                        node->n_tasks = 1;
+                        node->n_tasks = n_threads;
                     } break;
                 case GGML_OP_CONV_1D_1S:
                 case GGML_OP_CONV_1D_2S: