]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : implementation of xPos RoPE (#441); also extends ggml_rope_back with additiona...
authorJan Ploski <redacted>
Tue, 22 Aug 2023 08:45:20 +0000 (10:45 +0200)
committerGitHub <redacted>
Tue, 22 Aug 2023 08:45:20 +0000 (11:45 +0300)
include/ggml/ggml.h
src/ggml.c
tests/CMakeLists.txt
tests/test-xpos.c [new file with mode: 0644]

index 58598ebfbe64e4084d7affe5b4fb2111632d8106..5c8386895827ae2142e8cfe849eb96f6c68ec446 100644 (file)
@@ -1217,6 +1217,15 @@ extern "C" {
             float                 freq_base,
             float                 freq_scale);
 
+    // xPos RoPE, in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past,
+            int                   n_dims,
+            float                 scale_base,
+            bool                  downscale);
+
     // rotary position embedding backward, i.e compute dx from dy
     // a - dy
     GGML_API struct ggml_tensor * ggml_rope_back(
@@ -1225,7 +1234,11 @@ extern "C" {
             int                   n_past,
             int                   n_dims,
             int                   mode,
-            int                   n_ctx);
+            int                   n_ctx,
+            float                 freq_base,
+            float                 freq_scale,
+            float                 xpos_base,
+            bool                  xpos_downscale);
 
     // alibi position embedding
     // in-place, returns view(a)
index 7b5922e3687b59e4ecae43dc39ca42adcc31283a..6a4b6ed46d6c0089898efa554f8710063cfe94c0 100644 (file)
@@ -6715,6 +6715,8 @@ static struct ggml_tensor * ggml_rope_impl(
         int                   n_ctx,
         float                 freq_base,
         float                 freq_scale,
+        float                 xpos_base,
+        bool                  xpos_downscale,
         bool                  inplace) {
     GGML_ASSERT(n_past >= 0);
     bool is_node = false;
@@ -6725,9 +6727,11 @@ static struct ggml_tensor * ggml_rope_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    int32_t params[6] = { n_past, n_dims, mode, n_ctx };
+    int32_t params[8] = { n_past, n_dims, mode, n_ctx };
     memcpy(params + 4, &freq_base,  sizeof(float));
     memcpy(params + 5, &freq_scale, sizeof(float));
+    memcpy(params + 6, &xpos_base, sizeof(float));
+    memcpy(params + 7, &xpos_downscale, sizeof(bool));
     ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_ROPE;
@@ -6744,7 +6748,7 @@ struct ggml_tensor * ggml_rope(
         int                   n_dims,
         int                   mode,
         int                   n_ctx) {
-    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, false);
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
 }
 
 struct ggml_tensor * ggml_rope_inplace(
@@ -6754,7 +6758,7 @@ struct ggml_tensor * ggml_rope_inplace(
         int                   n_dims,
         int                   mode,
         int                   n_ctx) {
-    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, true);
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
 }
 
 struct ggml_tensor * ggml_rope_custom(
@@ -6766,7 +6770,7 @@ struct ggml_tensor * ggml_rope_custom(
         int                   n_ctx,
         float                 freq_base,
         float                 freq_scale) {
-    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, false);
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
 }
 
 struct ggml_tensor * ggml_rope_custom_inplace(
@@ -6778,7 +6782,17 @@ struct ggml_tensor * ggml_rope_custom_inplace(
         int                   n_ctx,
         float                 freq_base,
         float                 freq_scale) {
-    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, true);
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
+}
+
+struct ggml_tensor * ggml_rope_xpos_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past,
+        int                   n_dims,
+        float                 scale_base,
+        bool                  downscale) {
+    return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 10000.0f, 1.0f, scale_base, downscale, true);
 }
 
 // ggml_rope_back
