]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : avoid powf() calls in ggml_rope()
authorGeorgi Gerganov <redacted>
Fri, 14 Apr 2023 10:32:27 +0000 (13:32 +0300)
committerGeorgi Gerganov <redacted>
Fri, 14 Apr 2023 10:32:36 +0000 (13:32 +0300)
src/ggml.c

index 15a37108e8d0d479e80876ac61d709f7a73c8418..d99aca21a864dadd2e428ba0e5bc565ae3c893f5 100644 (file)
@@ -7509,6 +7509,8 @@ static void ggml_compute_forward_rope_f32(
     // row index used to determine which thread to use
     int ir = 0;
 
+    const float theta_scale = powf(10000.0, -2.0f/n_dims);
+
     for (int64_t i3 = 0; i3 < ne3; i3++) {
         for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
             const int p = (mode == 0 ? n_past + i2 : i2);
@@ -7516,11 +7518,13 @@ static void ggml_compute_forward_rope_f32(
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
+                float theta = (float)p;
+
                 for (int i0 = 0; i0 < n_dims; i0 += 2) {
-                    const float theta = powf(10000.0, ((float)-i0)/n_dims);
+                    const float cos_theta = cosf(theta);
+                    const float sin_theta = sinf(theta);
 
-                    const float cos_theta = cosf(p*theta);
-                    const float sin_theta = sinf(p*theta);
+                    theta *= theta_scale;
 
                     const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
                           float * dst_data  = (float *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -7582,6 +7586,8 @@ static void ggml_compute_forward_rope_f16(
     // row index used to determine which thread to use
     int ir = 0;
 
+    const float theta_scale = powf(10000.0, -2.0f/n_dims);
+
     for (int64_t i3 = 0; i3 < ne3; i3++) {
         for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
             const int p = (mode == 0 ? n_past + i2 : i2);
@@ -7589,11 +7595,13 @@ static void ggml_compute_forward_rope_f16(
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
+                float theta = (float)p;
+
                 for (int i0 = 0; i0 < n_dims; i0 += 2) {
-                    const float theta = powf(10000.0, ((float)-i0)/n_dims);
+                    const float cos_theta = cosf(theta);
+                    const float sin_theta = sinf(theta);
 
-                    const float cos_theta = cosf(p*theta);
-                    const float sin_theta = sinf(p*theta);
+                    theta *= theta_scale;
 
                     const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
                           ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);