]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add check for KV cache shifts (#10401)
authorGeorgi Gerganov <redacted>
Tue, 19 Nov 2024 11:29:26 +0000 (13:29 +0200)
committerGitHub <redacted>
Tue, 19 Nov 2024 11:29:26 +0000 (13:29 +0200)
ggml-ci

common/common.cpp
include/llama.h
src/llama.cpp

index 93037462127d519c77b7b0109110bbdd3fd97dd3..d314523db4c62510842c6d84beaf10a67a7a9355 100644 (file)
@@ -875,6 +875,12 @@ struct common_init_result common_init_from_params(common_params & params) {
         return iparams;
     }
 
+    if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) {
+        LOG_ERR("%s: KV cache shifting is not supported for this model (--no-context-shift to disable)'\n", __func__);
+        llama_free_model(model);
+        return iparams;
+    }
+
     if (!params.control_vectors.empty()) {
         if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1;
         if (params.control_vector_layer_end   <= 0) params.control_vector_layer_end   = llama_n_layer(model);
index bc268e7996baa57088b80b4876eec97abb3d619f..90791d5f5ea126a53840030d9e95348739ba4360 100644 (file)
@@ -667,6 +667,9 @@ extern "C" {
     // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
     LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
 
+    // Check if the context supports KV cache shifting
+    LLAMA_API bool llama_kv_cache_can_shift(struct llama_context * ctx);
+
     //
     // State / sessions
     //
index 4f31f25b1d355d0246a3d4c389da3c9e42f3d257..c51b36e66042e214553fb9d08e24516dc4c5e598 100644 (file)
@@ -18213,7 +18213,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
 
     // apply K-shift if needed
     if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
-        if (lctx.model.arch == LLM_ARCH_DEEPSEEK2) { // not supported due to MLA
+        if (!llama_kv_cache_can_shift(&lctx)) {
             GGML_ABORT("Deepseek2 does not support K-shift");
         }
 
@@ -20462,6 +20462,10 @@ void llama_kv_cache_update(struct llama_context * ctx) {
     llama_kv_cache_update_internal(*ctx);
 }
 
+bool llama_kv_cache_can_shift(struct llama_context * ctx) {
+    return ctx->model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
+}
+
 // deprecated
 size_t llama_get_state_size(struct llama_context * ctx) {
     return llama_state_get_size(ctx);