]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : fix session saving/loading (#3400)
authorGeorgi Gerganov <redacted>
Tue, 3 Oct 2023 18:04:01 +0000 (21:04 +0300)
committerGitHub <redacted>
Tue, 3 Oct 2023 18:04:01 +0000 (21:04 +0300)
* llama : fix session saving/loading

* llama : temp fix for clearing "future" tokens from the KV cache

* llama : fix handling of "future" tokens when loading sessions

* llama : fix comments for llama_kv_cache API

examples/chat-persistent.sh
examples/main/main.cpp
examples/parallel/parallel.cpp
examples/server/server.cpp
examples/speculative/speculative.cpp
llama.cpp
llama.h

index e0c251e5b03cc7819320368c1a6ba6879665a941..22f5b83d3da06b0bf45049d0d3963d322cf3d1f5 100755 (executable)
@@ -9,7 +9,7 @@ if [[ -z "${PROMPT_CACHE_FILE+x}" || -z "${CHAT_SAVE_DIR+x}" ]]; then
     exit 1
 fi
 
-MODEL="${MODEL:-./models/13B/ggml-model-q4_0.bin}"
+MODEL="${MODEL:-./models/llama-13b/ggml-model-q4_0.gguf}"
 PROMPT_TEMPLATE="${PROMPT_TEMPLATE:-./prompts/chat.txt}"
 USER_NAME="${USER_NAME:-User}"
 AI_NAME="${AI_NAME:-ChatLLaMa}"
@@ -61,9 +61,9 @@ fi
 
 if [[ ! -e "$PROMPT_CACHE_FILE" ]]; then
     echo 'Prompt cache does not exist, building...'
-    # Default batch_size to 8 here for better user feedback during initial prompt processing
+    # Default batch_size to 64 here for better user feedback during initial prompt processing
     ./main 2>>"$LOG" \
-        --batch_size 8 \
+        --batch_size 64 \
         "${OPTS[@]}" \
         --prompt-cache "$PROMPT_CACHE_FILE" \
         --file "$CUR_PROMPT_FILE" \
@@ -132,7 +132,7 @@ while read -e line; do
     # HACK get num tokens from debug message
     # TODO get both messages in one go
     if  ! session_size_msg="$(tail -n30 "$LOG" | grep -oE "$SESSION_SIZE_MSG_PATTERN")" ||
-        ! sample_time_msg="$( tail -n10 "$LOG" | grep -oE "$SAMPLE_TIME_MSG_PATTERN")"; then
+        ! sample_time_msg="$(tail -n10 "$LOG" | grep -oE "$SAMPLE_TIME_MSG_PATTERN")"; then
         echo >&2 "Couldn't get number of tokens from ./main output!"
         exit 1
     fi
index 3a4ed3f7814f8ef5de86853f7a33cda4bfcae8dc..7367ae3625ce2220b0be9f505b6bdd6f63c10e97 100644 (file)
@@ -543,6 +543,9 @@ int main(int argc, char ** argv) {
                 if (i > 0) {
                     embd.erase(embd.begin(), embd.begin() + i);
                 }
+
+                // remove any "future" tokens that we might have inherited from the session from the KV cache
+                llama_kv_cache_tokens_rm(ctx, n_past, -1);
             }
 
             // evaluate tokens in batches
index 0434ded234b183cbd22fb23089e403664dc91518..ffd7b1db4abddd2d4b53ce955cd5f0591a2dd520 100644 (file)
@@ -332,7 +332,7 @@ int main(int argc, char ** argv) {
                     }
 
                     // delete only the generated part of the sequence, i.e. keep the system prompt in the cache
-                    llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, n_ctx);
+                    llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, -1);
 
                     const auto t_main_end = ggml_time_us();
 
index 6dda5e36b74388065e753e03515a5d2a2607f21d..921eb5da4812d0d2b45d61d7f9afef2bfc04fd87 100644 (file)
@@ -448,7 +448,7 @@ struct llama_server_context
         n_past = common_part(embd, prompt_tokens);
 
         // since #3228 we now have to manually manage the KV cache
