]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : fix rope args order + assert (#2054)
authorGeorgi Gerganov <redacted>
Fri, 21 Jul 2023 11:51:34 +0000 (14:51 +0300)
committerGeorgi Gerganov <redacted>
Fri, 21 Jul 2023 11:51:34 +0000 (14:51 +0300)
examples/train-text-from-scratch/train-text-from-scratch.cpp
ggml.c
ggml.h
llama.cpp

index afbb4a77759fd032a23664eb9f6157746ef8b641..449b4e9ecdd549315dd9c952e5adf692f67bd89e 100644 (file)
@@ -1434,7 +1434,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
     gf->perf_time_us = 0;
 
     const auto & hparams = model->hparams;
-    //const int n_ctx      = hparams.n_ctx;
+    const int n_ctx      = hparams.n_ctx;
     const int n_vocab    = hparams.n_vocab;
     const int n_embd     = hparams.n_embd;
     const int n_layer    = hparams.n_layer;
@@ -1863,10 +1863,10 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
         t12->grad = expand(gb, ggml_permute(ctx0, t15->grad, 0, 2, 3, 1));                                            assert_shape_4d(t12->grad, N, n_batch, n_embd/n_head, n_head);
         t11->grad = expand(gb, ggml_reshape_2d(ctx0, ggml_cont(ctx0, t12->grad), N*n_batch, n_embd));                 assert_shape_2d(t11->grad, N*n_batch, n_embd);
         t10->grad = expand(gb, ggml_permute(ctx0, t14->grad, 0, 2, 1, 3));                                            assert_shape_4d(t10->grad, n_embd/n_head, n_head, N, n_batch);
-        t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode));                            assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch);
+        t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode, n_ctx));                     assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch);
         t08->grad = expand(gb, ggml_reshape_2d(ctx0, t09->grad, n_embd, N*n_batch));                                  assert_shape_2d(t08->grad, n_embd, N*n_batch);
         t07->grad = expand(gb, ggml_permute(ctx0, t13->grad, 0, 2, 1, 3));                                            assert_shape_4d(t07->grad, n_embd/n_head, n_head, N, n_batch);
-        t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode));                            assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch);
+        t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode, n_ctx));                     assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch);
         t05->grad = expand(gb, ggml_reshape_2d(ctx0, t06->grad, n_embd, N*n_batch));                                  assert_shape_2d(t05->grad, n_embd, N*n_batch);
         t04->grad = expand(gb, ggml_add_inplace(ctx0,
                         ggml_add_inplace(ctx0,
diff --git a/ggml.c b/ggml.c
index c56a3d0e0c0a2c695f76f326b9ca8ca4b9ba5ee2..7ecabc5de8bb8b3d7ea4f7a583d0062bc275ebe2 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -6956,9 +6956,9 @@ struct ggml_tensor * ggml_rope_impl(
         int                   n_past,
         int                   n_dims,
         int                   mode,
+        int                   n_ctx,
         float                 freq_base,
         float                 freq_scale,
-        int                   n_ctx,
         bool                  inplace) {
     GGML_ASSERT(n_past >= 0);
     bool is_node = false;
@@ -6997,7 +6997,7 @@ struct ggml_tensor * ggml_rope(
         int                   n_dims,
         int                   mode,
         int                   n_ctx) {
-    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, false);
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, false);
 }
 
 struct ggml_tensor * ggml_rope_inplace(
@@ -7007,7 +7007,7 @@ struct ggml_tensor * ggml_rope_inplace(
         int                   n_dims,
         int                   mode,
         int                   n_ctx) {
-    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, true);
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, true);
 }
 
 struct ggml_tensor * ggml_rope_custom_inplace(
@@ -7016,10 +7016,10 @@ struct ggml_tensor * ggml_rope_custom_inplace(
         int                   n_past,
         int                   n_dims,
         int                   mode,
+        int                   n_ctx,
         float                 freq_base,
-        float                 freq_scale,
-        int                   n_ctx) {
-    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, freq_base, freq_scale, n_ctx, true);
+        float                 freq_scale) {
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, true);
 }
 
 // ggml_rope_back
@@ -7029,7 +7029,8 @@ struct ggml_tensor * ggml_rope_back(
         struct ggml_tensor  * a,
         int                   n_past,
         int                   n_dims,
-        int                   mode) {
+        int                   mode,
+        int                   n_ctx) {
     GGML_ASSERT(n_past >= 0);
     GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
 
@@ -7043,12 +7044,13 @@ struct ggml_tensor * ggml_rope_back(
 
     ggml_scratch_save(ctx);
 
-    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
+    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
     ggml_set_name(b, "n_past, n_dims, mode");
 
     ((int32_t *) b->data)[0] = n_past;
     ((int32_t *) b->data)[1] = n_dims;
     ((int32_t *) b->data)[2] = mode;
+    ((int32_t *) b->data)[3] = n_ctx;
 
     ggml_scratch_load(ctx);
 
@@ -15740,13 +15742,15 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     const int n_past = ((int32_t *) src1->data)[0];
                     const int n_dims = ((int32_t *) src1->data)[1];
                     const int mode   = ((int32_t *) src1->data)[2];
+                    const int n_ctx  = ((int32_t *) src1->data)[3];
                     src0->grad = ggml_add_impl(ctx,
                             src0->grad,
                             ggml_rope_back(ctx,
                                 tensor->grad,
                                 n_past,
                                 n_dims,
-                                mode),
+                                mode,
+                                n_ctx),
                             inplace);
                 }
                 if (src1->grad) {
@@ -15757,7 +15761,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 if (src0->grad) {
                     assert(src1->type == GGML_TYPE_I32);
-                    assert(ggml_nelements(src1) == 3);
+                    assert(ggml_nelements(src1) == 4);
                     const int n_past = ((int32_t *) src1->data)[0];
                     const int n_dims = ((int32_t *) src1->data)[1];
                     const int mode   = ((int32_t *) src1->data)[2];
diff --git a/ggml.h b/ggml.h
index 24856a255c9a6d57683b024c979f159e7f78f64e..5023b165287888e1938e5498edc25694ff7214e1 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -1128,9 +1128,9 @@ extern "C" {
             int                   n_past,
             int                   n_dims,
             int                   mode,
+            int                   n_ctx,
             float                 freq_base,
-            float                 freq_scale,
-            int                   n_ctx);
+            float                 freq_scale);
 
     // rotary position embedding backward, i.e compute dx from dy
     // a - dy
@@ -1139,7 +1139,8 @@ extern "C" {
             struct ggml_tensor  * a,
             int                   n_past,
             int                   n_dims,
-            int                   mode);
+            int                   mode,
+            int                   n_ctx);
 
     // alibi position embedding
     // in-place, returns view(a)
index 3b0024e1284793040b0c66ca4607446b0e731d25..0a381afd5b7265feff3dcdb1fe153780f8a66339 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1452,11 +1452,11 @@ static bool llama_eval_internal(
             offload_func_kq(tmpq);
             ggml_set_name(tmpq, "tmpq");
 
-            struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, freq_base, freq_scale, 0);
+            struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0, freq_base, freq_scale);
             offload_func_kq(Kcur);
             ggml_set_name(Kcur, "Kcur");
 
-            struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, freq_base, freq_scale, 0);
+            struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0, freq_base, freq_scale);
             offload_func_kq(Qcur);
             ggml_set_name(Qcur, "Qcur");