]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : less KV padding when FA is off (#7257)
authorGeorgi Gerganov <redacted>
Mon, 13 May 2024 14:15:15 +0000 (17:15 +0300)
committerGitHub <redacted>
Mon, 13 May 2024 14:15:15 +0000 (17:15 +0300)
ggml-ci

llama.cpp

index adbcc07e20fc5079684fbba0b1c75e55a18490db..202bf94c80e144c9c4d3e280197d9b87f00ae6ba 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -2805,6 +2805,11 @@ static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
     cache.do_defrag = true;
 }
 
+static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) {
+    // the FA kernels require padding to avoid extra runtime boundary checks
+    return cparams.flash_attn ? 256u : 32u;
+}
+
 //
 // model loading and saving
 //
@@ -11510,7 +11515,8 @@ static int llama_decode_internal(
                 // a heuristic, to avoid attending the full cache if it is not yet utilized
                 // after enough generations, the benefit from this heuristic disappears
                 // if we start defragmenting the cache, the benefit from this will be more important
-                kv_self.n = std::min(kv_self.size, std::max(256u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 256)));
+                const uint32_t pad = llama_kv_cache_get_padding(cparams);
+                kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
                 //kv_self.n = llama_kv_cache_cell_max(kv_self);
             }
         }
@@ -15511,6 +15517,11 @@ struct llama_context * llama_new_context_with_model(
         return nullptr;
     }
 
+    if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
+        LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
+        params.flash_attn = false;
+    }
+
     llama_context * ctx = new llama_context(*model);
 
     const auto & hparams = model->hparams;
@@ -15534,7 +15545,7 @@ struct llama_context * llama_new_context_with_model(
     cparams.rope_freq_scale  = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
 
     // this is necessary due to kv_self.n being padded later during inference
-    cparams.n_ctx            = GGML_PAD(cparams.n_ctx, 256);
+    cparams.n_ctx            = GGML_PAD(cparams.n_ctx, llama_kv_cache_get_padding(cparams));
 
     // with causal attention, the batch size is limited by the context size
     cparams.n_batch          = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
@@ -15579,11 +15590,6 @@ struct llama_context * llama_new_context_with_model(
         }
     }
 
-    if (cparams.flash_attn && model->arch == LLM_ARCH_GROK) {
-        LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
-        cparams.flash_attn = false;
-    }
-
     if (params.seed == LLAMA_DEFAULT_SEED) {
         params.seed = time(NULL);
     }