-        llama_kv_cache_seq_rm(ctx, 0, n_past, params.n_ctx);
+        llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
 
         embd = prompt_tokens;
         if (n_past == num_prompt_tokens)
index c5e5b234f0f5cc45037c041a69e81912e1c7ce4c..75a2e5e22d04645ba499a6a8de845d325c44ee13 100644 (file)
@@ -172,7 +172,7 @@ int main(int argc, char ** argv) {
                 LOG("out of drafted tokens\n");
             }
 
-            llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, n_ctx);
+            llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
             llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0));
             ++n_past_dft;
 
@@ -257,7 +257,7 @@ int main(int argc, char ** argv) {
             }
 
             // evaluate the drafted token on the draft model
-            llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, n_ctx);
+            llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, -1);
             llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0));
             ++n_past_cur;
 
@@ -267,7 +267,7 @@ int main(int argc, char ** argv) {
         }
 
         // evaluate the target model on the drafted tokens
-        llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, n_ctx);
+        llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, -1);
         llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0));
         ++n_past_tgt;
 
index aa1b4732c7c0ca56623ab3d9c2c78e61d6a86894..a40da68391853c62d183c3c0d26806dd1bda5461 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1283,8 +1283,8 @@ static bool llama_kv_cache_init(
 // find an empty slot of size "n_tokens" in the cache
 // updates the cache head
 static bool llama_kv_cache_find_slot(
-             struct llama_kv_cache & cache,
-          const struct llama_batch & batch) {
+           struct llama_kv_cache & cache,
+        const struct llama_batch & batch) {
     const uint32_t n_ctx    = cache.size;
     const uint32_t n_tokens = batch.n_tokens;
 
@@ -1352,10 +1352,13 @@ static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0,
 }
 
 static void llama_kv_cache_seq_rm(
-             struct llama_kv_cache & cache,
-                      llama_seq_id   seq_id,
-                         llama_pos   p0,
-                         llama_pos   p1) {
+        struct llama_kv_cache & cache,
+                 llama_seq_id   seq_id,
+                    llama_pos   p0,
+                    llama_pos   p1) {
+    if (p0 < 0) p0 = 0;
+    if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
+
     for (uint32_t i = 0; i < cache.size; ++i) {
         if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
             cache.cells[i].seq_id.erase(seq_id);
@@ -1367,11 +1370,14 @@ static void llama_kv_cache_seq_rm(
 }
 
 static void llama_kv_cache_seq_cp(
-             struct llama_kv_cache & cache,
-                      llama_seq_id   seq_id_src,
-                      llama_seq_id   seq_id_dst,
-                         llama_pos   p0,
-                         llama_pos   p1) {
+        struct llama_kv_cache & cache,
+                 llama_seq_id   seq_id_src,
+                 llama_seq_id   seq_id_dst,
+                    llama_pos   p0,
+                    llama_pos   p1) {
+    if (p0 < 0) p0 = 0;
+    if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
+
     for (uint32_t i = 0; i < cache.size; ++i) {
         if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
             cache.cells[i].seq_id.insert(seq_id_dst);
@@ -1389,11 +1395,14 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
 }
 
 static void llama_kv_cache_seq_shift(
-             struct llama_kv_cache & cache,
-                      llama_seq_id   seq_id,
-                         llama_pos   p0,
-                         llama_pos   p1,
-                         llama_pos   delta) {
+        struct llama_kv_cache & cache,
+                 llama_seq_id   seq_id,
+                    llama_pos   p0,
+                    llama_pos   p1,
+                    llama_pos   delta) {
+    if (p0 < 0) p0 = 0;
+    if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
+
     for (uint32_t i = 0; i < cache.size; ++i) {
         if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
             cache.cells[i].pos += delta;
@@ -7209,16 +7218,6 @@ struct llama_data_file_context : llama_data_context {
  *
 */
 static void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) {
-    // TODO: does not support multi-sequence states
-    {
-        const auto & kv_self = ctx->kv_self;
-        for (uint32_t i = 0; i < kv_self.head; ++i) {
-            GGML_ASSERT(kv_self.cells[i].pos == (int32_t) i);
-            GGML_ASSERT(kv_self.cells[i].seq_id.size() == 1);
-            GGML_ASSERT(kv_self.cells[i].has_seq_id(0));
-        }
-    }
-
     // copy rng
     {
         std::stringstream rng_ss;
@@ -7271,36 +7270,38 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
         const auto & hparams = ctx->model.hparams;
         const auto & cparams = ctx->cparams;
 
-        const int    n_layer = hparams.n_layer;
-        const int    n_embd  = hparams.n_embd_gqa();
-        const int    n_ctx   = cparams.n_ctx;
+        const auto   n_layer = hparams.n_layer;
+        const auto   n_embd  = hparams.n_embd_gqa();
+        const auto   n_ctx   = cparams.n_ctx;
 
-        const size_t kv_size = kv_self.buf.size;
-        const int    kv_ntok = kv_self.head;
+        const size_t   kv_buf_size = kv_self.buf.size;
+        const uint32_t kv_head     = kv_self.head;
+        const uint32_t kv_size     = kv_self.size;
 
-        data_ctx->write(&kv_size, sizeof(kv_size));
-        data_ctx->write(&kv_ntok, sizeof(kv_ntok));
+        data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
+        data_ctx->write(&kv_head,     sizeof(kv_head));
+        data_ctx->write(&kv_size,     sizeof(kv_size));
 
-        if (kv_size) {
+        if (kv_buf_size) {
             const size_t elt_size = ggml_element_size(kv_self.k);
 
             ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true });
             ggml_cgraph gf{};
 
-            ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
+            ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer);
             std::vector<uint8_t> kout3d_data(ggml_nbytes(kout3d), 0);
             kout3d->data = kout3d_data.data();
 
-            ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
+            ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_head, n_embd, n_layer);
             std::vector<uint8_t> vout3d_data(ggml_nbytes(vout3d), 0);
             vout3d->data = vout3d_data.data();
 
             ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
-                n_embd, kv_ntok, n_layer,
+                n_embd, kv_head, n_layer,
                 elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
 
             ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
-                kv_ntok, n_embd, n_layer,
+                kv_head, n_embd, n_layer,
                 elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
 
             ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d));
@@ -7314,6 +7315,20 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
             data_ctx->write(kout3d_data.data(), kout3d_data.size());
             data_ctx->write(vout3d_data.data(), vout3d_data.size());
         }
+
+        for (uint32_t i = 0; i < kv_size; ++i) {
+            const auto & cell = kv_self.cells[i];
+
+            const llama_pos pos         = cell.pos;
+            const size_t    seq_id_size = cell.seq_id.size();
+
+            data_ctx->write(&pos,         sizeof(pos));
+            data_ctx->write(&seq_id_size, sizeof(seq_id_size));
+
+            for (auto seq_id : cell.seq_id) {
+                data_ctx->write(&seq_id, sizeof(seq_id));
+            }
+        }
     }
 }
 
@@ -7385,34 +7400,36 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
         const int    n_embd  = hparams.n_embd_gqa();
         const int    n_ctx   = cparams.n_ctx;
 
-        size_t kv_size;
-        int kv_ntok;
+        size_t   kv_buf_size;
+        uint32_t kv_head;
+        uint32_t kv_size;
 
-        memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
-        memcpy(&kv_ntok, inp, sizeof(kv_ntok)); inp += sizeof(kv_ntok);
+        memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size);
+        memcpy(&kv_head,     inp, sizeof(kv_head));     inp += sizeof(kv_head);
+        memcpy(&kv_size,     inp, sizeof(kv_size));     inp += sizeof(kv_size);
 
