]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : replace (permute + reshape + view_1d) with (view_3d) (#2538)
authorGeorgi Gerganov <redacted>
Thu, 17 Aug 2023 07:47:09 +0000 (10:47 +0300)
committerGitHub <redacted>
Thu, 17 Aug 2023 07:47:09 +0000 (10:47 +0300)
ggml-ci

llama.cpp

index 3452439904bc0aa93655b5e4986d9b4a2fd532d1..b8cc229427620add977f403600ee0c48ae60efdf 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1609,11 +1609,11 @@ static struct ggml_cgraph * llama_build_graph(
             ggml_set_name(Q, "Q");
 
             struct ggml_tensor * K =
-                ggml_permute(ctx0,
-                        ggml_reshape_3d(ctx0,
-                            ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd_gqa, il*n_ctx*ggml_element_size(kv_self.k)*n_embd_gqa),
-                            n_embd_head, n_head_kv, n_past + N),
-                        0, 2, 1, 3);
+                ggml_view_3d(ctx0, kv_self.k,
+                        n_embd_head, n_past + N, n_head_kv,
+                        ggml_element_size(kv_self.k)*n_embd_gqa,
+                        ggml_element_size(kv_self.k)*n_embd_head,
+                        ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
             offload_func_kq(K);
             ggml_set_name(K, "K");
 
@@ -1642,9 +1642,9 @@ static struct ggml_cgraph * llama_build_graph(
             struct ggml_tensor * V =
                 ggml_view_3d(ctx0, kv_self.v,
                         n_past + N, n_embd_head, n_head_kv,
-                        n_ctx*ggml_element_size(kv_self.v),
-                        n_ctx*ggml_element_size(kv_self.v)*n_embd_head,
-                        n_ctx*ggml_element_size(kv_self.v)*n_embd_gqa*il);
+                        ggml_element_size(kv_self.v)*n_ctx,
+                        ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
+                        ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
             offload_func_v(V);
             ggml_set_name(V, "V");