]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : handle KV shift for recurrent models (#10402)
authorGeorgi Gerganov <redacted>
Thu, 21 Nov 2024 08:22:47 +0000 (10:22 +0200)
committerGitHub <redacted>
Thu, 21 Nov 2024 08:22:47 +0000 (10:22 +0200)
ggml-ci

src/llama.cpp

index c51b36e66042e214553fb9d08e24516dc4c5e598..001711037d5d19a12747d9a567fe713cadeab9dd 100644 (file)
@@ -18211,13 +18211,13 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
 static void llama_kv_cache_update_internal(struct llama_context & lctx) {
     bool need_reserve = false;
 
-    // apply K-shift if needed
-    if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
+    if (lctx.kv_self.has_shift) {
         if (!llama_kv_cache_can_shift(&lctx)) {
-            GGML_ABORT("Deepseek2 does not support K-shift");
+            GGML_ABORT("The current context does not support K-shift");
         }
 
-        {
+        // apply K-shift if needed
+        if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
             ggml_backend_sched_reset(lctx.sched.get());
 
             ggml_cgraph * gf = llama_build_graph_k_shift(lctx);
@@ -20463,7 +20463,7 @@ void llama_kv_cache_update(struct llama_context * ctx) {
 }
 
 bool llama_kv_cache_can_shift(struct llama_context * ctx) {
-    return ctx->model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
+    return !ctx->kv_self.recurrent && ctx->model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
 }
 
 // deprecated