]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : KV cache view API + better KV cache management (#4170)
authorGeorgi Gerganov <redacted>
Thu, 23 Nov 2023 17:07:56 +0000 (19:07 +0200)
committerGitHub <redacted>
Thu, 23 Nov 2023 17:07:56 +0000 (19:07 +0200)
* llama : keep track of used KV cells + better KV cache management

* llama : zero KV cache used upon clear

ggml-ci

* llama : allow exporting a view of the KV cache (#4180)

* Allow exporting a view of the KV cache

* Allow dumping the sequences per cell in common

* Track max contiguous cells value and position as well

* Fix max contiguous empty cells index calculation

Make dump functions deal with lengths or sequences counts > 10 better

* Fix off by one error in dump_kv_cache_view

* Add doc comments for KV cache view functions

Eliminate cell sequence struct; use llama_seq_id directly

Minor cleanups

* common : add -dkvc arg for enabling kv cache dumps

---------

Co-authored-by: Kerfuffle <redacted>
common/common.cpp
common/common.h
examples/parallel/parallel.cpp
llama.cpp
llama.h

index eec704b99f888ec7720e49aaa754678e768c93c3..1dcc235eac0e6b4d78f57b1f0051e1d6eb92ecef 100644 (file)
@@ -12,6 +12,7 @@
 #include <regex>
 #include <sstream>
 #include <string>
+#include <unordered_map>
 #include <unordered_set>
 #include <vector>
 #include <cinttypes>
@@ -495,6 +496,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
             params.chatml = true;
         } else if (arg == "--infill") {
             params.infill = true;
+        } else if (arg == "-dkvc" || arg == "--dump-kv-cache") {
+            params.dump_kv_cache = true;
         } else if (arg == "--multiline-input") {
             params.multiline_input = true;
         } else if (arg == "--simple-io") {
@@ -835,6 +838,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
 #endif // GGML_USE_CUBLAS
 #endif
     printf("  --verbose-prompt      print prompt before generation\n");
+    printf("  -dkvc, --dump-kv-cache\n");
+    printf("                        verbose print of the KV cache\n");
     printf("  --simple-io           use basic IO for better compatibility in subprocesses and limited consoles\n");
     printf("  --lora FNAME          apply LoRA adapter (implies --no-mmap)\n");
     printf("  --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n");
@@ -1386,3 +1391,77 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
     fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
     fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
 }
+
+//
+// KV cache utils
+//
+
+void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size) {
+    static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+";
+
+    printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d",
+        view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
+
+    llama_kv_cache_view_cell * c_curr = view.cells;
+    llama_seq_id * cs_curr = view.cells_sequences;
+
+    for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
+        if (i % row_size == 0) {
+            printf("\n%5d: ", i);
+        }
+        int seq_count = 0;
+        for (int j = 0; j < view.n_max_seq; j++) {
+            if (cs_curr[j] >= 0) { seq_count++; }
+        }
+        putchar(slot_chars[std::min(sizeof(slot_chars) - 2, size_t(seq_count))]);
+    }
+
+    printf("\n=== Done dumping\n");
+}
+
+void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
+    static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
+
+    printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n",
+        view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
+
+    std::unordered_map<llama_seq_id, size_t> seqs;
+    llama_kv_cache_view_cell * c_curr = view.cells;
+    llama_seq_id * cs_curr = view.cells_sequences;
+
+    for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
+        for (int j = 0; j < view.n_max_seq; j++) {
+            if (cs_curr[j] < 0) { continue; }
+            if (seqs.find(cs_curr[j]) == seqs.end()) {
+                if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
+                seqs[cs_curr[j]] = seqs.size();
+            }
+        }
+        if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
+    }
+
+    printf("=== Sequence legend: ");
+    for (const auto & it : seqs) {
+        printf("%zu=%d, ", it.second, it.first);
+    }
+    printf("'+'=other sequence ids");
+
+    c_curr = view.cells;
+    cs_curr = view.cells_sequences;
+    for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
+        if (i % row_size == 0) {
+            printf("\n%5d: ", i);
+        }
+        for (int j = 0; j < view.n_max_seq; j++) {
+            if (cs_curr[j] >= 0) {
+                const auto & it = seqs.find(cs_curr[j]);
+                putchar(it != seqs.end() ? int(slot_chars[it->second]) : '+');
+            } else {
+                putchar('.');
+            }
+        }
+        putchar(' ');
+    }
+
+    printf("\n=== Done dumping\n");
+}
index 88fa13fc067c2b1553c532ff5d05a6d163829106..2f6fe48ab53d3527df02f2589e7ca3d63d4117c8 100644 (file)
@@ -122,6 +122,7 @@ struct gpt_params {
     bool numa              = false; // attempt optimizations that help on some NUMA systems
     bool verbose_prompt    = false; // print prompt tokens before generation
     bool infill            = false; // use infill mode
+    bool dump_kv_cache     = false; // dump the KV cache contents for debugging purposes
 
     // multimodal models (see examples/llava)
     std::string mmproj = ""; // path to multimodal projector
@@ -218,3 +219,13 @@ std::string get_sortable_timestamp();
 void dump_non_result_info_yaml(
     FILE * stream, const gpt_params & params, const llama_context * lctx,
     const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc);
+
+//
+// KV cache utils
+//
+
+// Dump the KV cache view with the number of sequences per cell.
+void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);
+
+// Dump the KV cache view showing individual sequences in each cell (long output).
+void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
index 9b89bdfec78cef6a04b4e068281826803124a433..d2e074d9e12b0a3c1539a9bdd7f5f22310826988 100644 (file)
@@ -113,6 +113,8 @@ int main(int argc, char ** argv) {
     // insert new requests as soon as the previous one is done
     const bool cont_batching = params.cont_batching;
 
+    const bool dump_kv_cache = params.dump_kv_cache;
+
 #ifndef LOG_DISABLE_LOGS
     log_set_target(log_filename_generator("parallel", "log"));
     LOG_TEE("Log start\n");
@@ -172,6 +174,8 @@ int main(int argc, char ** argv) {
     int32_t n_total_gen    = 0;
     int32_t n_cache_miss   = 0;
 
+    struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, n_clients);
+
     const auto t_main_start = ggml_time_us();
 
     LOG_TEE("%s: Simulating parallel requests from clients:\n", __func__);
@@ -201,6 +205,11 @@ int main(int argc, char ** argv) {
     LOG_TEE("Processing requests ...\n\n");
 
     while (true) {
+        if (dump_kv_cache) {
+            llama_kv_cache_view_update(ctx, &kvc_view);
+            dump_kv_cache_view_seqs(kvc_view, 40);
+        }
+
         llama_batch_clear(batch);
 
         // decode any currently ongoing sequences
index c2ad0486994727add8f6b1e7aa5441855bdef8a9..9fb7244b41cf52b319e6c635c65e738fdfb207b2 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1280,6 +1280,7 @@ struct llama_kv_cache {
     // cannot be freely changed after a slot has been allocated.
     uint32_t head = 0;
     uint32_t size = 0;
+    uint32_t used = 0; // used cells (i.e. at least one seq_id)
 
     // computed before each graph build
     uint32_t n = 0;
@@ -1504,6 +1505,7 @@ static bool llama_kv_cache_init(
 
     cache.head = 0;
     cache.size = n_ctx;
+    cache.used = 0;
 
     cache.cells.clear();
     cache.cells.resize(n_ctx);
@@ -1605,6 +1607,8 @@ static bool llama_kv_cache_find_slot(
         }
     }
 
+    cache.used += n_tokens;
+
     return true;
 }
 
@@ -1625,6 +1629,7 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
         cache.cells[i].seq_id.clear();
     }
     cache.head = 0;
+    cache.used = 0;
 }
 
 static void llama_kv_cache_seq_rm(
@@ -1647,6 +1652,9 @@ static void llama_kv_cache_seq_rm(
                 continue;
             }
             if (cache.cells[i].seq_id.empty()) {
+                // keep count of the number of used cells
+                if (cache.cells[i].pos >= 0) cache.used--;
+
                 cache.cells[i].pos = -1;
                 if (new_head == cache.size) new_head = i;
             }
@@ -1654,7 +1662,7 @@ static void llama_kv_cache_seq_rm(
     }
 
     // If we freed up a slot, set head to it so searching can start there.
-    if (new_head != cache.size) cache.head = new_head;
+    if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
 }
 
 static void llama_kv_cache_seq_cp(
@@ -1680,6 +1688,7 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
 
     for (uint32_t i = 0; i < cache.size; ++i) {
         if (!cache.cells[i].has_seq_id(seq_id)) {
+            if (cache.cells[i].pos >= 0) cache.used--;
             cache.cells[i].pos = -1;
             cache.cells[i].seq_id.clear();
             if (new_head == cache.size) new_head = i;
@@ -1690,7 +1699,7 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
     }
 
     // If we freed up a slot, set head to it so searching can start there.
-    if (new_head != cache.size) cache.head = new_head;
+    if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
 }
 
 static void llama_kv_cache_seq_shift(
@@ -1711,6 +1720,7 @@ static void llama_kv_cache_seq_shift(
             cache.cells[i].delta += delta;
 
             if (cache.cells[i].pos < 0) {
+                if (!cache.cells[i].seq_id.empty()) cache.used--;
                 cache.cells[i].pos = -1;
                 cache.cells[i].seq_id.clear();
                 if (new_head == cache.size) new_head = i;
@@ -5469,6 +5479,12 @@ static int llama_decode_internal(
         batch.seq_id = seq_id_arr.data();
     }
 
+    // if we have enough unused cells before the current head ->
+    //   better to start searching from the beginning of the cache, hoping to fill it
+    if (kv_self.head > kv_self.used + 2*n_tokens) {
+        kv_self.head = 0;
+    }
+
     if (!llama_kv_cache_find_slot(kv_self, batch)) {
         return 1;
     }
@@ -5479,7 +5495,7 @@ static int llama_decode_internal(
     //kv_self.n = std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32));   // TODO: this might be better for CUDA?
     kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self)));
 
-    //printf("kv_self.n = %d\n", kv_self.n);
+    //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
 
     ggml_allocr_reset(lctx.alloc);
 
@@ -8789,8 +8805,107 @@ int llama_model_apply_lora_from_file(const struct llama_model * model, const cha
     }
 }
 
+struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq) {
+    struct llama_kv_cache_view result = {
+        /*.n_cells            = */ 0,
+        /*.n_max_seq          = */ n_max_seq,
+        /*.token_count        = */ 0,
+        /*.used_cells         = */ llama_get_kv_cache_used_cells(ctx),
+        /*.max_contiguous     = */ 0,
+        /*.max_contiguous_idx = */ -1,
+        /*.cells              = */ nullptr,
+        /*.cells_sequences    = */ nullptr,
+    };
+    return result;
+}
+
+void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
+    if (view->cells != nullptr) {
+        free(view->cells);
+        view->cells = nullptr;
+    }
+    if (view->cells_sequences != nullptr) {
+        free(view->cells_sequences);
+        view->cells_sequences = nullptr;
+    }
+}
+
+void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) {
+    if (uint32_t(view->n_cells) < ctx->kv_self.size || view->cells == nullptr) {
+        view->n_cells = int32_t(ctx->kv_self.size);
+        void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
+        GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
+        view->cells = (struct llama_kv_cache_view_cell *)p;
+        p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_max_seq * view->n_cells);
+        GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
+        view->cells_sequences = (llama_seq_id *)p;
+    }
+
+    const std::vector<llama_kv_cell> & kv_cells = ctx->kv_self.cells;
+    llama_kv_cache_view_cell * c_curr = view->cells;
+    llama_seq_id * cs_curr = view->cells_sequences;
+    int32_t used_cells = 0;
+    int32_t token_count = 0;
+    int32_t curr_contig_idx = -1;
+    uint32_t max_contig = 0;
+    int32_t max_contig_idx = -1;
+
+    for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_max_seq) {
+        const size_t curr_size = kv_cells[i].seq_id.size();
+        token_count += curr_size;
+        c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
+
+        if (curr_size > 0) {
+            if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) {
+                max_contig = i - curr_contig_idx;
+                max_contig_idx = curr_contig_idx;
+            }
+            curr_contig_idx = -1;
+        } else if (curr_contig_idx < 0) {
+            curr_contig_idx = i;
+        }
+
+        int seq_idx = 0;
+        for (const llama_seq_id it : kv_cells[i].seq_id) {
+            if (seq_idx >= view->n_max_seq) {
+                break;
+            }
+            cs_curr[seq_idx] = it;
+            seq_idx++;
+        }
+        if (seq_idx != 0) {
+            used_cells++;
+        }
+        for (; seq_idx < view->n_max_seq; seq_idx++) {
+            cs_curr[seq_idx] = -1;
+        }
+    }
+    if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) {
+        max_contig_idx = curr_contig_idx;
+        max_contig = kv_cells.size() - curr_contig_idx;
+    }
+    view->max_contiguous = max_contig;
+    view->max_contiguous_idx = max_contig_idx;
+    view->token_count = token_count;
+    view->used_cells = used_cells;
+    if (uint32_t(used_cells) != ctx->kv_self.used) {
+        LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
+            __func__, ctx->kv_self.used, used_cells);
+    }
+}
+
 int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
