]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : fix rope calculation (!inplace + GPT-NeoX mode)
authorGeorgi Gerganov <redacted>
Sun, 14 May 2023 11:16:47 +0000 (14:16 +0300)
committerGeorgi Gerganov <redacted>
Sun, 14 May 2023 12:18:34 +0000 (15:18 +0300)
examples/gpt-j/main.cpp
src/ggml.c

index 5362407bee749d1394205a1c9077b389615e6f27..7c9197a991c40b36aac80b765cc921c3680b2a0f 100644 (file)
@@ -582,7 +582,7 @@ bool gptj_eval(
 
     //if (n_past%100 == 0) {
     //    ggml_graph_print   (&gf);
-    //    ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
+    //    ggml_graph_dump_dot(&gf, NULL, "gpt-j.dot");
     //}
 
     //embd_w.resize(n_vocab*N);
index 7773cebc110325b47aabf82b31ebb54ab3ed7bb8..63da6799b8ff2461fe37ab8e1df244808ec7effe 100644 (file)
@@ -10710,28 +10710,32 @@ static void ggml_compute_forward_rope_f32(
 
     assert(n_past >= 0);
 
-    //const int64_t ne0 = src0->ne[0];
-    const int64_t ne1 = src0->ne[1];
-    const int64_t ne2 = src0->ne[2];
-    const int64_t ne3 = src0->ne[3];
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
 
-    const int nb0 = src0->nb[0];
-    const int nb1 = src0->nb[1];
-    const int nb2 = src0->nb[2];
-    const int nb3 = src0->nb[3];
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    const int64_t ne3 = dst->ne[3];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
 
     //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
     //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
 
-    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
 
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int nr = ggml_nrows(src0);
-    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(dst);
 
-    GGML_ASSERT(n_dims <= nc);
+    GGML_ASSERT(n_dims <= ne0);
     GGML_ASSERT(n_dims % 2 == 0);
 
     // rows per thread
@@ -10750,37 +10754,50 @@ static void ggml_compute_forward_rope_f32(
 
     for (int64_t i3 = 0; i3 < ne3; i3++) {
         for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
-            const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
+            const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
             for (int64_t i1 = 0; i1 < ne1; i1++) {
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
                 float theta = (float)p;
 
-                for (int i0 = 0; i0 < n_dims; i0 += 2) {
-                    const float cos_theta = cosf(theta);
-                    const float sin_theta = sinf(theta);
+                if (!is_neox) {
+                    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+                        const float cos_theta = cosf(theta);
+                        const float sin_theta = sinf(theta);
 
-                    theta *= theta_scale;
+                        theta *= theta_scale;
 
-                    if (!is_neox) {
-                        const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-                              float * dst_data  = (float *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                              float * dst_data  = (float *)((char *)  dst->data +  i3*nb3 + i2*nb2  + i1*nb1  + i0*nb0);
 
                         const float x0 = src[0];
                         const float x1 = src[1];
 
                         dst_data[0] = x0*cos_theta - x1*sin_theta;
                         dst_data[1] = x0*sin_theta + x1*cos_theta;
-                    } else {
-                        const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
-                              float * dst_data  = (float *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
+                    }
+                } else {
+                    // TODO: this is probably wrong, but I can't figure it out ..
+                    // ref:  https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
+                    for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
+                        for (int64_t ic = 0; ic < n_dims; ic += 2) {
+                            const float cos_theta = cosf(theta);
+                            const float sin_theta = sinf(theta);
 
-                        const float x0 = src[0];
-                        const float x1 = src[n_dims/2];
+                            theta *= theta_scale;
+
+                            const int64_t i0 = ib*n_dims + ic/2;
+
+                            const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                                  float * dst_data  = (float *)((char *)  dst->data +  i3*nb3 + i2*nb2  + i1*nb1  + i0*nb0);
+
+                            const float x0 = src[0];
+                            const float x1 = src[n_dims/2];
 
-                        dst_data[0]        = x0*cos_theta - x1*sin_theta;
-                        dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+                            dst_data[0]        = x0*cos_theta - x1*sin_theta;
+                            dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+                        }
                     }
                 }
             }
@@ -10806,15 +10823,20 @@ static void ggml_compute_forward_rope_f16(
 
     assert(n_past >= 0);
 
-    //const int64_t ne0 = src0->ne[0];
-    const int64_t ne1 = src0->ne[1];
-    const int64_t ne2 = src0->ne[2];
-    const int64_t ne3 = src0->ne[3];
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
 
-    const int nb0 = src0->nb[0];
-    const int nb1 = src0->nb[1];
-    const int nb2 = src0->nb[2];
-    const int nb3 = src0->nb[3];
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    const int64_t ne3 = dst->ne[3];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
 
     //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
     //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -10824,10 +10846,9 @@ static void ggml_compute_forward_rope_f16(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int nr = ggml_nrows(src0);
-    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(dst);
 
-    GGML_ASSERT(n_dims <= nc);
+    GGML_ASSERT(n_dims <= ne0);
     GGML_ASSERT(n_dims % 2 == 0);
 
     // rows per thread
@@ -10846,37 +10867,50 @@ static void ggml_compute_forward_rope_f16(
 
     for (int64_t i3 = 0; i3 < ne3; i3++) {
         for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
-            const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
+            const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
             for (int64_t i1 = 0; i1 < ne1; i1++) {
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
                 float theta = (float)p;
 
-                for (int i0 = 0; i0 < n_dims; i0 += 2) {
-                    const float cos_theta = cosf(theta);
-                    const float sin_theta = sinf(theta);
+                if (!is_neox) {
+                    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+                        const float cos_theta = cosf(theta);
+                        const float sin_theta = sinf(theta);
 
-                    theta *= theta_scale;
+                        theta *= theta_scale;
 
-                    if (!is_neox) {
-                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-                              ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                              ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
                         const float x0 = GGML_FP16_TO_FP32(src[0]);
                         const float x1 = GGML_FP16_TO_FP32(src[1]);
 
                         dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
                         dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
-                    } else {
-                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
-                              ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
+                    }
+                } else {
+                    // TODO: this is probably wrong, but I can't figure it out ..
+                    // ref:  https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
+                    for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
+                        for (int64_t ic = 0; ic < n_dims; ic += 2) {
+                            const float cos_theta = cosf(theta);
+                            const float sin_theta = sinf(theta);
 
-                        const float x0 = GGML_FP16_TO_FP32(src[0]);
-                        const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
+                            theta *= theta_scale;
+
+                            const int64_t i0 = ib*n_dims + ic/2;
+
+                            const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                                  ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                            const float x0 = GGML_FP16_TO_FP32(src[0]);
+                            const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
 
-                        dst_data[0]        = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
-                        dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+                            dst_data[0]     = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+                            dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+                        }
                     }
                 }
             }
@@ -10929,15 +10963,21 @@ static void ggml_compute_forward_rope_back_f32(
 
     assert(n_past >= 0);
 
-    //const int64_t ne0 = src0->ne[0];
-    const int64_t ne1 = src0->ne[1];
-    const int64_t ne2 = src0->ne[2];
-    const int64_t ne3 = src0->ne[3];
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    const int64_t ne3 = dst->ne[3];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
 
-    const int nb0 = src0->nb[0];
-    const int nb1 = src0->nb[1];
-    const int nb2 = src0->nb[2];
-    const int nb3 = src0->nb[3];
 
     //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
     //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -10947,7 +10987,7 @@ static void ggml_compute_forward_rope_back_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int nr = ggml_nrows(src0);
+    const int nr = ggml_nrows(dst);
 
     // rows per thread
     const int dr = (nr + nth - 1)/nth;
@@ -10965,37 +11005,48 @@ static void ggml_compute_forward_rope_back_f32(
 
     for (int64_t i3 = 0; i3 < ne3; i3++) {
         for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
-            const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
+            const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
             for (int64_t i1 = 0; i1 < ne1; i1++) {
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
                 float theta = (float)p;
 
-                for (int i0 = 0; i0 < n_dims; i0 += 2) {
-                    const float cos_theta = cosf(theta);
-                    const float sin_theta = sinf(theta);
+                if (!is_neox) {
+                    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+                        const float cos_theta = cosf(theta);
+                        const float sin_theta = sinf(theta);
 
-                    theta *= theta_scale;
+                        theta *= theta_scale;
 
-                    if (!is_neox) {
-                        const float * const dy  = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-                              float *       dx  = (float *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+                        const float * const dy  = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                              float *       dx  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
                         const float dy0 = dy[0];
                         const float dy1 = dy[1];
 
                         dx[0] =   dy0*cos_theta + dy1*sin_theta;
                         dx[1] = - dy0*sin_theta + dy1*cos_theta;
-                    } else {
-                        const float * const dy  = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
-                              float *       dx  = (float *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
+                    }
+                } else {
+                    for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
+                        for (int64_t ic = 0; ic < n_dims; ic += 2) {
+                            const float cos_theta = cosf(theta);
+                            const float sin_theta = sinf(theta);
 
-                        const float dy0 = dy[0];
-                        const float dy1 = dy[n_dims/2];
+                            theta *= theta_scale;
+
+                            const int64_t i0 = ib*n_dims + ic/2;
 
-                        dx[0]        =   dy0*cos_theta + dy1*sin_theta;
-                        dx[n_dims/2] = - dy0*sin_theta + dy1*cos_theta;
+                            const float * const dy  = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                                  float *       dx  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                            const float dy0 = dy[0];
+                            const float dy1 = dy[n_dims/2];
+
+                            dx[0]        =   dy0*cos_theta + dy1*sin_theta;
+                            dx[n_dims/2] = - dy0*sin_theta + dy1*cos_theta;
+                        }
                     }
                 }
             }
@@ -11025,15 +11076,21 @@ static void ggml_compute_forward_rope_back_f16(
 
     assert(n_past >= 0);
 
-    //const int64_t ne0 = src0->ne[0];
-    const int64_t ne1 = src0->ne[1];
-    const int64_t ne2 = src0->ne[2];
-    const int64_t ne3 = src0->ne[3];
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    const int64_t ne3 = dst->ne[3];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
 
-    const int nb0 = src0->nb[0];
-    const int nb1 = src0->nb[1];
-    const int nb2 = src0->nb[2];
-    const int nb3 = src0->nb[3];
 
     //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
     //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -11043,7 +11100,7 @@ static void ggml_compute_forward_rope_back_f16(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int nr = ggml_nrows(src0);
+    const int nr = ggml_nrows(dst);
 
     // rows per thread
     const int dr = (nr + nth - 1)/nth;
@@ -11061,37 +11118,48 @@ static void ggml_compute_forward_rope_back_f16(
 
     for (int64_t i3 = 0; i3 < ne3; i3++) {
         for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
-            const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
+            const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
             for (int64_t i1 = 0; i1 < ne1; i1++) {
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
                 float theta = (float)p;
 
-                for (int i0 = 0; i0 < n_dims; i0 += 2) {
-                    const float cos_theta = cosf(theta);
-                    const float sin_theta = sinf(theta);
+                if (!is_neox) {
+                    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+                        const float cos_theta = cosf(theta);
+                        const float sin_theta = sinf(theta);
 
-                    theta *= theta_scale;
+                        theta *= theta_scale;
 
-                    if (!is_neox) {
-                        const ggml_fp16_t * const dy  = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-                              ggml_fp16_t *       dx  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+                        const ggml_fp16_t * const dy  = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                              ggml_fp16_t *       dx  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
                         const float dy0 = GGML_FP16_TO_FP32(dy[0]);
                         const float dy1 = GGML_FP16_TO_FP32(dy[1]);
 
                         dx[0] = GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta);
                         dx[1] = GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta);
-                    } else {
-                        const ggml_fp16_t * const dy  = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
-                              ggml_fp16_t *       dx  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
+                    }
+                } else {
+                    for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
+                        for (int64_t ic = 0; ic < n_dims; ic += 2) {
+                            const float cos_theta = cosf(theta);
+                            const float sin_theta = sinf(theta);
 
-                        const float dy0 = GGML_FP16_TO_FP32(dy[0]);
-                        const float dy1 = GGML_FP16_TO_FP32(dy[n_dims/2]);
+                            theta *= theta_scale;
+
+                            const int64_t i0 = ib*n_dims + ic/2;
 
-                        dx[0]        = GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta);
-                        dx[n_dims/2] = GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta);
+                            const ggml_fp16_t * const dy  = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                                  ggml_fp16_t *       dx  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                            const float dy0 = GGML_FP16_TO_FP32(dy[0]);
+                            const float dy1 = GGML_FP16_TO_FP32(dy[n_dims/2]);
+
+                            dx[0]        = GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta);
+                            dx[n_dims/2] = GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta);
+                        }
                     }
                 }
             }