@@ -6789,7 +6803,11 @@ struct ggml_tensor * ggml_rope_back(
         int                   n_past,
         int                   n_dims,
         int                   mode,
-        int                   n_ctx) {
+        int                   n_ctx,
+        float                 freq_base,
+        float                 freq_scale,
+        float                 xpos_base,
+        bool                  xpos_downscale) {
     GGML_ASSERT(n_past >= 0);
     GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
 
@@ -6801,7 +6819,11 @@ struct ggml_tensor * ggml_rope_back(
 
     struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
 
-    int32_t params[] = { n_past, n_dims, mode, n_ctx };
+    int32_t params[8] = { n_past, n_dims, mode, n_ctx };
+    memcpy(params + 4, &freq_base, sizeof(float));
+    memcpy(params + 5, &freq_scale, sizeof(float));
+    memcpy(params + 6, &xpos_base, sizeof(float));
+    memcpy(params + 7, &xpos_downscale, sizeof(bool));
     ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_ROPE_BACK;
@@ -12065,7 +12087,6 @@ static void ggml_compute_forward_alibi(
     }
 }
 
-
 // ggml_compute_forward_clamp
 
 static void ggml_compute_forward_clamp_f32(
@@ -12154,12 +12175,18 @@ static void ggml_compute_forward_rope_f32(
     float freq_base;
     float freq_scale;
 
+    // these two only relevant for xPos RoPE:
+    float xpos_base;
+    bool xpos_downscale;
+
     const int n_past = ((int32_t *) dst->op_params)[0];
     const int n_dims = ((int32_t *) dst->op_params)[1];
     const int mode   = ((int32_t *) dst->op_params)[2];
     const int n_ctx  = ((int32_t *) dst->op_params)[3];
     memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
     memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
+    memcpy(&xpos_base,  (int32_t *) dst->op_params + 6, sizeof(float));
+    memcpy(&xpos_downscale, (int32_t *) dst->op_params + 7, sizeof(bool));
 
     assert(n_past >= 0);
 
@@ -12231,6 +12258,9 @@ static void ggml_compute_forward_rope_f32(
                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
                         const float cos_theta = cosf(theta);
                         const float sin_theta = sinf(theta);
+                        // zeta scaling for xPos only:
+                        float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f;
+                        if (xpos_downscale) zeta = 1.0f / zeta;
 
                         theta *= theta_scale;
 
@@ -12240,8 +12270,8 @@ static void ggml_compute_forward_rope_f32(
                         const float x0 = src[0];
                         const float x1 = src[1];
 
-                        dst_data[0] = x0*cos_theta - x1*sin_theta;
-                        dst_data[1] = x0*sin_theta + x1*cos_theta;
+                        dst_data[0] = x0*cos_theta*zeta - x1*sin_theta*zeta;
+                        dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
                     }
                 } else {
                     // TODO: this is probably wrong, but I can't figure it out ..
@@ -12435,9 +12465,21 @@ static void ggml_compute_forward_rope_back_f32(
     // dx = rope_back(dy, src1)
     // src0 is dy, src1 contains options
 
+    float freq_base;
+    float freq_scale;
+
+    // these two only relevant for xPos RoPE:
+    float xpos_base;
+    bool xpos_downscale;
+
     const int n_past = ((int32_t *) dst->op_params)[0];
     const int n_dims = ((int32_t *) dst->op_params)[1];
     const int mode   = ((int32_t *) dst->op_params)[2];
+    const int n_ctx  = ((int32_t *) dst->op_params)[3];
+    memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
+    memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
+    memcpy(&xpos_base,  (int32_t *) dst->op_params + 6, sizeof(float));
+    memcpy(&xpos_downscale, (int32_t *) dst->op_params + 7, sizeof(bool));
 
     assert(n_past >= 0);
 
@@ -12463,7 +12505,7 @@ static void ggml_compute_forward_rope_back_f32(
     // row index used to determine which thread to use
     int ir = 0;
 
-    const float theta_scale = powf(10000.0, -2.0f/n_dims);
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
     const bool is_neox = mode & 2;
 
@@ -12474,12 +12516,15 @@ static void ggml_compute_forward_rope_back_f32(
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
-                float theta = (float)p;
+                float theta = freq_scale * (float)p;
 
                 if (!is_neox) {
                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
                         const float cos_theta = cosf(theta);
                         const float sin_theta = sinf(theta);
+                        // zeta scaling for xPos only:
+                        float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f;
+                        if (xpos_downscale) zeta = 1.0f / zeta;
 
                         theta *= theta_scale;
 
@@ -12489,8 +12534,8 @@ static void ggml_compute_forward_rope_back_f32(
                         const float dy0 = dy[0];
                         const float dy1 = dy[1];
 
-                        dx[0] =   dy0*cos_theta + dy1*sin_theta;
-                        dx[1] = - dy0*sin_theta + dy1*cos_theta;
+                        dx[0] =   dy0*cos_theta*zeta + dy1*sin_theta*zeta;
+                        dx[1] = - dy0*sin_theta*zeta + dy1*cos_theta*zeta;
                     }
                 } else {
                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
@@ -15967,6 +16012,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     const int n_dims = ((int32_t *) tensor->op_params)[1];
                     const int mode   = ((int32_t *) tensor->op_params)[2];
                     const int n_ctx  = ((int32_t *) tensor->op_params)[3];
+                    float freq_base, freq_scale, xpos_base;
+                    bool xpos_downscale;
+                                       memcpy(&freq_base,  (int32_t *) tensor->op_params + 4, sizeof(float));
+                                       memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float));
+                                       memcpy(&xpos_base,  (int32_t *) tensor->op_params + 6, sizeof(float));
+                                       memcpy(&xpos_downscale, (int32_t *) tensor->op_params + 7, sizeof(bool));
+
                     src0->grad = ggml_add_impl(ctx,
                             src0->grad,
                             ggml_rope_back(ctx,
@@ -15974,7 +16026,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 n_past,
                                 n_dims,
                                 mode,
-                                n_ctx),
+                                n_ctx,
+                                freq_base,
+                                freq_scale,
+                                xpos_base,
+                                xpos_downscale),
                             inplace);
                 }
             } break;
@@ -15985,14 +16041,26 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     const int n_dims = ((int32_t *) tensor->op_params)[1];
                     const int mode   = ((int32_t *) tensor->op_params)[2];
                     const int n_ctx  = ((int32_t *) tensor->op_params)[3];
+                    float freq_base, freq_scale, xpos_base;
+                    bool xpos_downscale;
+                                       memcpy(&freq_base,  (int32_t *) tensor->op_params + 4, sizeof(float));
+                                       memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float));
+                                       memcpy(&xpos_base,  (int32_t *) tensor->op_params + 6, sizeof(float));
+                                       memcpy(&xpos_downscale, (int32_t *) tensor->op_params + 7, sizeof(bool));
+
                     src0->grad = ggml_add_impl(ctx,
                             src0->grad,
-                            ggml_rope(ctx,
+                            ggml_rope_impl(ctx,
                                 tensor->grad,
                                 n_past,
                                 n_dims,
                                 mode,
-                                n_ctx),
+                                n_ctx,
+                                freq_base,
+                                freq_scale,
+                                xpos_base,
+                                xpos_downscale,
+                                false),
                             inplace);
                 }
             } break;
