]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Added api for getting/setting the kv_cache (#685)
authorChristian Falch <redacted>
Sun, 2 Apr 2023 10:23:04 +0000 (12:23 +0200)
committerGitHub <redacted>
Sun, 2 Apr 2023 10:23:04 +0000 (12:23 +0200)
The api provides access methods for retrieving the current memory buffer for the kv_cache and its token number.
It also contains a method for setting the kv_cache from a memory buffer.

This makes it possible to load/save history - maybe support --cache-prompt paramater as well?

Co-authored-by: Pavol Rusnak <redacted>
llama.cpp
llama.h

index b0f53ca62b934cb8c42c7afd267cb4df15441396..878907185c4a007983219ba258255a675022d9ca 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1668,6 +1668,33 @@ int llama_model_quantize(
     return 0;
 }
 
+// Returns the KV cache that will contain the context for the
+// ongoing prediction with the model.
+const uint8_t * llama_get_kv_cache(struct llama_context * ctx) {
+    return ctx->model.kv_self.buf.data();
+}
+
+// Returns the size of the KV cache
+size_t llama_get_kv_cache_size(struct llama_context * ctx) {
+    return ctx->model.kv_self.buf.size();
+}
+
+int llama_get_kv_cache_token_count(struct llama_context * ctx) {
+    return ctx->model.kv_self.n;
+}
+
+// Sets the KV cache containing the current context for the model
+void llama_set_kv_cache(
+        struct llama_context * ctx,
+               const uint8_t * kv_cache,
+                      size_t   n_size,
+                         int   n_token_count) {
+    // Make sure we have the same kv cache setup
+    LLAMA_ASSERT(ctx->model.kv_self.buf.size() == n_size);
+    memcpy(ctx->model.kv_self.buf.data(), kv_cache, n_size);
+    ctx->model.kv_self.n = n_token_count;
+}
+
 int llama_eval(
         struct llama_context * ctx,
            const llama_token * tokens,
diff --git a/llama.h b/llama.h
index 258de5a94497639d2ed2fbda7ef76320d670147a..04e2bf71cd9c017128b851e370905eb1a8d91a60 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -83,6 +83,23 @@ extern "C" {
             const char * fname_out,
                    int   itype);
 
+    // Returns the KV cache that will contain the context for the
+    // ongoing prediction with the model.
+    LLAMA_API const uint8_t * llama_get_kv_cache(struct llama_context * ctx);
+
+    // Returns the size of the KV cache
+    LLAMA_API size_t llama_get_kv_cache_size(struct llama_context * ctx);
+
+    // Returns the number of tokens in the KV cache
+    LLAMA_API int llama_get_kv_cache_token_count(struct llama_context * ctx);
+
+    // Sets the KV cache containing the current context for the model
+    LLAMA_API void llama_set_kv_cache(
+            struct llama_context * ctx,
+                   const uint8_t * kv_cache,
+                          size_t   n_size,
+                             int   n_token_count);
+
     // Run the llama inference to obtain the logits and probabilities for the next token.
     // tokens + n_tokens is the provided batch of new tokens to process
     // n_past is the number of tokens to use from previous eval calls