]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
gpt-j : update inference to match latest llama.cpp insights
authorGeorgi Gerganov <redacted>
Tue, 11 Apr 2023 18:33:17 +0000 (21:33 +0300)
committerGeorgi Gerganov <redacted>
Tue, 11 Apr 2023 18:33:17 +0000 (21:33 +0300)
- Use F16 KV cache
- Store transposed V in the cache
- Avoid unnecessary Q copy

examples/gpt-j/main.cpp

index 4455242978199e9450847e77891c19b754dfa7a2..96615f1dcdb97b03f297e834ee6c287f1713512b 100644 (file)
@@ -185,8 +185,8 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab &
         ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype));         // c_mlp_proj_w
         ctx_size += n_layer*(         n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
 
-        ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
-        ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
+        ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F16); // memory_k
+        ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F16); // memory_v
 
         ctx_size += (5 + 10*n_layer)*256; // object overhead
 
@@ -283,8 +283,8 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab &
         const int n_mem      = n_layer*n_ctx;
         const int n_elements = n_embd*n_mem;
 
-        model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
-        model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
+        model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+        model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
 
         const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
 
@@ -465,14 +465,17 @@ bool gptj_eval(
 
         // self-attention
         {
-            struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].c_attn_q_proj_w, cur);
-            struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].c_attn_k_proj_w, cur);
-            struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].c_attn_v_proj_w, cur);
+            struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_q_proj_w, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
+            struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_k_proj_w, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
 
             // store key and value to memory
-            if (N >= 1) {
+            {
+                struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_v_proj_w, cur));
+
                 struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
-                struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
+                struct ggml_tensor * v = ggml_view_2d(ctx0, model.memory_v, N, n_embd,
+                        (   n_ctx)*ggml_element_size(model.memory_v),
+                        (il*n_ctx)*ggml_element_size(model.memory_v)*n_embd + n_past*ggml_element_size(model.memory_v));
 
                 ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
                 ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
@@ -481,21 +484,15 @@ bool gptj_eval(
             // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
             struct ggml_tensor * Q =
                 ggml_permute(ctx0,
-                        ggml_rope(ctx0,
-                            ggml_cpy(ctx0,
-                                Qcur,
-                                ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
-                            n_past, n_rot, 0),
+                        Qcur,
                         0, 2, 1, 3);
 
             // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
             struct ggml_tensor * K =
                 ggml_permute(ctx0,
-                        ggml_rope(ctx0,
-                            ggml_reshape_3d(ctx0,
-                                ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
-                                n_embd/n_head, n_head, n_past + N),
-                            n_past, n_rot, 1),
+                        ggml_reshape_3d(ctx0,
+                            ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
+                            n_embd/n_head, n_head, n_past + N),
                         0, 2, 1, 3);
 
             // K * Q
@@ -515,17 +512,15 @@ bool gptj_eval(
             struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
 
             // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
-            struct ggml_tensor * V_trans =
-                ggml_cpy(ctx0,
-                        ggml_permute(ctx0,
-                            ggml_reshape_3d(ctx0,
-                                ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
-                                n_embd/n_head, n_head, n_past + N),
-                            1, 2, 0, 3),
-                        ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head));
+            struct ggml_tensor * V =
+                ggml_view_3d(ctx0, model.memory_v,
+                        n_past + N, n_embd/n_head, n_head,
+                        n_ctx*ggml_element_size(model.memory_v),
+                        n_ctx*ggml_element_size(model.memory_v)*n_embd/n_head,
+                        il*n_ctx*ggml_element_size(model.memory_v)*n_embd);
 
             // KQV = transpose(V) * KQ_soft_max
-            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
 
             // KQV_merged = KQV.permute(0, 2, 1, 3)
             struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);