From: Georgi Gerganov Date: Fri, 5 Sep 2025 07:39:22 +0000 (+0300) Subject: kv-cache : fix SWA checks + disable cacheless iSWA (#15811) X-Git-Tag: upstream/0.0.6527~138 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=c610b6c11b1ef7d678671dcf15acd7187a7ad8f3;p=pkg%2Fggml%2Fsources%2Fllama.cpp kv-cache : fix SWA checks + disable cacheless iSWA (#15811) ggml-ci --- diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 7ce2960e..4abb6008 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -297,6 +297,9 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { float * data = (float *) kq_mask->data; + // [TAG_NO_CACHE_ISWA] + GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement"); + for (int h = 0; h < 1; ++h) { for (int i1 = 0; i1 < n_tokens; ++i1) { const llama_seq_id s1 = ubatch->seq_id[i1][0]; @@ -315,9 +318,10 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { continue; // skip future tokens for causal attention } - if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) { - continue; // skip masked tokens for SWA - } + // TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA] + //if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) { + // continue; // skip masked tokens for SWA + //} // TODO: reimplement this like in llama_kv_cache_unified if (hparams.use_alibi) { diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 4b7a6536..c04ac58f 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -180,7 +180,7 @@ uint32_t llama_hparams::n_layer_kv() const { return res; } -bool llama_hparams::is_masked_swa(llama_pos p0, llama_pos p1) const { +bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) { assert(p0 >= 0 && p1 >= 0); switch (swa_type) { diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 06d1e51d..89f5c7ab 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -229,7 +229,10 @@ struct llama_hparams { // number of layers for which has_kv() returns true uint32_t n_layer_kv() const; - bool is_masked_swa(llama_pos p0, llama_pos p1) const; + // note that this function uses different SWA parameters from those in the hparams + // TODO: think of a better place for this function + // TODO: pack the SWA params in a struct? + static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1); }; static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable"); diff --git a/src/llama-kv-cache-iswa.cpp b/src/llama-kv-cache-iswa.cpp index 7c51a198..d7342914 100644 --- a/src/llama-kv-cache-iswa.cpp +++ b/src/llama-kv-cache-iswa.cpp @@ -60,14 +60,14 @@ llama_kv_cache_iswa::llama_kv_cache_iswa( kv_base = std::make_unique( model, type_k, type_v, v_trans, offload, unified, size_base, n_seq_max, n_pad, - 0, filter_base, reuse); + 0, LLAMA_SWA_TYPE_NONE, filter_base, reuse); LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); kv_swa = std::make_unique( model, type_k, type_v, v_trans, offload, unified, size_swa, n_seq_max, n_pad, - hparams.n_swa, filter_swa, reuse); + hparams.n_swa, hparams.swa_type, filter_swa, reuse); } void llama_kv_cache_iswa::clear(bool data) { diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 1564faed..ae35f742 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -27,10 +27,11 @@ llama_kv_cache::llama_kv_cache( uint32_t n_seq_max, uint32_t n_pad, uint32_t n_swa, + llama_swa_type swa_type, const layer_filter_cb & filter, const layer_reuse_cb & reuse) : model(model), hparams(model.hparams), v_trans(v_trans), - n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa) { + n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { GGML_ASSERT(kv_size % n_pad == 0); @@ -1392,7 +1393,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co } bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const { - return hparams.is_masked_swa(p0, p1); + return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1); } void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 55ee355b..b545bf5b 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -89,6 +89,7 @@ public: uint32_t n_seq_max, uint32_t n_pad, uint32_t n_swa, + llama_swa_type swa_type, const layer_filter_cb & filter, const layer_reuse_cb & reuse); @@ -211,6 +212,9 @@ private: // env: LLAMA_KV_CACHE_DEBUG int debug = 0; + // this is the SWA type of the cache - not to be confused with the model SWA type + const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; + std::vector ctxs; std::vector bufs; diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index 38f92da1..ba61ebaa 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -17,6 +17,7 @@ llama_memory_hybrid::llama_memory_hybrid( uint32_t kv_size, uint32_t n_pad, uint32_t n_swa, + llama_swa_type swa_type, /* recurrent */ ggml_type type_r, ggml_type type_s, @@ -40,6 +41,7 @@ llama_memory_hybrid::llama_memory_hybrid( n_seq_max, n_pad, n_swa, + swa_type, filter_attn == nullptr ? [&](int32_t il) { return !hparams.is_recurrent(il); } : filter_attn, diff --git a/src/llama-memory-hybrid.h b/src/llama-memory-hybrid.h index 0eb63f5e..11a35651 100644 --- a/src/llama-memory-hybrid.h +++ b/src/llama-memory-hybrid.h @@ -27,6 +27,7 @@ public: uint32_t kv_size, uint32_t n_pad, uint32_t n_swa, + llama_swa_type swa_type, /* recurrent */ ggml_type type_r, ggml_type type_s, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index afc4cb48..1813f06d 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -11084,7 +11084,8 @@ struct llm_build_gemma_embedding_iswa : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_no_cache(); + // TODO: support cacheless iSWA embeddings [TAG_NO_CACHE_ISWA] + auto * inp_attn = build_attn_inp_kv_iswa(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -18632,7 +18633,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_NEO_BERT: case LLM_ARCH_WAVTOKENIZER_DEC: - case LLM_ARCH_GEMMA_EMBEDDING: + //case LLM_ARCH_GEMMA_EMBEDDING: // TODO: disabled until the cacheless SWA logic is fixed [TAG_NO_CACHE_ISWA] case LLM_ARCH_DREAM: case LLM_ARCH_LLADA: { @@ -18681,6 +18682,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* attn_kv_size */ cparams.n_ctx, /* attn_n_pad */ padding, /* attn_n_swa */ hparams.n_swa, + /* attn_swa_type */ hparams.swa_type, /* recurrent_type_k */ GGML_TYPE_F32, /* recurrent_type_v */ GGML_TYPE_F32, /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), @@ -18750,6 +18752,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.n_seq_max, padding, hparams.n_swa, + hparams.swa_type, nullptr, nullptr); }