]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : support ChatGLM-style RoPE (#305)
authorJiahao Li <redacted>
Mon, 26 Jun 2023 20:47:31 +0000 (04:47 +0800)
committerGitHub <redacted>
Mon, 26 Jun 2023 20:47:31 +0000 (23:47 +0300)
examples/dolly-v2/main.cpp
examples/gpt-j/main.cpp
examples/gpt-neox/main.cpp
include/ggml/ggml.h
src/ggml.c
tests/test-grad0.c

index 85faa707625e94b943ebe88211ebde0c812b6e2c..ce7d3d4fa2c52a1d85de9e97852149e38c6fcd79 100644 (file)
@@ -523,8 +523,8 @@ bool dollyv2_eval(
             struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd/n_head, n_head, N, cur->nb[1]/n_head, cur->nb[1], 2*sizeof(float)*n_embd/n_head));
 
             // using mode = 2 for GPT-NeoX mode
-            Qcur = ggml_rope_inplace(ctx0, Qcur, n_past, n_rot, 2);
-            Kcur = ggml_rope_inplace(ctx0, Kcur, n_past, n_rot, 2);
+            Qcur = ggml_rope_inplace(ctx0, Qcur, n_past, n_rot, 2, 0);
+            Kcur = ggml_rope_inplace(ctx0, Kcur, n_past, n_rot, 2, 0);
 
             // store key and value to memory
             {
index 3d956ffe876c34d0bfb1591bfa1bd7c372e9b84d..516e0998a4cc3060142f904f504551900e60fe66 100644 (file)
@@ -452,8 +452,8 @@ bool gptj_eval(
 
         // self-attention
         {
-            struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_q_proj_w, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
-            struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_k_proj_w, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
+            struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_q_proj_w, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
+            struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_k_proj_w, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
 
             // store key and value to memory
             {
index 1cd64a227fe76e9207955030a0f1c376ba27f159..649b51c0e4b8982346db9fc15645b18f1f81d6ac 100644 (file)
@@ -517,8 +517,8 @@ bool gpt_neox_eval(
             struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd/n_head, n_head, N, cur->nb[1]/n_head, cur->nb[1], 2*sizeof(float)*n_embd/n_head));
 
             // using mode = 2 for GPT-NeoX mode
-            Qcur = ggml_rope_inplace(ctx0, Qcur, n_past, n_rot, 2);
-            Kcur = ggml_rope_inplace(ctx0, Kcur, n_past, n_rot, 2);
+            Qcur = ggml_rope_inplace(ctx0, Qcur, n_past, n_rot, 2, 0);
+            Kcur = ggml_rope_inplace(ctx0, Kcur, n_past, n_rot, 2, 0);
 
             // store key and value to memory
             {
index 08025e57a951f7b7cdae0980568019f34a298dcc..459913222e0833e0a124be41efd73fbbdd124a9b 100644 (file)
@@ -1036,13 +1036,15 @@ extern "C" {
     // rotary position embedding
     // if mode & 1 == 1, skip n_past elements
     // if mode & 2 == 1, GPT-NeoX style
+    // if mode & 4 == 1, ChatGLM style
     // TODO: avoid creating a new tensor every time
     GGML_API struct ggml_tensor * ggml_rope(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             int                   n_past,
             int                   n_dims,
-            int                   mode);
+            int                   mode,
+            int                   n_ctx);
 
     // in-place, returns view(a)
     GGML_API struct ggml_tensor * ggml_rope_inplace(
@@ -1050,7 +1052,8 @@ extern "C" {
             struct ggml_tensor  * a,
             int                   n_past,
             int                   n_dims,
-            int                   mode);
+            int                   mode,
+            int                   n_ctx);
 
     // rotary position embedding backward, i.e compute dx from dy
     // a - dy
index c179bee9383880d4346c7d3875fefb52350a25cd..92faf03f746a1c4203f6ab98725ad3660b3e2298 100644 (file)
@@ -6778,6 +6778,7 @@ struct ggml_tensor * ggml_rope_impl(
         int                   n_past,
         int                   n_dims,
         int                   mode,
+        int                   n_ctx,
         bool                  inplace) {
     GGML_ASSERT(n_past >= 0);
     bool is_node = false;
@@ -6790,11 +6791,12 @@ struct ggml_tensor * ggml_rope_impl(
 
     ggml_scratch_save(ctx);
 
-    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
+    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
 
     ((int32_t *) b->data)[0] = n_past;
     ((int32_t *) b->data)[1] = n_dims;
     ((int32_t *) b->data)[2] = mode;
+    ((int32_t *) b->data)[3] = n_ctx;
 
     ggml_scratch_load(ctx);
 
@@ -6811,8 +6813,9 @@ struct ggml_tensor * ggml_rope(
         struct ggml_tensor  * a,
         int                   n_past,
         int                   n_dims,
-        int                   mode) {
-    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, false);
+        int                   mode,
+        int                   n_ctx) {
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, false);
 }
 
 struct ggml_tensor * ggml_rope_inplace(
@@ -6820,8 +6823,9 @@ struct ggml_tensor * ggml_rope_inplace(
         struct ggml_tensor  * a,
         int                   n_past,
         int                   n_dims,
-        int                   mode) {
-    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, true);
+        int                   mode,
+        int                   n_ctx) {
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, true);
 }
 
 // ggml_rope_back
@@ -12440,7 +12444,7 @@ static void ggml_compute_forward_rope_f32(
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(ggml_nelements(src1) == 3);
+    GGML_ASSERT(ggml_nelements(src1) == 4);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -12449,6 +12453,7 @@ static void ggml_compute_forward_rope_f32(
     const int n_past = ((int32_t *) src1->data)[0];
     const int n_dims = ((int32_t *) src1->data)[1];
     const int mode   = ((int32_t *) src1->data)[2];
+    const int n_ctx  = ((int32_t *) src1->data)[3];
 
     assert(n_past >= 0);
 
@@ -12493,6 +12498,7 @@ static void ggml_compute_forward_rope_f32(
     const float theta_scale = powf(10000.0, -2.0f/n_dims);
 
     const bool is_neox = mode & 2;
+    const bool is_glm  = mode & 4;
 
     for (int64_t i3 = 0; i3 < ne3; i3++) {
         for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
@@ -12503,7 +12509,32 @@ static void ggml_compute_forward_rope_f32(
 
                 float theta = (float)p;
 
-                if (!is_neox) {
+                if (is_glm) {
+                    theta = MIN(p, n_ctx - 2);
+                    float block_theta = MAX(p - (n_ctx - 2), 0);
+                    for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
+                        const float cos_theta = cosf(theta);
+                        const float sin_theta = sinf(theta);
+                        const float cos_block_theta = cosf(block_theta);
+                        const float sin_block_theta = sinf(block_theta);
+
+                        theta *= theta_scale;
+                        block_theta *= theta_scale;
+
+                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                              float * dst_data  = (float *)((char *)  dst->data +  i3*nb3 + i2*nb2  + i1*nb1  + i0*nb0);
+
+                        const float x0 = src[0];
+                        const float x1 = src[n_dims/2];
+                        const float x2 = src[n_dims];
+                        const float x3 = src[n_dims/2*3];
+
+                        dst_data[0]          = x0*cos_theta - x1*sin_theta;
+                        dst_data[n_dims/2]   = x0*sin_theta + x1*cos_theta;
+                        dst_data[n_dims]     = x2*cos_block_theta - x3*sin_block_theta;
+                        dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta;
+                    }
+                } else if (!is_neox) {
                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
                         const float cos_theta = cosf(theta);
                         const float sin_theta = sinf(theta);
@@ -12553,7 +12584,7 @@ static void ggml_compute_forward_rope_f16(
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(ggml_nelements(src1) == 3);
+    GGML_ASSERT(ggml_nelements(src1) == 4);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -12562,6 +12593,7 @@ static void ggml_compute_forward_rope_f16(
     const int n_past = ((int32_t *) src1->data)[0];
     const int n_dims = ((int32_t *) src1->data)[1];
     const int mode   = ((int32_t *) src1->data)[2];
+    const int n_ctx  = ((int32_t *) src1->data)[3];
 
     assert(n_past >= 0);
 
@@ -12606,6 +12638,7 @@ static void ggml_compute_forward_rope_f16(
     const float theta_scale = powf(10000.0, -2.0f/n_dims);
 
     const bool is_neox = mode & 2;
+    const bool is_glm  = mode & 4;
 
     for (int64_t i3 = 0; i3 < ne3; i3++) {
         for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
@@ -12616,7 +12649,32 @@ static void ggml_compute_forward_rope_f16(
 
                 float theta = (float)p;
 
-                if (!is_neox) {
+                if (is_glm) {
+                    theta = MIN(p, n_ctx - 2);
+                    float block_theta = MAX(p - (n_ctx - 2), 0);
+                    for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
+                        const float cos_theta = cosf(theta);
+                        const float sin_theta = sinf(theta);
+                        const float cos_block_theta = cosf(block_theta);
+                        const float sin_block_theta = sinf(block_theta);
+
+                        theta *= theta_scale;
+                        block_theta *= theta_scale;
+
+                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                              ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data +  i3*nb3 + i2*nb2  + i1*nb1  + i0*nb0);
+
+                        const float x0 = GGML_FP16_TO_FP32(src[0]);
+                        const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
+                        const float x2 = GGML_FP16_TO_FP32(src[n_dims]);
+                        const float x3 = GGML_FP16_TO_FP32(src[n_dims/2*3]);
+
+                        dst_data[0]          = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+                        dst_data[n_dims/2]   = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+                        dst_data[n_dims]     = GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta);
+                        dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta);
+                    }
+                } if (!is_neox) {
                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
                         const float cos_theta = cosf(theta);
                         const float sin_theta = sinf(theta);
@@ -16189,17 +16247,19 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 if (src0->grad) {
                     assert(src1->type == GGML_TYPE_I32);
-                    assert(ggml_nelements(src1) == 3);
+                    assert(ggml_nelements(src1) == 4);
                     const int n_past = ((int32_t *) src1->data)[0];
                     const int n_dims = ((int32_t *) src1->data)[1];
                     const int mode   = ((int32_t *) src1->data)[2];
+                    const int n_ctx  = ((int32_t *) src1->data)[3];
                     src0->grad = ggml_add_impl(ctx,
                             src0->grad,
                             ggml_rope(ctx,
                                 tensor->grad,
                                 n_past,
                                 n_dims,
-                                mode),
+                                mode,
+                                n_ctx),
                             inplace);
                 }
                 if (src1->grad) {
index b5a499c1db57e4dae69d92df31f5b4f32498b16a..ce5dc32bfbf49460ead4c02b7d4171cd43328d32 100644 (file)
@@ -1139,7 +1139,7 @@ int main(int argc, const char ** argv) {
             int n_rot = ne2[0];
 
             for (int ndims = 3; ndims <= 4; ++ndims) {
-                for (int mode = 0; mode < 4; ++mode) {
+                for (int mode = 0; mode < 8; ++mode) {
                     for (int n_past = 1; n_past < ne2[2]; ++n_past) {
                         x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
 
@@ -1154,7 +1154,7 @@ int main(int argc, const char ** argv) {
                             continue;
                         }
 
-                        struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode));
+                        struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode, 0));
 
                         GGML_PRINT_DEBUG("rope: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
                         check_gradient("rope", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY);