]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : fix rope usage after ChatGLM change
authorGeorgi Gerganov <redacted>
Mon, 26 Jun 2023 21:37:13 +0000 (00:37 +0300)
committerGeorgi Gerganov <redacted>
Mon, 26 Jun 2023 21:37:33 +0000 (00:37 +0300)
examples/train-text-from-scratch/train-text-from-scratch.cpp
llama.cpp

index 5c6fd5738ee36eb96190e73bd9117217b38edb6c..a05881d1640e7a1a47aefba8eca40c6d9fa38d5a 100644 (file)
@@ -443,8 +443,8 @@ struct ggml_tensor * forward(
             // wk   shape [n_embd, n_embd, 1, 1]
             // Qcur shape [n_embd/n_head, n_head, N, 1]
             // Kcur shape [n_embd/n_head, n_head, N, 1]
-            struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
-            struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
+            struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
+            struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
 
             // store key and value to memory
             {
@@ -700,8 +700,8 @@ struct ggml_tensor * forward_batch(
             // wk   shape [n_embd, n_embd, 1, 1]
             // Qcur shape [n_embd/n_head, n_head, N, n_batch]
             // Kcur shape [n_embd/n_head, n_head, N, n_batch]
-            struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0);
-            struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0);
+            struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0);
+            struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0);
             assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch);
             assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch);
 
@@ -985,8 +985,8 @@ struct ggml_tensor * forward_batch_wo_cache(
             // wk   shape [n_embd, n_embd, 1, 1]
             // Qcur shape [n_embd/n_head, n_head, N, n_batch]
             // Kcur shape [n_embd/n_head, n_head, N, n_batch]
-            struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0);
-            struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0);
+            struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0);
+            struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0);
             assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch);
             assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch);
 
@@ -1207,8 +1207,8 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
             // compute Q and K and RoPE them
             // wq   shape [n_embd, n_embd, 1, 1]
             // wk   shape [n_embd, n_embd, 1, 1]
-            struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0);
-            struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0);
+            struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0);
+            struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0);
             assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch);
             assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch);
 
@@ -1607,10 +1607,10 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
         use_buf(-1); struct ggml_tensor * t04 = expand(gf, ggml_mul          (ctx0, t02, t03));                               assert_shape_2d(t04, n_embd, N*n_batch);
         use_buf(-1); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat      (ctx0, layer.wq, t04));                          assert_shape_2d(t05, n_embd, N*n_batch);
         use_buf(-1); struct ggml_tensor * t06 = expand(gf, ggml_reshape_4d   (ctx0, t05, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t06, n_embd/n_head, n_head, N, n_batch);
-        use_buf(-1); struct ggml_tensor * t07 = expand(gf, ggml_rope_inplace (ctx0, t06, n_past, n_rot, rope_mode));          assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch);
+        use_buf(-1); struct ggml_tensor * t07 = expand(gf, ggml_rope_inplace (ctx0, t06, n_past, n_rot, rope_mode, 0));       assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch);
         use_buf(-1); struct ggml_tensor * t08 = expand(gf, ggml_mul_mat      (ctx0, layer.wk, t04));                          assert_shape_2d(t08, n_embd, N*n_batch);
         use_buf(-1); struct ggml_tensor * t09 = expand(gf, ggml_reshape_4d   (ctx0, t08, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t09, n_embd/n_head, n_head, N, n_batch);
-        use_buf(-1); struct ggml_tensor * t10 = expand(gf, ggml_rope_inplace (ctx0, t09, n_past, n_rot, rope_mode));          assert_shape_4d(t10, n_embd/n_head, n_head, N, n_batch);
+        use_buf(-1); struct ggml_tensor * t10 = expand(gf, ggml_rope_inplace (ctx0, t09, n_past, n_rot, rope_mode, 0));       assert_shape_4d(t10, n_embd/n_head, n_head, N, n_batch);
         use_buf(-1); struct ggml_tensor * t11 = expand(gf, ggml_mul_mat      (ctx0, t04, layer.wv));                          assert_shape_2d(t11, N*n_batch, n_embd);
         use_buf(-1); struct ggml_tensor * t12 = expand(gf, ggml_reshape_4d   (ctx0, t11, N, n_batch, n_embd/n_head, n_head)); assert_shape_4d(t12, N, n_batch, n_embd/n_head, n_head);
         use_buf(-1); struct ggml_tensor * t13 = expand(gf, ggml_permute      (ctx0, t07, 0, 2, 1, 3));                        assert_shape_4d(t13, n_embd/n_head, N, n_head, n_batch);
index 1a15844bcc7a4a1bd23ac149fdc0209db87f925d..2482bdd18d2e71da15ef76ff244691ef0ae1d0d7 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1491,11 +1491,11 @@ static bool llama_eval_internal(
             offload_func_kq(tmpq);
             ggml_set_name(tmpq, "tmpq");
 
-            struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0);
+            struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
             offload_func_kq(Kcur);
             ggml_set_name(Kcur, "Kcur");
 
-            struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0);
+            struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
             offload_func_kq(Qcur);
             ggml_set_name(Qcur, "Qcur");