]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : fix backward rope after YaRN (#3974)
authorxaedes <redacted>
Tue, 7 Nov 2023 08:04:51 +0000 (09:04 +0100)
committerGitHub <redacted>
Tue, 7 Nov 2023 08:04:51 +0000 (10:04 +0200)
* fix backward process of rope

rope backward process was broken after YaRN RoPE (#2268) implementation, due to missing changes in backward functions.

the code for the backward process is nearly identically to the forward process:
the only difference is the sign of the sin-values.

to avoid future regressions remove the near-duplicate backward functions and reuse the forward code:

for this a new function argument `bool forward` was added to `ggml_compute_forward_rope_f32` and `ggml_compute_forward_rope_f16`.
the sin-values will be negated when forward is false.

* fix finetune rope call to use correct default attn_factor of 1.0f

* remove unused `ggml_rope_xpos_back`

it is better to have only one `ggml_rope_back` function that accepts all rope parameters, so that `ggml_compute_backward` can propagate all parameters without having to switch between different rope_back variants.

* fix comments explaining the sinus sign in ggml_forward_rope

* add missing function arguments in declaration

* fix function argument type in declaration

examples/finetune/finetune.cpp
ggml.c
ggml.h

index 649a3b7c1941e5fb4ab4a40752764b0bcc211b11..fa7dbe496b2c51be70d6a12a7c1fb46dd77c6cbf 100644 (file)
@@ -643,7 +643,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
 
         return ggml_rope_custom(ctx,
             t, KQ_pos, n_rot, rope_mode, n_ctx, 0,
-            rope_freq_base, rope_freq_scale, 0.0f, 0.0f, 0.0f, 0.0f
+            rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
         );
     };
 
diff --git a/ggml.c b/ggml.c
index 605a27940fc81ff92b588ca5263bf4e9d1fb49ff..009d5b3985e55fb7be98c1fdfd39262a7b81a90a 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -4970,8 +4970,13 @@ struct ggml_tensor * ggml_rope_back(
         int                   n_dims,
         int                   mode,
         int                   n_ctx,
+        int                   n_orig_ctx,
         float                 freq_base,
         float                 freq_scale,
+        float                 ext_factor,
+        float                 attn_factor,
+        float                 beta_fast,
+        float                 beta_slow,
         float                 xpos_base,
         bool                  xpos_down) {
     GGML_ASSERT(ggml_is_vector(b));
@@ -4988,11 +4993,15 @@ struct ggml_tensor * ggml_rope_back(
 
     struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
 
-    int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
-    memcpy(params + 4, &freq_base,  sizeof(float));
-    memcpy(params + 5, &freq_scale, sizeof(float));
-    memcpy(params + 6, &xpos_base,  sizeof(float));
-    memcpy(params + 7, &xpos_down,  sizeof(bool));
+    int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
+    memcpy(params +  5, &freq_base,    sizeof(float));
+    memcpy(params +  6, &freq_scale,   sizeof(float));
+    memcpy(params +  7, &ext_factor,   sizeof(float));
+    memcpy(params +  8, &attn_factor,  sizeof(float));
+    memcpy(params +  9, &beta_fast,    sizeof(float));
+    memcpy(params + 10, &beta_slow,    sizeof(float));
+    memcpy(params + 11, &xpos_base,    sizeof(float));
+    memcpy(params + 12, &xpos_down,    sizeof(bool));
     ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_ROPE_BACK;
@@ -10974,7 +10983,8 @@ static void ggml_compute_forward_rope_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        struct ggml_tensor * dst) {
+        struct ggml_tensor * dst,
+        const bool forward) {
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
@@ -11033,6 +11043,11 @@ static void ggml_compute_forward_rope_f32(
     const bool is_neox = mode & 2;
     const bool is_glm  = mode & 4;
 
+    // backward process uses inverse rotation by cos and sin.
+    // cos and sin build a rotation matrix, where the inverse is the transpose.
+    // this essentially just switches the sign of sin.
+    const float sin_sign = forward ? 1.0f : -1.0f;
+
     const int32_t * pos = (const int32_t *) src1->data;
 
     for (int64_t i3 = 0; i3 < ne3; i3++) {
@@ -11049,9 +11064,9 @@ static void ggml_compute_forward_rope_f32(
                     float block_theta = MAX(p - (n_ctx - 2), 0);
                     for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
                         const float cos_theta = cosf(theta_base);
-                        const float sin_theta = sinf(theta_base);
+                        const float sin_theta = sinf(theta_base) * sin_sign;
                         const float cos_block_theta = cosf(block_theta);
-                        const float sin_block_theta = sinf(block_theta);
+                        const float sin_block_theta = sinf(block_theta) * sin_sign;
 
                         theta_base *= theta_scale;
                         block_theta *= theta_scale;
@@ -11075,6 +11090,7 @@ static void ggml_compute_forward_rope_f32(
                         rope_yarn(
                             theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
                         );
+                        sin_theta *= sin_sign;
 
                         // zeta scaling for xPos only:
                         float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
@@ -11105,6 +11121,7 @@ static void ggml_compute_forward_rope_f32(
                                 theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
                                 &cos_theta, &sin_theta
                             );
+                            sin_theta *= sin_sign;
 
                             theta_base *= theta_scale;
 
@@ -11130,7 +11147,8 @@ static void ggml_compute_forward_rope_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        struct ggml_tensor * dst) {
+        struct ggml_tensor * dst,
+        const bool forward) {
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
@@ -11182,6 +11200,11 @@ static void ggml_compute_forward_rope_f16(
     const bool is_neox = mode & 2;
     const bool is_glm  = mode & 4;
 
+    // backward process uses inverse rotation by cos and sin.
+    // cos and sin build a rotation matrix, where the inverse is the transpose.
+    // this essentially just switches the sign of sin.
+    const float sin_sign = forward ? 1.0f : -1.0f;
+
     const int32_t * pos = (const int32_t *) src1->data;
 
     for (int64_t i3 = 0; i3 < ne3; i3++) {
@@ -11198,9 +11221,9 @@ static void ggml_compute_forward_rope_f16(
                     float block_theta = MAX(p - (n_ctx - 2), 0);
                     for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
                         const float cos_theta = cosf(theta_base);
-                        const float sin_theta = sinf(theta_base);
+                        const float sin_theta = sinf(theta_base) * sin_sign;
                         const float cos_block_theta = cosf(block_theta);
-                        const float sin_block_theta = sinf(block_theta);
+                        const float sin_block_theta = sinf(block_theta) * sin_sign;
 
                         theta_base *= theta_scale;
                         block_theta *= theta_scale;
@@ -11224,6 +11247,7 @@ static void ggml_compute_forward_rope_f16(
                         rope_yarn(
                             theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
                         );
+                        sin_theta *= sin_sign;
 
                         theta_base *= theta_scale;
 
@@ -11250,6 +11274,7 @@ static void ggml_compute_forward_rope_f16(
                                 theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
                                 &cos_theta, &sin_theta
                             );
+                            sin_theta *= sin_sign;
 
                             theta_base *= theta_scale;
 
@@ -11279,11 +11304,11 @@ static void ggml_compute_forward_rope(
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
-                ggml_compute_forward_rope_f16(params, src0, src1, dst);
+                ggml_compute_forward_rope_f16(params, src0, src1, dst, true);
             } break;
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_rope_f32(params, src0, src1, dst);
+                ggml_compute_forward_rope_f32(params, src0, src1, dst, true);
             } break;
         default:
             {
@@ -11294,216 +11319,6 @@ static void ggml_compute_forward_rope(
 
 // ggml_compute_forward_rope_back
 
-static void ggml_compute_forward_rope_back_f32(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-        struct ggml_tensor * dst) {
-
-    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
-        return;
-    }
-
-    // y = rope(x, src1)
-    // dx = rope_back(dy, src1)
-    // src0 is dy, src1 contains options
-
-    float freq_base;
-    float freq_scale;
-
-    // these two only relevant for xPos RoPE:
-    float xpos_base;
-    bool xpos_down;
-
-    //const int n_past = ((int32_t *) dst->op_params)[0];
-    const int n_dims = ((int32_t *) dst->op_params)[1];
-    const int mode   = ((int32_t *) dst->op_params)[2];
-    const int n_ctx  = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx);
-    memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
-    memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
-    memcpy(&xpos_base,  (int32_t *) dst->op_params + 6, sizeof(float));
-    memcpy(&xpos_down,  (int32_t *) dst->op_params + 7, sizeof(bool));
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
-    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
-
-    assert(nb0 == sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr = ggml_nrows(dst);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    // row index used to determine which thread to use
-    int ir = 0;
-
-    const float theta_scale = powf(freq_base, -2.0f/n_dims);
-
-    const bool is_neox = mode & 2;
-
-    const int32_t * pos = (const int32_t *) src1->data;
-
-    for (int64_t i3 = 0; i3 < ne3; i3++) {
-        for (int64_t i2 = 0; i2 < ne2; i2++) {
-            const int64_t p = pos[i2];
-            for (int64_t i1 = 0; i1 < ne1; i1++) {
-                if (ir++ < ir0) continue;
-                if (ir   > ir1) break;
-
-                float theta_base = freq_scale * (float)p;
-
-                if (!is_neox) {
-                    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
-                        const float cos_theta = cosf(theta_base);
-                        const float sin_theta = sinf(theta_base);
-
-                        // zeta scaling for xPos only:
-                        float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
-                        if (xpos_down) zeta = 1.0f / zeta;
-
-                        theta_base *= theta_scale;
-
-                        const float * const dy  = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                              float *       dx  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
-
-                        const float dy0 = dy[0];
-                        const float dy1 = dy[1];
-
-                        dx[0] =   dy0*cos_theta*zeta + dy1*sin_theta*zeta;
-                        dx[1] = - dy0*sin_theta*zeta + dy1*cos_theta*zeta;
-                    }
-                } else {
-                    for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
-                        for (int64_t ic = 0; ic < n_dims; ic += 2) {
-                            const float cos_theta = cosf(theta_base);
-                            const float sin_theta = sinf(theta_base);
-
-                            theta_base *= theta_scale;
-
-                            const int64_t i0 = ib*n_dims + ic/2;
-
-                            const float * const dy  = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                                  float *       dx  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
-
-                            const float dy0 = dy[0];
-                            const float dy1 = dy[n_dims/2];
-
-                            dx[0]        =   dy0*cos_theta + dy1*sin_theta;
-                            dx[n_dims/2] = - dy0*sin_theta + dy1*cos_theta;
-                        }
-                    }
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_rope_back_f16(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-        struct ggml_tensor * dst) {
-
-    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
-        return;
-    }
-
-    // y = rope(x, src1)
-    // dx = rope_back(dy, src1)
-    // src0 is dy, src1 contains options
-
-    //const int n_past = ((int32_t *) dst->op_params)[0];
-    const int n_dims = ((int32_t *) dst->op_params)[1];
-    const int mode   = ((int32_t *) dst->op_params)[2];
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
-    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
-
-    assert(nb0 == sizeof(ggml_fp16_t));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr = ggml_nrows(dst);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    // row index used to determine which thread to use
-    int ir = 0;
-
-    const float theta_scale = powf(10000.0, -2.0f/n_dims);
-
-    const bool is_neox = mode & 2;
-
-    const int32_t * pos = (const int32_t *) src1->data;
-
-    for (int64_t i3 = 0; i3 < ne3; i3++) {
-        for (int64_t i2 = 0; i2 < ne2; i2++) {
-            const int64_t p = pos[i2];
-            for (int64_t i1 = 0; i1 < ne1; i1++) {
-                if (ir++ < ir0) continue;
-                if (ir   > ir1) break;
-
-                float theta_base = (float)p;
-
-                if (!is_neox) {
-                    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
-                        const float cos_theta = cosf(theta_base);
-                        const float sin_theta = sinf(theta_base);
-
-                        theta_base *= theta_scale;
-
-                        const ggml_fp16_t * const dy  = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                              ggml_fp16_t *       dx  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
-
-                        const float dy0 = GGML_FP16_TO_FP32(dy[0]);
-                        const float dy1 = GGML_FP16_TO_FP32(dy[1]);
-
-                        dx[0] = GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta);
-                        dx[1] = GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta);
-                    }
-                } else {
-                    for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
-                        for (int64_t ic = 0; ic < n_dims; ic += 2) {
-                            const float cos_theta = cosf(theta_base);
-                            const float sin_theta = sinf(theta_base);
-
-                            theta_base *= theta_scale;
-
-                            const int64_t i0 = ib*n_dims + ic/2;
-
-                            const ggml_fp16_t * const dy  = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                                  ggml_fp16_t *       dx  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
-
-                            const float dy0 = GGML_FP16_TO_FP32(dy[0]);
-                            const float dy1 = GGML_FP16_TO_FP32(dy[n_dims/2]);
-
-                            dx[0]        = GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta);
-                            dx[n_dims/2] = GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta);
-                        }
-                    }
-                }
-            }
-        }
-    }
-}
-
 static void ggml_compute_forward_rope_back(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
@@ -11512,11 +11327,11 @@ static void ggml_compute_forward_rope_back(
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
-                ggml_compute_forward_rope_back_f16(params, src0, src1, dst);
+                ggml_compute_forward_rope_f16(params, src0, src1, dst, false);
             } break;
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_rope_back_f32(params, src0, src1, dst);
+                ggml_compute_forward_rope_f32(params, src0, src1, dst, false);
             } break;
         default:
             {
@@ -15559,17 +15374,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 // necessary for llama
                 if (src0->grad) {
                     //const int n_past = ((int32_t *) tensor->op_params)[0];
-                    const int n_dims = ((int32_t *) tensor->op_params)[1];
-                    const int mode   = ((int32_t *) tensor->op_params)[2];
-                    const int n_ctx  = ((int32_t *) tensor->op_params)[3];
-                    float freq_base;
-                    float freq_scale;
-                    float xpos_base;
-                    bool  xpos_down;
-                    memcpy(&freq_base,  (int32_t *) tensor->op_params + 4, sizeof(float));
-                    memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float));
-                    memcpy(&xpos_base,  (int32_t *) tensor->op_params + 6, sizeof(float));
-                    memcpy(&xpos_down,  (int32_t *) tensor->op_params + 7, sizeof(bool));
+                    const int n_dims     = ((int32_t *) tensor->op_params)[1];
+                    const int mode       = ((int32_t *) tensor->op_params)[2];
+                    const int n_ctx      = ((int32_t *) tensor->op_params)[3];
+                    const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
+                    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
+
+                    memcpy(&freq_base,   (int32_t *) tensor->op_params +  5, sizeof(float));
+                    memcpy(&freq_scale,  (int32_t *) tensor->op_params +  6, sizeof(float));
+                    memcpy(&ext_factor,  (int32_t *) tensor->op_params +  7, sizeof(float));
+                    memcpy(&attn_factor, (int32_t *) tensor->op_params +  8, sizeof(float));
+                    memcpy(&beta_fast,   (int32_t *) tensor->op_params +  9, sizeof(float));
+                    memcpy(&beta_slow,   (int32_t *) tensor->op_params + 10, sizeof(float));
+                    memcpy(&xpos_base,   (int32_t *) tensor->op_params + 11, sizeof(float));
+                    memcpy(&xpos_down,   (int32_t *) tensor->op_params + 12, sizeof(bool));
 
                     src0->grad = ggml_add_or_set(ctx,
                             src0->grad,
@@ -15579,8 +15397,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 n_dims,
                                 mode,
                                 n_ctx,
+                                n_orig_ctx,
                                 freq_base,
                                 freq_scale,
+                                ext_factor,
+                                attn_factor,
+                                beta_fast,
+                                beta_slow,
                                 xpos_base,
                                 xpos_down),
                             zero_table);
@@ -15590,17 +15413,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 if (src0->grad) {
                     //const int n_past = ((int32_t *) tensor->op_params)[0];
-                    const int n_dims = ((int32_t *) tensor->op_params)[1];
-                    const int mode   = ((int32_t *) tensor->op_params)[2];
-                    const int n_ctx  = ((int32_t *) tensor->op_params)[3];
-                    float freq_base;
-                    float freq_scale;
-                    float xpos_base;
-                    bool  xpos_down;
-                    memcpy(&freq_base,  (int32_t *) tensor->op_params + 4, sizeof(float));
-                    memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float));
-                    memcpy(&xpos_base,  (int32_t *) tensor->op_params + 6, sizeof(float));
-                    memcpy(&xpos_down,  (int32_t *) tensor->op_params + 7, sizeof(bool));
+                    const int n_dims     = ((int32_t *) tensor->op_params)[1];
+                    const int mode       = ((int32_t *) tensor->op_params)[2];
+                    const int n_ctx      = ((int32_t *) tensor->op_params)[3];
+                    const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
+                    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
+
+                    memcpy(&freq_base,   (int32_t *) tensor->op_params +  5, sizeof(float));
+                    memcpy(&freq_scale,  (int32_t *) tensor->op_params +  6, sizeof(float));
+                    memcpy(&ext_factor,  (int32_t *) tensor->op_params +  7, sizeof(float));
+                    memcpy(&attn_factor, (int32_t *) tensor->op_params +  8, sizeof(float));
+                    memcpy(&beta_fast,   (int32_t *) tensor->op_params +  9, sizeof(float));
+                    memcpy(&beta_slow,   (int32_t *) tensor->op_params + 10, sizeof(float));
+                    memcpy(&xpos_base,   (int32_t *) tensor->op_params + 11, sizeof(float));
+                    memcpy(&xpos_down,   (int32_t *) tensor->op_params + 12, sizeof(bool));
 
                     src0->grad = ggml_add_or_set(ctx,
                             src0->grad,
@@ -15609,14 +15435,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 src1,
                                 n_dims,
                                 mode,
-                                0,
                                 n_ctx,
+                                n_orig_ctx,
                                 freq_base,
                                 freq_scale,
-                                0.0f,
-                                1.0f,
-                                0.0f,
-                                0.0f,
+                                ext_factor,
+                                attn_factor,
+                                beta_fast,
+                                beta_slow,
                                 xpos_base,
                                 xpos_down,
                                 false),
diff --git a/ggml.h b/ggml.h
index 70eb25a6bf3afc70124ad090508b70d06085f000..26654fc8ecdc84c996e66c543729fdb83eddce6d 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -1372,8 +1372,13 @@ extern "C" {
             int                   n_dims,
             int                   mode,
             int                   n_ctx,
+            int                   n_orig_ctx,
             float                 freq_base,
             float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow,
             float                 xpos_base,
             bool                  xpos_down);