-    return ctx->kv_self.head;
+    int result = 0;
+
+    for (uint32_t i = 0; i < ctx->kv_self.size; i++) {
+        result += ctx->kv_self.cells[i].seq_id.size();
+    }
+
+    return result;
+}
+
+int llama_get_kv_cache_used_cells(const struct llama_context * ctx) {
+    return ctx->kv_self.used;
 }
 
 void llama_kv_cache_clear(struct llama_context * ctx) {
@@ -8960,10 +9075,12 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
         const size_t   kv_buf_size = kv_self.buf.size;
         const uint32_t kv_head     = kv_self.head;
         const uint32_t kv_size     = kv_self.size;
+        const uint32_t kv_used     = kv_self.used;
 
         data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
         data_ctx->write(&kv_head,     sizeof(kv_head));
         data_ctx->write(&kv_size,     sizeof(kv_size));
+        data_ctx->write(&kv_used,     sizeof(kv_used));
 
         if (kv_buf_size) {
             const size_t elt_size = ggml_element_size(kv_self.k);
@@ -9086,10 +9203,12 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
         size_t   kv_buf_size;
         uint32_t kv_head;
         uint32_t kv_size;
+        uint32_t kv_used;
 
         memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size);
         memcpy(&kv_head,     inp, sizeof(kv_head));     inp += sizeof(kv_head);
         memcpy(&kv_size,     inp, sizeof(kv_size));     inp += sizeof(kv_size);
+        memcpy(&kv_used,     inp, sizeof(kv_used));     inp += sizeof(kv_used);
 
         if (kv_buf_size) {
             GGML_ASSERT(kv_self.buf.size == kv_buf_size);
@@ -9124,6 +9243,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
 
         ctx->kv_self.head = kv_head;
         ctx->kv_self.size = kv_size;
+        ctx->kv_self.used = kv_used;
 
         ctx->kv_self.cells.resize(kv_size);
 
diff --git a/llama.h b/llama.h
index 70e8fda4bf1b3522b35dc14ea11120824a0095a8..1a62058d1406bc32e9a97f1dc05c53038ad2ecf5 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -361,9 +361,60 @@ extern "C" {
     // KV cache
     //
 
-    // Returns the number of tokens in the KV cache
-    LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
-            "avoid using this, it will be removed in the future, instead - count the tokens in user code");
+    // Information associated with an individual cell in the KV cache view.
+    struct llama_kv_cache_view_cell {
+        // The position for this cell. Takes KV cache shifts into account.
+        // May be negative if the cell is not populated.
+        llama_pos pos;
+    };
+
+    // An updateable view of the KV cache.
+    struct llama_kv_cache_view {
+        // Number of KV cache cells. This will be the same as the context size.
+        int32_t n_cells;
+
+        // Maximum number of sequences that can exist in a cell. It's not an error
+        // if there are more sequences in a cell than this value, however they will
+        // not be visible in the view cells_sequences.
+        int32_t n_max_seq;
+
+        // Number of tokens in the cache. For example, if there are two populated
+        // cells, the first with 1 sequence id in it and the second with 2 sequence
+        // ids then you'll have 3 tokens.
+        int32_t token_count;
+
+        // Number of populated cache cells.
+        int32_t used_cells;
+
+        // Maximum contiguous empty slots in the cache.
+        int32_t max_contiguous;
+
+        // Index to the start of the max_contiguous slot range. Can be negative
+        // when cache is full.
+        int32_t max_contiguous_idx;
+
+        // Information for an individual cell.
+        struct llama_kv_cache_view_cell * cells;
+
+        // The sequences for each cell. There will be n_max_seq items per cell.
+        llama_seq_id * cells_sequences;
+    };
+
+    // Create an empty KV cache view. (use only for debugging purposes)
+    LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq);
+
+    // Free a KV cache view. (use only for debugging purposes)
+    LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
+
+    // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
+    LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
+
+    // Returns the number of tokens in the KV cache (slow, use only for debug)
+    // If a KV cell has multiple sequences assigned to it, it will be counted multiple times
+    LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
+
+    // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
+    LLAMA_API int llama_get_kv_cache_used_cells(const struct llama_context * ctx);
 
     // Clear the KV cache
     LLAMA_API void llama_kv_cache_clear(