index 33eda08f72b84fc889f5a4420119386807ea242f..8d9d1f79b8a8314eb2cbaebd6e58b4b30eb4153c 100644 (file)
@@ -325,3 +325,12 @@ if (MSVC)
 endif()
 add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
 set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")
+
+#
+# test-xpos
+
+set(TEST_TARGET test-xpos)
+add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml)
+add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
+set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")
diff --git a/tests/test-xpos.c b/tests/test-xpos.c
new file mode 100644 (file)
index 0000000..a8c64e5
--- /dev/null
@@ -0,0 +1,87 @@
+#include "ggml/ggml.h"
+
+#include <math.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+bool is_close(float a, float b, float epsilon) {
+    return fabs(a - b) < epsilon;
+}
+
+int main(int argc, char ** argv) {
+    const int n_threads = 1;
+    const int n_embd_head = 4; // aka head_dim
+    const int n_head = 1;
+    const int N = 8;
+
+    struct ggml_init_params params = {
+        .mem_size   = 16*1024*1024,
+        .mem_buffer = NULL,
+    };
+   
+    // memory allocation happens here
+    struct ggml_context * ctx = ggml_init(params);
+
+    struct ggml_tensor * Q = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, N);
+    struct ggml_tensor * K = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, N);
+
+    for (int i = 0; i < ggml_nelements(Q); i++) {
+        ((float*) Q->data)[i] = 2.0f;
+        ((float*) K->data)[i] = 2.0f;
+    }
+
+    struct ggml_tensor * Qx = ggml_rope_xpos_inplace(ctx, Q, 1, n_embd_head, 512.0f, false);
+    struct ggml_tensor * Kx = ggml_rope_xpos_inplace(ctx, K, 1, n_embd_head, 512.0f, true);
+   
+    struct ggml_cgraph gf = ggml_build_forward(Qx);
+    ggml_build_forward_expand(&gf, Kx);
+    ggml_graph_compute_with_ctx(ctx, &gf, n_threads);
+
+       // expected output for Qx:
+    // -0.6009  2.7568  1.9782  2.0182 
+    // -2.6379  0.9815  1.9562  2.0361 
+    // -2.2457 -1.6853  1.9341  2.0538 
+    //  0.2043 -2.7934  1.9118  2.0712 
+    //  2.4550 -1.3341  1.8894  2.0884 
+    //  2.4430  1.3417  1.8668  2.1054 
+    //  0.1905  2.7739  1.8440  2.1221 
+    // -2.2257  1.6550  1.8212  2.1386 
+
+    for (int i = 0; i < ggml_nelements(Q); i++) {
+        if (((float*) Qx->data)[i] > 0) printf(" ");
+        printf("%.4f ", ((float*) Qx->data)[i]);
+        if ((i+1) % n_embd_head == 0) printf("\n");
+    }
+    printf("\n");
+
+    GGML_ASSERT(is_close(((float*) Qx->data)[7 * n_embd_head + 0], -2.2257f, 0.0001f));
+    GGML_ASSERT(is_close(((float*) Qx->data)[7 * n_embd_head + 1],  1.6550f, 0.0001f));
+    GGML_ASSERT(is_close(((float*) Qx->data)[7 * n_embd_head + 2],  1.8212f, 0.0001f));
+    GGML_ASSERT(is_close(((float*) Qx->data)[7 * n_embd_head + 3],  2.1386f, 0.0001f));
+
+    // expected output for Kx:
+       // -0.6038  2.7703  1.9816  2.0216 
+    // -2.6639  0.9911  1.9630  2.0431 
+    // -2.2789 -1.7103  1.9441  2.0644 
+    //  0.2083 -2.8486  1.9251  2.0856 
+    //  2.5158 -1.3671  1.9057  2.1065 
+    //  2.5158  1.3816  1.8862  2.1273 
+    //  0.1972  2.8705  1.8665  2.1479 
+    // -2.3146  1.7211  1.8465  2.1684 
+    
+    for (int i = 0; i < ggml_nelements(K); i++) {
+        if (((float*) Kx->data)[i] > 0) printf(" ");
+        printf("%.4f ", ((float*) Kx->data)[i]);
+        if ((i+1) % n_embd_head == 0) printf("\n");
+    }
+    printf("\n");
+
+    GGML_ASSERT(is_close(((float*) Kx->data)[7 * n_embd_head + 0], -2.3146f, 0.0001f));
+    GGML_ASSERT(is_close(((float*) Kx->data)[7 * n_embd_head + 1],  1.7211f, 0.0001f));
+    GGML_ASSERT(is_close(((float*) Kx->data)[7 * n_embd_head + 2],  1.8465f, 0.0001f));
+    GGML_ASSERT(is_close(((float*) Kx->data)[7 * n_embd_head + 3],  2.1684f, 0.0001f));
+
+    ggml_free(ctx);
+
+    return 0;
+}