]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Extend llama_kv_cache_seq_rm to allow matching any sequence (#3843)
authorKerfuffle <redacted>
Sun, 29 Oct 2023 17:31:40 +0000 (11:31 -0600)
committerGitHub <redacted>
Sun, 29 Oct 2023 17:31:40 +0000 (11:31 -0600)
* Extend llama_kv_cache_seq_rm to allow matichng any sequence

* Replace llama_kv_cache_tokens_rm with llama_kv_cache_clear

Use llama_kv_cache_clear for cache clearing

Change calls to llama_kv_cache_tokens_rm that want to delete by position to use llama_kv_cache_seq_rm functionality

common/common.cpp
examples/batched-bench/batched-bench.cpp
examples/llama-bench/llama-bench.cpp
examples/main/main.cpp
examples/perplexity/perplexity.cpp
examples/server/server.cpp
llama.cpp
llama.h

index f81f4d354bc01755c0de2af0097e28888ae0a3c1..c187128d6ede3d576611d9c568385155072326f4 100644 (file)
@@ -889,7 +889,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
 
         std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
         llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
-        llama_kv_cache_tokens_rm(lctx, -1, -1);
+        llama_kv_cache_clear(lctx);
         llama_reset_timings(lctx);
     }
 
index 43f9c971d18465fa7274b6bb40566eef8e63623e..533c55c17aad17da651014866a30b6d63a07ec3a 100644 (file)
@@ -185,7 +185,7 @@ int main(int argc, char ** argv) {
 
                 const auto t_pp_start = ggml_time_us();
 
-                llama_kv_cache_tokens_rm(ctx, -1, -1);
+                llama_kv_cache_clear(ctx);
 
                 if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
                     LOG_TEE("%s: llama_decode() failed\n", __func__);
index 20767d555b206d08d62f3641e39405d6b297dcef..780398184d2215929b005f2ccc919a42608c1445 100644 (file)
@@ -1037,7 +1037,7 @@ int main(int argc, char ** argv) {
 
         test t(inst, lmodel, ctx);
 
-        llama_kv_cache_tokens_rm(ctx, -1, -1);
+        llama_kv_cache_clear(ctx);
 
         // warmup run
         if (t.n_prompt > 0) {
@@ -1048,7 +1048,7 @@ int main(int argc, char ** argv) {
         }
 
         for (int i = 0; i < params.reps; i++) {
-            llama_kv_cache_tokens_rm(ctx, -1, -1);
+            llama_kv_cache_clear(ctx);
 
             uint64_t t_start = get_time_ns();
             if (t.n_prompt > 0) {
index 3d9f670b9da7f2b2eedb099f55fae5c6cf8e416a..8a43b6ab878a5f930ecf663508a2e00c7a8a0317 100644 (file)
@@ -298,7 +298,7 @@ int main(int argc, char ** argv) {
         }
 
         // remove any "future" tokens that we might have inherited from the previous session
-        llama_kv_cache_tokens_rm(ctx, n_matching_session_tokens, -1);
+        llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
     }
 
     LOGLN(
index 3c2542e8c105e73ff41059003a2321d3080d537d..bd2c73d87875fe321ec79bbd7bfd8517d4ec5402 100644 (file)
@@ -210,7 +210,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
         const auto t_start = std::chrono::high_resolution_clock::now();
 
         // clear the KV cache
-        llama_kv_cache_tokens_rm(ctx, -1, -1);
+        llama_kv_cache_clear(ctx);
 
         for (int j = 0; j < num_batches; ++j) {
             const int batch_start = start + j * n_batch;
@@ -339,7 +339,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
         const auto t_start = std::chrono::high_resolution_clock::now();
 
         // clear the KV cache
-        llama_kv_cache_tokens_rm(ctx, -1, -1);
+        llama_kv_cache_clear(ctx);
 
         for (int j = 0; j < num_batches; ++j) {
             const int batch_start = start + j * n_batch;
@@ -573,7 +573,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
         }
 
         // clear the KV cache
-        llama_kv_cache_tokens_rm(ctx, -1, -1);
+        llama_kv_cache_clear(ctx);
 
         auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab);
         if (logits.empty()) {
index 5b7e4139de551e2d42b9d24241a4e3c4ec9fd42b..c163c7f8ec0dd12c306794ca8128c0184c997eb7 100644 (file)
@@ -857,7 +857,7 @@ struct llama_server_context
 
     void kv_cache_clear() {
         // clear the entire KV cache
-        llama_kv_cache_tokens_rm(ctx, -1, -1);
+        llama_kv_cache_clear(ctx);
         clean_kv_cache = false;
     }
 
index d8510a5cf01254f638cb6dda404eac6ccca7a68b..a4340d5277b09c0c63185f5e09ae9044295e07ac 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1466,17 +1466,12 @@ static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
     return 0;
 }
 
-static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0, int32_t c1) {
-    if (c0 < 0) c0 = 0;
-    if (c1 < 0) c1 = cache.size;
-
-    for (int32_t i = c0; i < c1; ++i) {
+static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
+    for (int32_t i = 0; i < cache.size; ++i) {
         cache.cells[i].pos = -1;
         cache.cells[i].seq_id.clear();
     }
-
-    // Searching for a free slot can start here since we know it will be empty.
-    cache.head = uint32_t(c0);
+    cache.head = 0;
 }
 
 static void llama_kv_cache_seq_rm(
@@ -1490,8 +1485,14 @@ static void llama_kv_cache_seq_rm(
     if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
 
     for (uint32_t i = 0; i < cache.size; ++i) {
-        if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
-            cache.cells[i].seq_id.erase(seq_id);
+        if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
+            if (seq_id < 0) {
+                cache.cells[i].seq_id.clear();
+            } else if (cache.cells[i].has_seq_id(seq_id)) {
+                cache.cells[i].seq_id.erase(seq_id);
+            } else {
+                continue;
+            }
             if (cache.cells[i].seq_id.empty()) {
                 cache.cells[i].pos = -1;
                 if (new_head == cache.size) new_head = i;
@@ -9207,8 +9208,8 @@ int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
     return ctx->kv_self.head;
 }
 
-void llama_kv_cache_tokens_rm(struct llama_context * ctx, int32_t c0, int32_t c1) {
-    llama_kv_cache_tokens_rm(ctx->kv_self, c0, c1);
+void llama_kv_cache_clear(struct llama_context * ctx) {
+    llama_kv_cache_clear(ctx->kv_self);
 }
 
 void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
@@ -9654,7 +9655,7 @@ int llama_eval(
                  llama_token * tokens,
                      int32_t   n_tokens,
                          int   n_past) {
-    llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
+    llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1);
 
     const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0));
     if (ret < 0) {
@@ -9669,7 +9670,7 @@ int llama_eval_embd(
                            float * embd,
                          int32_t   n_tokens,
                              int   n_past) {
-    llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
+    llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1);
 
     llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, };
 
diff --git a/llama.h b/llama.h
index 6927bd6010dd7a32ab8187564129c7451203ed3c..d727dbd9fd915dbe4dd46c6cc6c77c5536e93431 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -334,17 +334,14 @@ extern "C" {
     LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
             "avoid using this, it will be removed in the future, instead - count the tokens in user code");
 
-    // Remove all tokens data of cells in [c0, c1)
-    // c0 < 0 : [0,  c1]
-    // c1 < 0 : [c0, inf)
-    LLAMA_API void llama_kv_cache_tokens_rm(
-            struct llama_context * ctx,
-                         int32_t   c0,
-                         int32_t   c1);
+    // Clear the KV cache
+    LLAMA_API void llama_kv_cache_clear(
+            struct llama_context * ctx);
 
     // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
-    // p0 < 0 : [0,  p1]
-    // p1 < 0 : [p0, inf)
+    // seq_id < 0 : match any sequence
+    // p0 < 0     : [0,  p1]
+    // p1 < 0     : [p0, inf)
     LLAMA_API void llama_kv_cache_seq_rm(
             struct llama_context * ctx,
                     llama_seq_id   seq_id,