-        if (kv_size) {
-            GGML_ASSERT(kv_self.buf.size == kv_size);
+        if (kv_buf_size) {
+            GGML_ASSERT(kv_self.buf.size == kv_buf_size);
 
             const size_t elt_size = ggml_element_size(kv_self.k);
 
             ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true });
             ggml_cgraph gf{};
 
-            ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
+            ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer);
             kin3d->data = (void *) inp;
             inp += ggml_nbytes(kin3d);
 
-            ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
+            ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_head, n_embd, n_layer);
             vin3d->data = (void *) inp;
             inp += ggml_nbytes(vin3d);
 
             ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
-                n_embd, kv_ntok, n_layer,
+                n_embd, kv_head, n_layer,
                 elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
 
             ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
-                kv_ntok, n_embd, n_layer,
+                kv_head, n_embd, n_layer,
                 elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
 
             ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d));
@@ -7422,8 +7439,27 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
             ggml_free(cpy_ctx);
         }
 
-        ctx->kv_self.head = kv_ntok;
+        ctx->kv_self.head = kv_head;
         ctx->kv_self.size = kv_size;
+
+        ctx->kv_self.cells.resize(kv_size);
+
+        for (uint32_t i = 0; i < kv_size; ++i) {
+            llama_pos pos;
+            size_t    seq_id_size;
+
+            memcpy(&pos,         inp, sizeof(pos));         inp += sizeof(pos);
+            memcpy(&seq_id_size, inp, sizeof(seq_id_size)); inp += sizeof(seq_id_size);
+
+            ctx->kv_self.cells[i].pos = pos;
+
+            llama_seq_id seq_id;
+
+            for (size_t j = 0; j < seq_id_size; ++j) {
+                memcpy(&seq_id, inp, sizeof(seq_id)); inp += sizeof(seq_id);
+                ctx->kv_self.cells[i].seq_id.insert(seq_id);
+            }
+        }
     }
 
     const size_t nread    = inp - src;
