]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml: cache sin/cos for RoPE (llama/4908)
authorJohannes Gäßler <redacted>
Sat, 13 Jan 2024 20:41:37 +0000 (21:41 +0100)
committerGeorgi Gerganov <redacted>
Sat, 13 Jan 2024 22:11:45 +0000 (00:11 +0200)
ggml.c

diff --git a/ggml.c b/ggml.c
index de6ef34bdde59cb216ffdd1b524c9eaf323fb0b6..bcfb6652c10326f5be23f0b59cb3ce8c80d72168 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -11638,6 +11638,21 @@ static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, fl
     return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
 }
 
+static void ggml_rope_cache_init(
+     float theta_base, float freq_scale, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
+     float * cache, float sin_sign, float theta_scale
+) {
+    float theta = theta_base;
+    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+        rope_yarn(
+            theta, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
+        );
+        cache[i0 + 1] *= sin_sign;
+
+        theta *= theta_scale;
+    }
+}
+
 void ggml_rope_yarn_corr_dims(
     int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
 ) {
@@ -11720,6 +11735,12 @@ static void ggml_compute_forward_rope_f32(
     for (int64_t i3 = 0; i3 < ne3; i3++) {
         for (int64_t i2 = 0; i2 < ne2; i2++) {
             const int64_t p = pos[i2];
+
+            float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
+            if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
+                ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+            }
+
             for (int64_t i1 = 0; i1 < ne1; i1++) {
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
@@ -11753,18 +11774,13 @@ static void ggml_compute_forward_rope_f32(
                     }
                 } else if (!is_neox) {
                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
-                        float cos_theta, sin_theta;
-                        rope_yarn(
-                            theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
-                        );
-                        sin_theta *= sin_sign;
+                        const float cos_theta = cache[i0 + 0];
+                        const float sin_theta = cache[i0 + 1];
 
                         // zeta scaling for xPos only:
                         float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
                         if (xpos_down) zeta = 1.0f / zeta;
 
-                        theta_base *= theta_scale;
-
                         const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
                               float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
@@ -11888,6 +11904,12 @@ static void ggml_compute_forward_rope_f16(
     for (int64_t i3 = 0; i3 < ne3; i3++) {
         for (int64_t i2 = 0; i2 < ne2; i2++) {
             const int64_t p = pos[i2];
+
+            float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
+            if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
+                ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+            }
+
             for (int64_t i1 = 0; i1 < ne1; i1++) {
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
@@ -11921,13 +11943,8 @@ static void ggml_compute_forward_rope_f16(
                     }
                 } else if (!is_neox) {
                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
-                        float cos_theta, sin_theta;
-                        rope_yarn(
-                            theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
-                        );
-                        sin_theta *= sin_sign;
-
-                        theta_base *= theta_scale;
+                        const float cos_theta = cache[i0 + 0];
+                        const float sin_theta = cache[i0 + 1];
 
                         const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
                               ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
@@ -16722,6 +16739,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
                     }
                 } break;
             case GGML_OP_SOFT_MAX:
+            case GGML_OP_ROPE:
                 {
                     cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
                 } break;