]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : fix llama_copy_state_data with fragmented KV cache (#5840)
authorcompilade <redacted>
Sun, 3 Mar 2024 08:41:55 +0000 (03:41 -0500)
committerGitHub <redacted>
Sun, 3 Mar 2024 08:41:55 +0000 (10:41 +0200)
The row size of the saved states was based on kv_self.head while
it should be based on llama_kv_cache_cell_max.

Existing session files should still work.

* llama : fix llama_kv_cache_cell_max inability to return 1

I've also changed its return type to uint32_t,
because this function is always used to set the value of uint32_t variables,
and because the index already has this type.

* llama : fix state size calculation

Some bytes in the state were unaccounted for in llama_get_state_size.
Since the logits reserve so much space, it did not cause problems.

llama.cpp

index d4c7a965bf377fd1b5b96135333cf8ed7688c602..41d0000da7f9eeb4d54e2f41b1d19d04715983f6 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -2156,10 +2156,12 @@ static bool llama_kv_cache_find_slot(
 }
 
 // find how many cells are currently in use
-static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
-    for (uint32_t i = cache.size - 1; i > 0; --i) {
-        if (cache.cells[i].pos >= 0 && !cache.cells[i].is_empty()) {
-            return i + 1;
+static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
+    for (uint32_t i = cache.size; i > 0; --i) {
+        const llama_kv_cell & cell = cache.cells[i - 1];
+
+        if (cell.pos >= 0 && !cell.is_empty()) {
+            return i;
         }
     }
 
@@ -8178,7 +8180,7 @@ static int llama_decode_internal(
     // a heuristic, to avoid attending the full cache if it is not yet utilized
     // after enough generations, the benefit from this heuristic disappears
     // if we start defragmenting the cache, the benefit from this will be more important
-    kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
+    kv_self.n = std::min(cparams.n_ctx, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
     //kv_self.n = llama_kv_cache_cell_max(kv_self);
 
     //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
@@ -12615,9 +12617,14 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
     const size_t s_logits          = ctx->logits.capacity() * sizeof(float);
     const size_t s_embedding_size  = sizeof(size_t);
     const size_t s_embedding       = ctx->embedding.size() * sizeof(float);
-    const size_t s_kv_size         = sizeof(size_t);
-    const size_t s_kv_ntok         = sizeof(int);
+    const size_t s_kv_buf_size     = sizeof(size_t);
+    const size_t s_kv_head         = sizeof(uint32_t);
+    const size_t s_kv_size         = sizeof(uint32_t);
+    const size_t s_kv_used         = sizeof(uint32_t);
     const size_t s_kv              = ctx->kv_self.total_size();
+    // TODO: assume the max is more than 1 seq_id per KV cell
+    const size_t s_kv_cell         = sizeof(llama_pos) + sizeof(size_t) + sizeof(llama_seq_id);
+    const size_t s_kv_cells        = ctx->kv_self.size * s_kv_cell;
 
     const size_t s_total = (
         + s_rng_size
@@ -12626,9 +12633,12 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
         + s_logits
         + s_embedding_size
         + s_embedding
+        + s_kv_buf_size
+        + s_kv_head
         + s_kv_size
-        + s_kv_ntok
+        + s_kv_used
         + s_kv
+        + s_kv_cells
     );
 
     return s_total;
@@ -12728,15 +12738,13 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
     {
         const auto & kv_self = ctx->kv_self;
         const auto & hparams = ctx->model.hparams;
-        const auto & cparams = ctx->cparams;
 
         const uint32_t n_layer      = hparams.n_layer;
         const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
         const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
-        const uint32_t n_ctx        = cparams.n_ctx;
 
         const size_t   kv_buf_size = kv_self.total_size();
-        const uint32_t kv_head     = kv_self.head;
+        const uint32_t kv_head     = llama_kv_cache_cell_max(kv_self);
         const uint32_t kv_size     = kv_self.size;
         const uint32_t kv_used     = kv_self.used;
 
@@ -12756,7 +12764,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
 
                 // v is not contiguous, copy row by row
                 const size_t v_row_size   = ggml_row_size(kv_self.v_l[il]->type, kv_head);
-                const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, n_ctx);
+                const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
 
                 tmp_buf.resize(v_row_size);
                 for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
@@ -12766,7 +12774,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
             }
         }
 
-        for (uint32_t i = 0; i < kv_size; ++i) {
+        for (uint32_t i = 0; i < kv_head; ++i) {
             const auto & cell = kv_self.cells[i];
 
             const llama_pos pos         = cell.pos;
@@ -12842,12 +12850,10 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
     {
         const auto & kv_self = ctx->kv_self;
         const auto & hparams = ctx->model.hparams;
-        const auto & cparams = ctx->cparams;
 
         const uint32_t n_layer      = hparams.n_layer;
         const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
         const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
-        const uint32_t n_ctx        = cparams.n_ctx;
 
         size_t   kv_buf_size;
         uint32_t kv_head;
@@ -12870,7 +12876,7 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
 
                 // v is not contiguous, copy row by row
                 const size_t v_row_size   = ggml_row_size(kv_self.v_l[il]->type, kv_head);
-                const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, n_ctx);
+                const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
 
                 for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
                     ggml_backend_tensor_set(kv_self.v_l[il], inp, ir*v_row_stride, v_row_size);
@@ -12879,13 +12885,15 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
             }
         }
 
+        GGML_ASSERT(kv_self.size == kv_size);
+
         ctx->kv_self.head = kv_head;
         ctx->kv_self.size = kv_size;
         ctx->kv_self.used = kv_used;
 
         ctx->kv_self.cells.resize(kv_size);
 
-        for (uint32_t i = 0; i < kv_size; ++i) {
+        for (uint32_t i = 0; i < kv_head; ++i) {
             llama_pos pos;
             size_t    seq_id_size;
 
@@ -12901,6 +12909,11 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
                 ctx->kv_self.cells[i].seq_id.insert(seq_id);
             }
         }
+
+        for (uint32_t i = kv_head; i < kv_size; ++i) {
+            ctx->kv_self.cells[i].pos = -1;
+            ctx->kv_self.cells[i].seq_id.clear();
+        }
     }
 
     const size_t nread    = inp - src;