]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : pre-allocate input tensors in a separate buffer (#5100)
authorslaren <redacted>
Wed, 24 Jan 2024 11:48:14 +0000 (12:48 +0100)
committerGitHub <redacted>
Wed, 24 Jan 2024 11:48:14 +0000 (12:48 +0100)
ggml-alloc.c
llama.cpp

index 89b85d34870d76b058d62c9b87b833cb9bf70903..60141a34d8f6a22a13238934e66ba920f3ea99ac 100644 (file)
@@ -109,8 +109,8 @@ void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
         if (block->size >= size) {
             best_fit_block = alloc->n_free_blocks - 1;
         } else {
-            fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n",
-                    __func__, size, max_avail);
+            fprintf(stderr, "%s: not enough space in the buffer to allocate %s (needed %zu, largest block available %zu)\n",
+                    __func__, tensor->name, size, max_avail);
             GGML_ASSERT(!"not enough space in the buffer");
             return;
         }
index 582e82260ea853178a3c8a1352c511bf1141ab70..114046db9267560a2a54a0b4eb6c4d22e247634a 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1669,6 +1669,9 @@ struct llama_context {
         for (ggml_backend_t backend : backends) {
             ggml_backend_free(backend);
         }
+
+        ggml_backend_buffer_free(buf_input);
+        ggml_free(ctx_input);
     }
 
     llama_cparams cparams;