diff --git a/llama.h b/llama.h
index 0177d07a9104406ec242ece5ee20c5d3fedd3695..a78015adab30c37179a87249684df66963d62204 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -42,7 +42,7 @@
 #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
 
 #define LLAMA_SESSION_MAGIC   LLAMA_FILE_MAGIC_GGSN
-#define LLAMA_SESSION_VERSION 1
+#define LLAMA_SESSION_VERSION 2
 
 #if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
 // Defined when llama.cpp is compiled with support for offloading model layers to GPU.
@@ -333,12 +333,16 @@ extern "C" {
             "avoid using this, it will be removed in the future, instead - count the tokens in user code");
 
     // Remove all tokens data of cells in [c0, c1)
+    // c0 < 0 : [0,  c1]
+    // c1 < 0 : [c0, inf)
     LLAMA_API void llama_kv_cache_tokens_rm(
             struct llama_context * ctx,
                          int32_t   c0,
                          int32_t   c1);
 
     // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
+    // p0 < 0 : [0,  p1]
+    // p1 < 0 : [p0, inf)
     LLAMA_API void llama_kv_cache_seq_rm(
             struct llama_context * ctx,
                     llama_seq_id   seq_id,
@@ -347,6 +351,8 @@ extern "C" {
 
     // Copy all tokens that belong to the specified sequence to another sequence
     // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
+    // p0 < 0 : [0,  p1]
+    // p1 < 0 : [p0, inf)
     LLAMA_API void llama_kv_cache_seq_cp(
             struct llama_context * ctx,
                     llama_seq_id   seq_id_src,
@@ -361,6 +367,8 @@ extern "C" {
 
     // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
     // If the KV cache is RoPEd, the KV data is updated accordingly
+    // p0 < 0 : [0,  p1]
+    // p1 < 0 : [p0, inf)
     LLAMA_API void llama_kv_cache_seq_shift(
             struct llama_context * ctx,
                     llama_seq_id   seq_id,