@@ -1715,8 +1718,14 @@ struct llama_context {
     // allocator for the input tensors
     ggml_tallocr * alloc = nullptr;
 
-    // temporary buffer for copying data to/from the backend
-    std::vector<no_init<uint8_t>> buf_copy;
+    // input tensors
+    ggml_backend_buffer_t buf_input = nullptr;
+    ggml_context * ctx_input = nullptr;
+    struct ggml_tensor * inp_tokens;    // I32 [n_batch]
+    struct ggml_tensor * inp_embd;      // F32 [n_embd, n_batch]
+    struct ggml_tensor * inp_pos;       // I32 [n_batch]
+    struct ggml_tensor * inp_KQ_mask;   // F32 [n_ctx, n_batch]
+    struct ggml_tensor * inp_K_shift;   // I32 [n_ctx]
 
 #ifdef GGML_USE_MPI
     ggml_mpi_context * ctx_mpi = NULL;
@@ -4089,22 +4098,24 @@ static struct ggml_tensor * llm_build_inp_embd(
         const llama_hparams & hparams,
           const llama_batch & batch,
          struct ggml_tensor * tok_embd,
+         struct ggml_tensor * inp_tokens,
+         struct ggml_tensor * inp_embd,
          const llm_build_cb & cb) {
     const int64_t n_embd = hparams.n_embd;
 
     struct ggml_tensor * inpL;
 
     if (batch.token) {
-        struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, batch.n_tokens);
+        struct ggml_tensor * inp_tokens_v = ggml_view_1d(ctx, inp_tokens, batch.n_tokens, 0);
         cb(inp_tokens, "inp_tokens", -1);
 
-        inpL = ggml_get_rows(ctx, tok_embd, inp_tokens);
+        inpL = ggml_get_rows(ctx, tok_embd, inp_tokens_v);
     } else {
 #ifdef GGML_USE_MPI
         GGML_ASSERT(false && "not implemented");
 #endif
 
-        inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
+        inpL = ggml_view_2d(ctx, inp_embd, n_embd, batch.n_tokens, inp_embd->nb[1], 0);
     }
 
     return inpL;
@@ -4118,6 +4129,7 @@ static void llm_build_k_shift(
       const llama_cparams & cparams,
      const llama_kv_cache & kv,
        struct ggml_cgraph * graph,
+       struct ggml_tensor * K_shift,
             llm_rope_type   type,
                   int64_t   n_ctx,
                   float     freq_base,
@@ -4134,9 +4146,6 @@ static void llm_build_k_shift(
     const float   beta_fast     = cparams.yarn_beta_fast;
     const float   beta_slow     = cparams.yarn_beta_slow;
 
-    struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_ctx);
-    cb(K_shift, "K_shift", -1);
-
     int rope_type = 0;
 
     switch (type) {
@@ -4457,6 +4466,7 @@ static struct ggml_tensor * llm_build_kv(
 
 struct llm_build_context {
     const llama_model    & model;
+    const llama_context  & lctx;
     const llama_hparams  & hparams;
     const llama_cparams  & cparams;
     const llama_batch    & batch;
@@ -4503,6 +4513,7 @@ struct llm_build_context {
     const llm_build_cb & cb,
                   bool   worst_case) :
         model            (lctx.model),
+        lctx             (lctx),
         hparams          (model.hparams),
         cparams          (lctx.cparams),
         batch            (batch),
@@ -4563,20 +4574,20 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
         cb(inpL, "inp_embd", -1);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
         cb(inp_pos, "inp_pos", -1);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
         // shift the entire K-cache if needed
         if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
         }
 
         for (int il = 0; il < n_layer; ++il) {
@@ -4747,20 +4758,20 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
         cb(inpL, "inp_embd", -1);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
         cb(inp_pos, "inp_pos", -1);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
         // shift the entire K-cache if needed
         if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
         }
 
         for (int il = 0; il < n_layer; ++il) {
@@ -4868,20 +4879,20 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
         cb(inpL, "inp_embd", -1);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
         cb(inp_pos, "inp_pos", -1);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
         // shift the entire K-cache if needed
         if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
         }
 
         for (int il = 0; il < n_layer; ++il) {
@@ -4990,15 +5001,15 @@ struct llm_build_context {
         struct ggml_tensor * pos;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
         cb(inpL, "inp_embd", -1);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
         cb(inp_pos, "inp_pos", -1);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
         pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
@@ -5087,19 +5098,19 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
         cb(inpL, "inp_embd", -1);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
         cb(inp_pos, "inp_pos", -1);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
         if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
         }
 
         for (int il = 0; il < n_layer; ++il) {
@@ -5294,11 +5305,11 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
         cb(inpL, "inp_embd", -1);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
         for (int il = 0; il < n_layer; ++il) {
@@ -5384,11 +5395,11 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
         cb(inpL, "inp_embd", -1);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
         inpL = llm_build_norm(ctx0, inpL, hparams,
@@ -5477,11 +5488,11 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
         cb(inpL, "inp_embd", -1);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
         for (int il = 0; il < n_layer; ++il) {
@@ -5573,20 +5584,20 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
         cb(inpL, "inp_embd", -1);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
         cb(inp_pos, "inp_pos", -1);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
         // shift the entire K-cache if needed
         if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
         }
 
         for (int il = 0; il < n_layer; ++il) {
@@ -5696,20 +5707,20 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
         cb(inpL, "inp_embd", -1);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
         cb(inp_pos, "inp_pos", -1);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
         // shift the entire K-cache if needed
         if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
         }
 
         for (int il = 0; il < n_layer; ++il) {
@@ -5810,20 +5821,20 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
         cb(inpL, "inp_embd", -1);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
         cb(inp_pos, "inp_pos", -1);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
         // shift the entire K-cache if needed
         if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
         }
 
         for (int il = 0; il < n_layer; ++il) {
@@ -5931,20 +5942,20 @@ struct llm_build_context {
         struct ggml_tensor * ffn_output;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
         cb(inpL, "inp_embd", -1);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
         cb(inp_pos, "inp_pos", -1);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
         // shift the entire K-cache if needed
         if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
         }
 
         for (int il = 0; il < n_layer; ++il) {
@@ -6053,20 +6064,20 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
         cb(inpL, "inp_embd", -1);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
         cb(inp_pos, "inp_pos", -1);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
         // shift the entire K-cache if needed
         if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
         }
 
         for (int il = 0; il < n_layer; ++il) {
@@ -6160,15 +6171,15 @@ struct llm_build_context {
         struct ggml_tensor * pos;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
         cb(inpL, "inp_embd", -1);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
         cb(inp_pos, "inp_pos", -1);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
         pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
@@ -6258,20 +6269,20 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
         cb(inpL, "inp_embd", -1);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
         cb(inp_pos, "inp_pos", -1);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
         // shift the entire K-cache if needed
         if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
         }
 
         for (int il = 0; il < n_layer; ++il) {
@@ -6365,15 +6376,7 @@ static struct ggml_cgraph * llama_build_graph(
     // check if we should build the worst-case graph (for memory measurement)
     const bool worst_case = ggml_tallocr_is_measure(lctx.alloc);
 
-    // keep track of the input that has already been allocated
-    bool alloc_inp_tokens   = false;
-    bool alloc_inp_embd     = false;
-    bool alloc_inp_pos      = false;
-    bool alloc_inp_KQ_mask  = false;
-    bool alloc_inp_K_shift  = false;
-
     // this callback allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
-    // TODO: improve handling of input and output tensors, then replace this with ggml_set_name
     llm_build_cb cb = [&](struct ggml_tensor * cur, const char * name, int il) {
         if (il >= 0) {
             ggml_format_name(cur, "%s-%d", name, il);
@@ -6381,126 +6384,78 @@ static struct ggml_cgraph * llama_build_graph(
             ggml_set_name(cur, name);
         }
 
-
         if (!lctx.cparams.offload_kqv) {
             if (strcmp(name, "kqv_merged_cont") == 0) {
                 // all nodes between the KV store and the attention output are run on the CPU
                 ggml_backend_sched_set_node_backend(lctx.sched, cur, lctx.backend_cpu);
             }
         }
+    };
 
-        //
-        // allocate input tensors and set input data
-        //
+    struct ggml_cgraph * result = NULL;
 
-        if (!alloc_inp_tokens && strcmp(name, "inp_tokens") == 0) {
-            ggml_tallocr_alloc(lctx.alloc, cur);
+    struct llm_build_context llm(lctx, batch, cb, worst_case);
 
-            if (!ggml_tallocr_is_measure(lctx.alloc) && batch.token) {
-                const int64_t n_tokens = cur->ne[0];
+    //
+    // set input data
+    //
 
-                ggml_backend_tensor_set(cur, batch.token, 0, n_tokens*ggml_element_size(cur));
-            }
+    if (!ggml_tallocr_is_measure(lctx.alloc)) {
+        if (batch.token) {
+            const int64_t n_tokens = batch.n_tokens;
 
-            alloc_inp_tokens = true;
+            ggml_backend_tensor_set(lctx.inp_tokens, batch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens));
         }
 
-        if (!alloc_inp_embd && strcmp(name, "inp_embd") == 0 && batch.embd) {
-            ggml_tallocr_alloc(lctx.alloc, cur);
-
-            if (!ggml_tallocr_is_measure(lctx.alloc) && batch.embd) {
-                const int64_t n_embd   = cur->ne[0];
-                const int64_t n_tokens = cur->ne[1];
-
-                ggml_backend_tensor_set(cur, batch.embd, 0, n_tokens*n_embd*ggml_element_size(cur));
-            }
+        if (batch.embd) {
+            const int64_t n_embd   = llm.n_embd;
+            const int64_t n_tokens = batch.n_tokens;
 
-            alloc_inp_embd = true;
+            ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
         }
 
-        if (!alloc_inp_pos && strcmp(name, "inp_pos") == 0) {
-            ggml_tallocr_alloc(lctx.alloc, cur);
-
-            if (!ggml_tallocr_is_measure(lctx.alloc) && batch.pos) {
-                const int64_t n_tokens = cur->ne[0];
-
-                static_assert(std::is_same<llama_pos, int32_t>::value, "llama_pos must be int32_t");
-                ggml_backend_tensor_set(cur, batch.pos, 0, n_tokens*ggml_element_size(cur));
-            }
+        if (batch.pos) {
+            const int64_t n_tokens = batch.n_tokens;
 
-            alloc_inp_pos = true;
+            ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
         }
 
-        if (!alloc_inp_KQ_mask && strcmp(name, "KQ_mask") == 0) {
-            ggml_tallocr_alloc(lctx.alloc, cur);
+        {
+            const int64_t n_kv     = llm.n_kv;
+            const int64_t n_tokens = batch.n_tokens;
 
-            if (!ggml_tallocr_is_measure(lctx.alloc)) {
-                const int64_t n_kv     = cur->ne[0];
-                const int64_t n_tokens = cur->ne[1];
+            GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
+            float * data = (float *) lctx.inp_KQ_mask->data;
 
-                float * data;
-                if (ggml_backend_buffer_is_host(cur->buffer)) {
-                    data = (float *) cur->data;
-                } else {
-                    lctx.buf_copy.resize(ggml_nbytes(cur));
-                    data = (float *) lctx.buf_copy.data();
-                }
+            for (int h = 0; h < 1; ++h) {
+                for (int j = 0; j < n_tokens; ++j) {
+                    const llama_pos    pos    = batch.pos[j];
+                    const llama_seq_id seq_id = batch.seq_id[j][0];
 
-                for (int h = 0; h < 1; ++h) {
-                    for (int j = 0; j < n_tokens; ++j) {
-                        const llama_pos    pos    = batch.pos[j];
-                        const llama_seq_id seq_id = batch.seq_id[j][0];
-
-                        for (int i = 0; i < n_kv; ++i) {
-                            float f;
-                            if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
-                                f = -INFINITY;
-                            } else {
-                                f = 0;
-                            }
-                            data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
+                    for (int i = 0; i < n_kv; ++i) {
+                        float f;
+                        if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
+                            f = -INFINITY;
+                        } else {
+                            f = 0;
                         }
+                        data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
                     }
                 }
-
-                if (data != cur->data) {
-                    ggml_backend_tensor_set(cur, data, 0, ggml_nbytes(cur));
-                }
             }
-
-            alloc_inp_KQ_mask = true;
         }
 
-        if (!alloc_inp_K_shift && strcmp(name, "K_shift") == 0) {
-            ggml_tallocr_alloc(lctx.alloc, cur);
-
-            if (!ggml_tallocr_is_measure(lctx.alloc)) {
-                const int64_t n_ctx = cur->ne[0];
+        if (llm.do_rope_shift) {
+            const int64_t n_ctx = llm.n_ctx;
 
-                int32_t * data;
-                if (ggml_backend_buffer_is_host(cur->buffer)) {
-                    data = (int32_t *) cur->data;
-                } else {
-                    lctx.buf_copy.resize(ggml_nbytes(cur));
-                    data = (int32_t *) lctx.buf_copy.data();
-                }
+            GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
+            int32_t * data = (int32_t *) lctx.inp_K_shift->data;
 
-                for (int i = 0; i < n_ctx; ++i) {
-                    data[i] = lctx.kv_self.cells[i].delta;
-                }
-
-                if (data != cur->data) {
-                    ggml_backend_tensor_set(cur, data, 0, ggml_nbytes(cur));
-                }
+            for (int i = 0; i < n_ctx; ++i) {
+                data[i] = lctx.kv_self.cells[i].delta;
             }
-
-            alloc_inp_K_shift = true;
         }
-    };
-
-    struct ggml_cgraph * result = NULL;
-
-    struct llm_build_context llm(lctx, batch, cb, worst_case);
+    }
 
     llm.init();
 
@@ -9964,6 +9919,35 @@ struct llama_context * llama_new_context_with_model(
             ctx->embedding.resize(hparams.n_embd);
         }
 
+        // graph inputs
+        {
+            ggml_init_params init_params = {
+                /* .mem_size   */ ggml_tensor_overhead()*5,
+                /* .mem_buffer */ nullptr,
+                /* .no_alloc   */ true,
+            };
+            ctx->ctx_input = ggml_init(init_params);
+
+            ctx->inp_tokens  = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
+            ctx->inp_embd    = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_embd, cparams.n_batch);
+            ctx->inp_pos     = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
+            ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch);
+            ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx);
+
+            ggml_set_name(ctx->inp_tokens,  "inp_tokens");
+            ggml_set_name(ctx->inp_embd,    "inp_embd");
+            ggml_set_name(ctx->inp_pos,     "inp_pos");
+            ggml_set_name(ctx->inp_KQ_mask, "inp_KQ_mask");
+            ggml_set_name(ctx->inp_K_shift, "inp_K_shift");
+
+            ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true));
+
+            LLAMA_LOG_INFO("%s: %10s input buffer size   = %8.2f MiB\n", __func__,
+                    ggml_backend_buffer_name(ctx->buf_input),
+                    ggml_backend_buffer_get_size(ctx->buf_input) / 1024.0 / 1024.0);
+        }
+
+        // scheduler and compute buffers
         {
             // buffer types used for the compute buffer of each backend
             std::vector<ggml_backend_buffer_type_t> backend_buft;
@@ -9990,9 +9974,6 @@ struct llama_context * llama_new_context_with_model(
 
             // initialize scheduler with the worst-case graph
             ggml_backend_sched_init_measure(ctx->sched, gf);
-            // note: the number of splits during measure is higher than during inference due to the kv shift
-            int n_splits = ggml_backend_sched_get_n_splits(ctx->sched);
-            LLAMA_LOG_INFO("%s: graph splits (measure): %d\n", __func__, n_splits);
             ctx->alloc = ggml_backend_sched_get_tallocr(ctx->sched, ctx->backend_cpu);
 
             for (ggml_backend_t backend : ctx->backends) {
@@ -10001,6 +9982,10 @@ struct llama_context * llama_new_context_with_model(
                         ggml_backend_buffer_name(buf),
                         ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0);
             }
+
+            // note: the number of splits during measure is higher than during inference due to the kv shift
+            int n_splits = ggml_backend_sched_get_n_splits(ctx->sched);
+            LLAMA_LOG_INFO("%s: graph splits (measure): %d\n", __func__, n_splits);
         }
     }