]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
talk-llama : sync llama.cpp
authorGeorgi Gerganov <redacted>
Tue, 27 May 2025 14:08:24 +0000 (17:08 +0300)
committerGeorgi Gerganov <redacted>
Tue, 27 May 2025 15:03:00 +0000 (18:03 +0300)
ggml-ci

18 files changed:
examples/talk-llama/llama-batch.cpp
examples/talk-llama/llama-context.cpp
examples/talk-llama/llama-cparams.cpp
examples/talk-llama/llama-cparams.h
examples/talk-llama/llama-grammar.cpp
examples/talk-llama/llama-graph.cpp
examples/talk-llama/llama-graph.h
examples/talk-llama/llama-hparams.cpp
examples/talk-llama/llama-hparams.h
examples/talk-llama/llama-kv-cache.cpp
examples/talk-llama/llama-kv-cache.h
examples/talk-llama/llama-kv-cells.h [new file with mode: 0644]
examples/talk-llama/llama-memory.h
examples/talk-llama/llama-model.cpp
examples/talk-llama/llama-model.h
examples/talk-llama/llama-sampling.cpp
examples/talk-llama/llama-vocab.cpp
examples/talk-llama/llama.h

index a88b2fe3082c9447c1ed8cf66c90fb802855a959..b98e3256c390d7942644d30545fb3cea8af072f7 100644 (file)
@@ -1,5 +1,6 @@
 #include "llama-batch.h"
 
+#include <cassert>
 #include <cstring>
 #include <algorithm>
 
@@ -281,9 +282,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
     batch = in_batch;
     GGML_ASSERT(batch.n_tokens > 0);
     if (!batch.pos) {
+        assert(p0 >= 0);
         pos.resize(batch.n_tokens);
         for (int32_t i = 0; i < batch.n_tokens; i++) {
-            pos[i] = i + p0;
+            pos[i] = p0 + i;
         }
         batch.pos = pos.data();
     }
index a3b84a6a82e74dd1014f5abb88f6e95f0e3d9fe5..e153351af38093a0b788dfc576f80d8301db1946 100644 (file)
@@ -25,7 +25,11 @@ llama_context::llama_context(
 
     const auto & hparams = model.hparams;
 
-    cparams.n_seq_max        = std::max(1u, params.n_seq_max);
+    cparams.n_seq_max = std::max(1u, params.n_seq_max);
+    if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) {
+        throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES));
+    }
+
     cparams.n_threads        = params.n_threads;
     cparams.n_threads_batch  = params.n_threads_batch;
     cparams.yarn_ext_factor  = params.yarn_ext_factor;
@@ -93,6 +97,7 @@ llama_context::llama_context(
     }
 
     cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
+
     cparams.op_offload = params.op_offload;
 
     const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
@@ -176,8 +181,9 @@ llama_context::llama_context(
     // init the memory module
     if (!hparams.vocab_only) {
         llama_memory_params params_mem = {
-            /*.type_k =*/ params.type_k,
-            /*.type_v =*/ params.type_v,
+            /*.type_k   =*/ params.type_k,
+            /*.type_v   =*/ params.type_v,
+            /*.swa_full =*/ params.swa_full,
         };
 
         memory.reset(model.create_memory(params_mem, cparams));
@@ -687,12 +693,18 @@ int llama_context::encode(llama_batch & inp_batch) {
 
     GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
 
+    // TODO: move the validation to the llama_batch_allocr
     if (batch.token) {
         for (int32_t i = 0; i < n_tokens; ++i) {
             if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
                 LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
                 return -1;
             }
+
+            if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
+                LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
+                throw -1;
+            }
         }
     }
 
@@ -846,7 +858,7 @@ int llama_context::encode(llama_batch & inp_batch) {
 
 int llama_context::decode(llama_batch & inp_batch) {
     if (!memory) {
-        LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__);
+        LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
         return encode(inp_batch);
     }
 
@@ -855,11 +867,17 @@ int llama_context::decode(llama_batch & inp_batch) {
         return -1;
     }
 
+    if (!inp_batch.pos) {
+        if (inp_batch.seq_id) {
+            LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
+            return -1;
+        }
+    }
+
     llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
 
     // temporary allocate memory for the input batch if needed
-    // TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
-    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
+    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
 
     const llama_batch & batch = batch_allocr.batch;
 
@@ -875,11 +893,17 @@ int llama_context::decode(llama_batch & inp_batch) {
 
     GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
 
+    // TODO: move the validation to the llama_batch_allocr
     if (batch.token) {
         for (int64_t i = 0; i < n_tokens_all; ++i) {
             if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
                 LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
-                throw std::runtime_error("invalid token");
+                return -1;
+            }
+
+            if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
+                LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
+                return -1;
             }
         }
     }
@@ -947,8 +971,6 @@ int llama_context::decode(llama_batch & inp_batch) {
 
         // find KV slot
         if (!kv_self->find_slot(ubatch)) {
-            LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
-
             return 1;
         }
 
@@ -2093,6 +2115,7 @@ llama_context_params llama_context_default_params() {
         /*.flash_attn                  =*/ false,
         /*.no_perf                     =*/ true,
         /*.op_offload                  =*/ true,
+        /*.swa_full                    =*/ true,
     };
 
     return result;
@@ -2287,65 +2310,51 @@ int32_t llama_apply_adapter_cvec(
     return res ? 0 : -1;
 }
 
-//
-// kv cache view
-//
-
-llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) {
-    const auto * kv = ctx->get_kv_self();
-    if (kv == nullptr) {
-        LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
-        return {};
-    }
-
-    return llama_kv_cache_view_init(*kv, n_seq_max);
-}
-
-void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) {
-    const auto * kv = ctx->get_kv_self();
-    if (kv == nullptr) {
-        LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
-        return;
-    }
-
-    llama_kv_cache_view_update(view, kv);
-}
-
 //
 // kv cache
 //
 
 // deprecated
-int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
-    return llama_kv_self_n_tokens(ctx);
-}
-
 int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
     const auto * kv = ctx->get_kv_self();
     if (!kv) {
         return 0;
     }
 
-    return kv->get_n_tokens();
-}
+    int32_t res = 0;
 
-// deprecated
-int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
-    return llama_kv_self_used_cells(ctx);
+    for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
+        const llama_pos p0 = kv->seq_pos_min(s);
+        const llama_pos p1 = kv->seq_pos_max(s);
+
+        if (p0 >= 0) {
+            res += (p1 - p0) + 1;
+        }
+    }
+
+    return res;
 }
 
+// deprecated
+// note: this is the same as above - will be removed anyway, so it's ok
 int32_t llama_kv_self_used_cells(const llama_context * ctx) {
     const auto * kv = ctx->get_kv_self();
     if (!kv) {
         return 0;
     }
 
-    return kv->get_used_cells();
-}
+    int32_t res = 0;
 
-// deprecated
-void llama_kv_cache_clear(llama_context * ctx) {
-    llama_kv_self_clear(ctx);
+    for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
+        const llama_pos p0 = kv->seq_pos_min(s);
+        const llama_pos p1 = kv->seq_pos_max(s);
+
+        if (p0 >= 0) {
+            res += (p1 - p0) + 1;
+        }
+    }
+
+    return res;
 }
 
 void llama_kv_self_clear(llama_context * ctx) {
@@ -2357,15 +2366,6 @@ void llama_kv_self_clear(llama_context * ctx) {
     kv->clear();
 }
 
-// deprecated
-bool llama_kv_cache_seq_rm(
-        llama_context * ctx,
-         llama_seq_id   seq_id,
-            llama_pos   p0,
-            llama_pos   p1) {
-    return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
-}
-
 bool llama_kv_self_seq_rm(
         llama_context * ctx,
          llama_seq_id   seq_id,
@@ -2379,16 +2379,6 @@ bool llama_kv_self_seq_rm(
     return kv->seq_rm(seq_id, p0, p1);
 }
 
-// deprecated
-void llama_kv_cache_seq_cp(
-        llama_context * ctx,
-         llama_seq_id   seq_id_src,
-         llama_seq_id   seq_id_dst,
-            llama_pos   p0,
-            llama_pos   p1) {
-    llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
-}
-
 void llama_kv_self_seq_cp(
         llama_context * ctx,
          llama_seq_id   seq_id_src,
@@ -2403,13 +2393,6 @@ void llama_kv_self_seq_cp(
     kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
 }
 
-// deprecated
-void llama_kv_cache_seq_keep(
-        llama_context * ctx,
-         llama_seq_id   seq_id) {
-    llama_kv_self_seq_keep(ctx, seq_id);
-}
-
 void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
     auto * kv = ctx->get_kv_self();
     if (!kv) {
@@ -2419,16 +2402,6 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
     kv->seq_keep(seq_id);
 }
 
-// deprecated
-void llama_kv_cache_seq_add(
-        llama_context * ctx,
-         llama_seq_id   seq_id,
-            llama_pos   p0,
-            llama_pos   p1,
-            llama_pos   delta) {
-    llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
-}
-
 void llama_kv_self_seq_add(
         llama_context * ctx,
          llama_seq_id   seq_id,
@@ -2443,16 +2416,6 @@ void llama_kv_self_seq_add(
     kv->seq_add(seq_id, p0, p1, delta);
 }
 
-// deprecated
-void llama_kv_cache_seq_div(
-        llama_context * ctx,
-         llama_seq_id   seq_id,
-            llama_pos   p0,
-            llama_pos   p1,
-                  int   d) {
-    llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
-}
-
 void llama_kv_self_seq_div(
         llama_context * ctx,
          llama_seq_id   seq_id,
@@ -2467,25 +2430,24 @@ void llama_kv_self_seq_div(
     kv->seq_div(seq_id, p0, p1, d);
 }
 
-// deprecated
-llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
-    return llama_kv_self_seq_pos_max(ctx, seq_id);
+llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
+    const auto * kv = ctx->get_kv_self();
+    if (!kv) {
+        return -1;
+    }
+
+    return kv->seq_pos_min(seq_id);
 }
 
 llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
     const auto * kv = ctx->get_kv_self();
     if (!kv) {
-        return 0;
+        return -1;
     }
 
     return kv->seq_pos_max(seq_id);
 }
 
-// deprecated
-void llama_kv_cache_defrag(llama_context * ctx) {
-    llama_kv_self_defrag(ctx);
-}
-
 void llama_kv_self_defrag(llama_context * ctx) {
     auto * kv = ctx->get_kv_self();
     if (!kv) {
@@ -2496,11 +2458,6 @@ void llama_kv_self_defrag(llama_context * ctx) {
     kv->defrag_sched(-1.0f);
 }
 
-// deprecated
-bool llama_kv_cache_can_shift(const llama_context * ctx) {
-    return llama_kv_self_can_shift(ctx);
-}
-
 bool llama_kv_self_can_shift(const llama_context * ctx) {
     const auto * kv = ctx->get_kv_self();
     if (!kv) {
@@ -2510,11 +2467,6 @@ bool llama_kv_self_can_shift(const llama_context * ctx) {
     return kv->get_can_shift();
 }
 
-// deprecated
-void llama_kv_cache_update(llama_context * ctx) {
-    llama_kv_self_update(ctx);
-}
-
 // llama state API
 
 // deprecated
@@ -2637,7 +2589,21 @@ int32_t llama_encode(
 int32_t llama_decode(
         llama_context * ctx,
           llama_batch   batch) {
-    const int ret = ctx->decode(batch);
+    int ret = ctx->decode(batch);
+
+    // defrag and try again
+    // TODO: distinguish return code when we are sure that even after defrag there is no space available
+    if (ret == 1) {
+        llama_kv_self_defrag(ctx);
+        ret = ctx->decode(batch);
+
+        if (ret == 1) {
+            LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
+
+            return ret;
+        }
+    }
+
     if (ret != 0) {
         LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
     }
index 28369be365252724d929420f10f243bbfcd27937..f7b36590fe3e3f0a9dc9dfee9e7f749d2610aee7 100644 (file)
@@ -1 +1,5 @@
 #include "llama-cparams.h"
+
+size_t llama_max_parallel_sequences(void) {
+    return LLAMA_MAX_PARALLEL_SEQUENCES;
+}
index 246fa5777deea1f6d4b94581d9b07b258b9434a2..2871031ef09619bbce252126c1ddd13fb4681dcd 100644 (file)
@@ -4,6 +4,8 @@
 
 #include <cstdint>
 
+#define LLAMA_MAX_PARALLEL_SEQUENCES 64
+
 struct llama_cparams {
     uint32_t n_ctx;           // context size used during inference
     uint32_t n_batch;
index 973b47ae063b08a6faec73294a6ab05d78e4ac56..bed706bb248d139664d8024726948e9fb1ba4cb5 100644 (file)
@@ -1177,8 +1177,18 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
             for (const auto & trigger_pattern : grammar.trigger_patterns) {
                 if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
                     grammar.awaiting_trigger = false;
-                    // get from the first match to the end of the string
-                    auto constrained_str = grammar.trigger_buffer.substr(match.position(1));
+                    // get from the first matched capturing group to the end of the string
+                    size_t start = std::string::npos;
+                    for (auto i = 1u; i < match.size(); i++) {
+                        if (match.length(i) > 0) {
+                            start = match.position(i);
+                            break;
+                        }
+                    }
+                    if (start == std::string::npos) {
+                        start = match.position(0);
+                    }
+                    auto constrained_str = grammar.trigger_buffer.substr(start);
                     // std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
                     grammar.trigger_buffer.clear();
                     llama_grammar_accept_str(grammar, constrained_str);
index b0e3f63597a76d0481b77e949f823cda749fd561..cdd5887de961c6f622e6c507efcf2775ad75709a 100644 (file)
@@ -9,33 +9,6 @@
 #include <cmath>
 #include <cstring>
 
-static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
-    // TODO move to hparams if a T5 variant appears that uses a different value
-    const int64_t max_distance = 128;
-
-    if (bidirectional) {
-        n_buckets >>= 1;
-    }
-
-    const int64_t max_exact = n_buckets >> 1;
-
-    int32_t relative_position = x - y;
-    int32_t relative_bucket = 0;
-
-    if (bidirectional) {
-        relative_bucket += (relative_position > 0) * n_buckets;
-        relative_position = abs(relative_position);
-    } else {
-        relative_position = -std::min<int32_t>(relative_position, 0);
-    }
-
-    int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
-    relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
-    relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
-
-    return relative_bucket;
-}
-
 void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
     if (ubatch->token) {
         const int64_t n_tokens = ubatch->n_tokens;
@@ -110,22 +83,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
 
 void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
     if (pos_bucket) {
-        const int64_t n_tokens = ubatch->n_tokens;
-
-        GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
-        GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
-
-        int32_t * data = (int32_t *) pos_bucket->data;
-
-        const int64_t n_kv = kv_self->n;
-
-        for (int h = 0; h < 1; ++h) {
-            for (int j = 0; j < n_tokens; ++j) {
-                for (int i = 0; i < n_kv; ++i) {
-                    data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
-                }
-            }
-        }
+        kv_self->set_input_pos_bucket(pos_bucket, ubatch);
     }
 }
 
@@ -403,99 +361,18 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
 }
 
 void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
-    if (self_kq_mask || self_kq_mask_swa) {
-        const int64_t n_kv         = kv_self->n;
-        const int64_t n_tokens     = ubatch->n_tokens;
-        const int64_t n_seq_tokens = ubatch->n_seq_tokens;
-        const int64_t n_seqs       = ubatch->n_seqs;
-
-        float * data     = nullptr;
-        float * data_swa = nullptr;
-
-        if (self_kq_mask) {
-            GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
-            data = (float *) self_kq_mask->data;
-        }
-
-        if (self_kq_mask_swa) {
-            GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
-            data_swa = (float *) self_kq_mask_swa->data;
-        }
-
-        // Use only the previous KV cells of the correct sequence for each token of the ubatch.
-        // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
-        // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
-        //   Causal mask:
-        //      xxx-------
-        //      xxxx------
-        //      xxxxx-----
-        //   Non-causal mask:
-        //      xxxxx-----
-        //      xxxxx-----
-        //      xxxxx-----
-        // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
-        for (int h = 0; h < 1; ++h) {
-            for (int s = 0; s < n_seqs; ++s) {
-                const llama_seq_id seq_id = ubatch->seq_id[s][0];
-
-                for (int j = 0; j < n_seq_tokens; ++j) {
-                    const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
-                    for (int i = 0; i < n_kv; ++i) {
-                        float f;
-                        // mask the token if:
-                        if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
-                            || (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
-                        ) {
-                            f = -INFINITY;
-                        } else {
-                            if (hparams.use_alibi) {
-                                f = -std::abs(kv_self->cells[i].pos - pos);
-                            } else {
-                                f = 0.0f;
-                            }
-                        }
-
-                        if (data) {
-                            data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
-                        }
-
-                        // may need to cut off old tokens for sliding window
-                        // TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
-                        if (data_swa) {
-                            if (hparams.n_attn_chunk) {
-                                llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
-                                if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
-                                    f = -INFINITY;
-                                }
-                            } else {
-                                if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
-                                    f = -INFINITY;
-                                }
-                            }
-                            data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
-                        }
-                    }
-                }
-            }
+    if (self_kq_mask) {
+        kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
+    }
+}
 
-            // mask padded tokens
-            if (data) {
-                for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
-                    for (int j = 0; j < n_kv; ++j) {
-                        data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
-                    }
-                }
-            }
+void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
+    if (self_kq_mask) {
+        kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
+    }
 
-            // mask padded tokens
-            if (data_swa) {
-                for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
-                    for (int j = 0; j < n_kv; ++j) {
-                        data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
-                    }
-                }
-            }
-        }
+    if (self_kq_mask_swa) {
+        kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
     }
 }
 
@@ -545,7 +422,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
     n_layer          (hparams.n_layer),
     n_rot            (hparams.n_rot),
     n_ctx            (cparams.n_ctx),
-    n_ctx_per_seq    (cparams.n_ctx / cparams.n_seq_max),
     n_head           (hparams.n_head()),
     n_head_kv        (hparams.n_head_kv()),
     n_embd_head_k    (hparams.n_embd_head_k),
@@ -1153,7 +1029,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
 
     auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
 
-    const auto n_kv = kv_self->n;
+    const auto n_kv = kv_self->get_n();
 
     auto & cur = inp->pos_bucket;
 
@@ -1188,16 +1064,12 @@ ggml_tensor * llm_graph_context::build_attn_mha(
          ggml_tensor * kq_b,
          ggml_tensor * kq_mask,
          ggml_tensor * v_mla,
-             bool      v_trans,
              float     kq_scale) const {
-  //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
-  //const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
-
-  //const int64_t n_head    = hparams.n_head(il);
-  //const int64_t n_head_kv = hparams.n_head_kv(il);
+    const bool v_trans = v->nb[1] > v->nb[2];
 
-  //const auto & n_embd_head_k = hparams.n_embd_head_k;
-  //const auto & n_embd_head_v = hparams.n_embd_head_v;
+    q = ggml_permute(ctx0, q, 0, 2, 1, 3);
+    k = ggml_permute(ctx0, k, 0, 2, 1, 3);
+    v = ggml_permute(ctx0, v, 0, 2, 1, 3);
 
     const auto n_tokens = q->ne[1];
     const auto n_head   = q->ne[2];
@@ -1336,17 +1208,11 @@ ggml_tensor * llm_graph_context::build_attn(
 
     const auto & kq_mask = inp->get_kq_mask();
 
-    ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
-    //cb(q, "q", il);
-
-    ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
-    //cb(k, "k", il);
-
-    ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
-    //cb(k, "v", il);
-
-    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
+    ggml_tensor * q = q_cur;
+    ggml_tensor * k = k_cur;
+    ggml_tensor * v = v_cur;
 
+    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
     cb(cur, "kqv_out", il);
 
     if (wo) {
@@ -1369,22 +1235,16 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
 
     auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
 
-    const auto n_kv = kv_self->n;
-
-    inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
-    //cb(inp->self_kq_mask, "KQ_mask", -1);
-    ggml_set_input(inp->self_kq_mask);
-
-    inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
+    {
+        GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
 
-    if (hparams.n_swa_pattern > 1) {
-        GGML_ASSERT(hparams.n_swa > 0);
+        const auto n_kv = kv_self->get_n();
 
-        inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
-        //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
-        ggml_set_input(inp->self_kq_mask_swa);
+        inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+        //cb(inp->self_kq_mask, "KQ_mask", -1);
+        ggml_set_input(inp->self_kq_mask);
 
-        inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
+        inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
     }
 
     return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
@@ -1409,81 +1269,104 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_build_forward_expand(gf, v_cur);
 
     const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
-    const auto & n_ctx = cparams.n_ctx;
 
-    const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
-    const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
+    // store to KV cache
+    {
+        ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
+        ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
+    }
+
+    const auto & kq_mask = inp->get_kq_mask();
 
-    const auto n_tokens = q_cur->ne[2];
+    ggml_tensor * q = q_cur;
+    ggml_tensor * k = kv_self->get_k(ctx0, il);
+    ggml_tensor * v = kv_self->get_v(ctx0, il);
 
-    const bool v_trans = !cparams.flash_attn;
+    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
+    cb(cur, "kqv_out", il);
 
-    // store to KV cache
-    {
-        const auto kv_head = kv_self->head;
+    if (wo) {
+        cur = build_lora_mm(wo, cur);
+        if (arch == LLM_ARCH_GLM4) {
+            // GLM4 seems to have numerical issues with half-precision accumulators
+            ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
+        }
+    }
 
-        GGML_ASSERT(kv_self->size == n_ctx);
+    if (wo_b) {
+        cur = ggml_add(ctx0, cur, wo_b);
+    }
 
-        ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa)*kv_head);
-        //cb(k_cache_view, "k_cache_view", il);
+    return cur;
+}
 
-        // note: storing RoPE-ed version of K in the KV cache
-        ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
+llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
+    const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
 
-        v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
+    auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
 
-        ggml_tensor * v_cache_view = nullptr;
+    {
+        const auto n_kv = kv_self->get_kv_base()->get_n();
 
-        if (!v_trans) {
-            v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head);
-        } else {
-            // note: the V cache is transposed when not using flash attention
-            v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
-                    (  n_ctx)*ggml_element_size(kv_self->v_l[il]),
-                    (kv_head)*ggml_element_size(kv_self->v_l[il]));
+        inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+        //cb(inp->self_kq_mask, "KQ_mask", -1);
+        ggml_set_input(inp->self_kq_mask);
 
-            v_cur = ggml_transpose(ctx0, v_cur);
-        }
-        //cb(v_cache_view, "v_cache_view", il);
+        inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
+    }
+
+    {
+        GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
 
-        ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
+        const auto n_kv = kv_self->get_kv_swa()->get_n();
+
+        inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+        //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
+        ggml_set_input(inp->self_kq_mask_swa);
+
+        inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
     }
 
+    return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
+}
+
+ggml_tensor * llm_graph_context::build_attn(
+        llm_graph_input_attn_kv_unified_iswa * inp,
+        ggml_cgraph * gf,
+        ggml_tensor * wo,
+        ggml_tensor * wo_b,
+        ggml_tensor * q_cur,
+        ggml_tensor * k_cur,
+        ggml_tensor * v_cur,
+        ggml_tensor * kq_b,
+        ggml_tensor * v_mla,
+            float     kq_scale,
+            int       il) const {
+    // these nodes are added to the graph together so that they are not reordered
+    // by doing so, the number of splits in the graph is reduced
+    ggml_build_forward_expand(gf, q_cur);
+    ggml_build_forward_expand(gf, k_cur);
+    ggml_build_forward_expand(gf, v_cur);
+
     const bool is_swa = hparams.is_swa(il);
 
+    const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
+
+    const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base();
+
+    // store to KV cache
+    {
+        ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il));
+        ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il));
+    }
+
     const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
 
-    const auto n_kv = kv_self->n;
+    ggml_tensor * q = q_cur;
+    ggml_tensor * k = kv->get_k(ctx0, il);
+    ggml_tensor * v = kv->get_v(ctx0, il);
 
-    const int64_t n_head_kv = hparams.n_head_kv(il);
-
-    const auto & n_embd_head_k = hparams.n_embd_head_k;
-    const auto & n_embd_head_v = hparams.n_embd_head_v;
-
-    ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
-    //cb(q, "q", il);
-
-    ggml_tensor * k =
-        ggml_view_3d(ctx0, kv_self->k_l[il],
-                n_embd_head_k, n_kv, n_head_kv,
-                ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
-                ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
-                0);
-    //cb(k, "k", il);
-
-    ggml_tensor * v = !v_trans ?
-        ggml_view_3d(ctx0, kv_self->v_l[il],
-                n_embd_head_v, n_kv, n_head_kv,
-                ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
-                ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
-                0) :
-        ggml_view_3d(ctx0, kv_self->v_l[il],
-                n_kv, n_embd_head_v, n_head_kv,
-                ggml_element_size(kv_self->v_l[il])*n_ctx,
-                ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
-                0);
-
-    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
+    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
     cb(cur, "kqv_out", il);
 
     if (wo) {
@@ -1534,17 +1417,11 @@ ggml_tensor * llm_graph_context::build_attn(
 
     const auto & kq_mask = inp->get_kq_mask_cross();
 
-    ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
-    //cb(q, "q", il);
-
-    ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
-    //cb(k, "k", il);
-
-    ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
-    //cb(k, "v", il);
-
-    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
+    ggml_tensor * q = q_cur;
+    ggml_tensor * k = k_cur;
+    ggml_tensor * v = v_cur;
 
+    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
     cb(cur, "kqv_out", il);
 
     if (wo) {
@@ -1712,3 +1589,30 @@ void llm_graph_context::build_pooling(
 
     ggml_build_forward_expand(gf, cur);
 }
+
+int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
+    // TODO move to hparams if a T5 variant appears that uses a different value
+    const int64_t max_distance = 128;
+
+    if (bidirectional) {
+        n_buckets >>= 1;
+    }
+
+    const int64_t max_exact = n_buckets >> 1;
+
+    int32_t relative_position = x - y;
+    int32_t relative_bucket = 0;
+
+    if (bidirectional) {
+        relative_bucket += (relative_position > 0) * n_buckets;
+        relative_position = abs(relative_position);
+    } else {
+        relative_position = -std::min<int32_t>(relative_position, 0);
+    }
+
+    int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
+    relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
+    relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
+
+    return relative_bucket;
+}
index 832a8c09f2b80eb816326458d100caa2ca262764..2b85bb25befbac4e9c3294c06d09e641c856a7ef 100644 (file)
@@ -19,6 +19,7 @@ struct llama_cparams;
 
 class llama_memory_i;
 class llama_kv_cache_unified;
+class llama_kv_cache_unified_iswa;
 class llama_kv_cache_recurrent;
 
 // certain models (typically multi-modal) can produce different types of graphs
@@ -255,6 +256,31 @@ public:
 
     void set_input(const llama_ubatch * ubatch) override;
 
+    ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
+
+    ggml_tensor * self_kq_mask     = nullptr; // F32 [n_kv, n_batch]
+    ggml_tensor * self_kq_mask_cnv = nullptr; //     [n_kv, n_batch]
+
+    const llama_hparams & hparams;
+    const llama_cparams & cparams;
+
+    const llama_kv_cache_unified * kv_self;
+};
+
+class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
+public:
+    llm_graph_input_attn_kv_unified_iswa(
+            const llama_hparams & hparams,
+            const llama_cparams & cparams,
+            const llama_kv_cache_unified_iswa * kv_self) :
+        hparams(hparams),
+        cparams(cparams),
+        kv_self(kv_self) {
+    }
+    ~llm_graph_input_attn_kv_unified_iswa() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
     ggml_tensor * get_kq_mask()     const { return self_kq_mask_cnv; }
     ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
 
@@ -266,7 +292,7 @@ public:
     const llama_hparams & hparams;
     const llama_cparams & cparams;
 
-    const llama_kv_cache_unified * kv_self;
+    const llama_kv_cache_unified_iswa * kv_self;
 };
 
 class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -378,7 +404,6 @@ struct llm_graph_context {
     const int64_t n_layer;
     const int64_t n_rot;
     const int64_t n_ctx;       // user-specified context size (can be different from n_ctx_train)
-    const int64_t n_ctx_per_seq;
     const int64_t n_head;
     const int64_t n_head_kv;
     const int64_t n_embd_head_k;
@@ -507,13 +532,12 @@ struct llm_graph_context {
 
     ggml_tensor * build_attn_mha(
              ggml_cgraph * gf,
-             ggml_tensor * q,     // [n_embd_head_q, n_tokens, n_head_q]
-             ggml_tensor * k,     // [n_embd_head_k, n_tokens, n_head_k]
-             ggml_tensor * v,     // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
+             ggml_tensor * q,       // [n_embd_head_q, n_head_q, n_tokens]
+             ggml_tensor * k,       // [n_embd_head_k, n_head_k, n_tokens]
+             ggml_tensor * v,       // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
              ggml_tensor * kq_b,
              ggml_tensor * kq_mask,
-             ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
-                    bool   v_trans,
+             ggml_tensor * v_mla,   // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
                    float   kq_scale) const;
 
     llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
@@ -546,6 +570,21 @@ struct llm_graph_context {
                   float   kq_scale,
                     int   il) const;
 
+    llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
+
+    ggml_tensor * build_attn(
+            llm_graph_input_attn_kv_unified_iswa * inp,
+            ggml_cgraph * gf,
+            ggml_tensor * wo,
+            ggml_tensor * wo_b,
+            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
+            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
+            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
+            ggml_tensor * kq_b,
+            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
+                  float   kq_scale,
+                    int   il) const;
+
     llm_graph_input_attn_cross * build_attn_inp_cross() const;
 
     ggml_tensor * build_attn(
@@ -596,3 +635,6 @@ struct llm_graph_context {
             ggml_tensor * cls_out,
             ggml_tensor * cls_out_b) const;
 };
+
+// TODO: better name
+int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);
index 90dfe7a7fcc00f2f2504f6f11cef64b198c27fee..1499eb08a5dd9246f182dc3545d4e23aecc5ca29 100644 (file)
@@ -2,6 +2,22 @@
 
 #include "ggml.h"
 
+void llama_hparams::set_swa_pattern(uint32_t n_pattern) {
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
+    }
+}
+
+bool llama_hparams::is_swa_any() const {
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        if (swa_layers[il]) {
+            return true;
+        }
+    }
+
+    return false;
+}
+
 uint32_t llama_hparams::n_head(uint32_t il) const {
     if (il < n_layer) {
         return n_head_arr[il];
@@ -72,7 +88,7 @@ uint32_t llama_hparams::n_embd_v_s() const {
 
 bool llama_hparams::is_swa(uint32_t il) const {
     if (il < n_layer) {
-        return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1);
+        return swa_layers[il];
     }
 
     GGML_ABORT("fatal error");
index 7ee6a5b75ad1ef66a0e3d21a514e257ae6dcecf1..2d72eab180ad0c93cb797e194d73ff38e11547fd 100644 (file)
@@ -14,6 +14,12 @@ enum llama_expert_gating_func_type {
     LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
 };
 
+enum llama_swa_type {
+    LLAMA_SWA_TYPE_NONE     = 0,
+    LLAMA_SWA_TYPE_STANDARD = 1,
+    LLAMA_SWA_TYPE_CHUNKED  = 2,
+};
+
 struct llama_hparams_posnet {
     uint32_t n_embd;
     uint32_t n_layer;
@@ -35,8 +41,6 @@ struct llama_hparams {
     uint32_t n_embd_features = 0;
     uint32_t n_layer;
     uint32_t n_rot;
-    uint32_t n_swa = 0; // sliding window attention (SWA)
-    uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
     uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
     uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
     uint32_t n_expert = 0;
@@ -96,6 +100,15 @@ struct llama_hparams {
 
     std::array<int, 4> rope_sections;
 
+    // Sliding Window Attention (SWA)
+    llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
+    // the size of the sliding window (0 - no SWA)
+    uint32_t n_swa = 0;
+    // if swa_layers[il] == true, then layer il is SWA
+    // if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
+    // by default, all layers are dense
+    std::array<bool, LLAMA_MAX_LAYERS> swa_layers;
+
     // for State Space Models
     uint32_t ssm_d_conv  = 0;
     uint32_t ssm_d_inner = 0;
@@ -116,11 +129,10 @@ struct llama_hparams {
     bool causal_attn   = true;
     bool use_alibi     = false;
     bool attn_soft_cap = false;
+    bool use_kq_norm   = true;
 
+    // llama4
     uint32_t n_moe_layer_step        = 0;
-    bool     use_kq_norm             = true;
-    uint32_t n_attn_chunk            = 0;
-    // values below seems to be fixed on llama4
     uint32_t n_no_rope_layer_step    = 4;
     uint32_t n_attn_temp_floor_scale = 8192;
     float    f_attn_temp_scale       = 0.1;
@@ -133,6 +145,23 @@ struct llama_hparams {
     enum llama_rope_type         rope_type               = LLAMA_ROPE_TYPE_NONE;
     enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
 
+    // this value n_pattern means that every nth layer is dense (i.e. non-SWA)
+    // note that if n_pattern == 0, all layers are SWA
+    //           if n_pattern == 1, all layers are dense
+    // example: n_pattern = 3
+    //   il == 0: swa
+    //   il == 1: swa
+    //   il == 2: dense
+    //   il == 3: swa
+    //   il == 4: swa
+    //   il == 5: dense
+    //   il == 6: swa
+    //   etc ...
+    void set_swa_pattern(uint32_t n_pattern);
+
+    // return true if one of the layers is SWA
+    bool is_swa_any() const;
+
     uint32_t n_head(uint32_t il = 0) const;
 
     uint32_t n_head_kv(uint32_t il = 0) const;
index 265db2527c7ca5d012cd5c305bd8ecefdbba983a..4a42d6ecdc4556f7528810afce958d2854d50f6f 100644 (file)
@@ -23,32 +23,21 @@ uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
 }
 
 llama_kv_cache_unified::llama_kv_cache_unified(
-        const llama_model & model,
-                ggml_type   type_k,
-                ggml_type   type_v,
-                     bool   v_trans,
-                     bool   offload,
-                 uint32_t   kv_size,
-                 uint32_t   padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) {
-    const int32_t n_layer = hparams.n_layer;
-
-    has_shift = false;
-    can_shift = true;
-
-    LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d, padding = %d\n",
-            __func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift, padding);
-
-    GGML_ASSERT(kv_size % padding == 0 && "kv_size must be a multiple of padding");
-
-    head = 0;
-    size = kv_size;
-    used = 0;
-
-    this->type_k = type_k;
-    this->type_v = type_v;
-
-    cells.clear();
-    cells.resize(kv_size);
+        const llama_model &  model,
+          layer_filter_cb && filter,
+                ggml_type    type_k,
+                ggml_type    type_v,
+                     bool    v_trans,
+                     bool    offload,
+                 uint32_t    kv_size,
+                 uint32_t    n_seq_max,
+                 uint32_t    n_pad,
+                 uint32_t    n_swa,
+           llama_swa_type    swa_type) :
+    model(model), hparams(model.hparams), v_trans(v_trans),
+    n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
+
+    GGML_ASSERT(kv_size % n_pad == 0);
 
     // create a context for each buffer type
     std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
@@ -56,7 +45,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
         auto it = ctx_map.find(buft);
         if (it == ctx_map.end()) {
             ggml_init_params params = {
-                /*.mem_size   =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
+                /*.mem_size   =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()),
                 /*.mem_buffer =*/ NULL,
                 /*.no_alloc   =*/ true,
             };
@@ -75,37 +64,48 @@ llama_kv_cache_unified::llama_kv_cache_unified(
         return it->second;
     };
 
-    k_l.reserve(n_layer);
-    v_l.reserve(n_layer);
+    head = 0;
 
-    for (int i = 0; i < n_layer; i++) {
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
-        const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
+    cells.resize(kv_size);
+
+    for (uint32_t il = 0; il < hparams.n_layer; il++) {
+        if (filter && !filter(il)) {
+            LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
+            continue;
+        }
+
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+        const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
 
         const char * dev_name = "CPU";
 
         ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
 
         if (offload) {
-            auto * dev = model.dev_layer(i);
+            auto * dev = model.dev_layer(il);
             buft = ggml_backend_dev_buffer_type(dev);
 
             dev_name = ggml_backend_dev_name(dev);
         }
 
-        LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, i, dev_name);
+        LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name);
 
         ggml_context * ctx = ctx_for_buft(buft);
         if (!ctx) {
             throw std::runtime_error("failed to create ggml context for kv cache");
         }
 
-        ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
-        ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
-        ggml_format_name(k, "cache_k_l%d", i);
-        ggml_format_name(v, "cache_v_l%d", i);
-        k_l.push_back(k);
-        v_l.push_back(v);
+        ggml_tensor * k;
+        ggml_tensor * v;
+
+        k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
+        v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
+
+        ggml_format_name(k, "cache_k_l%d", il);
+        ggml_format_name(v, "cache_v_l%d", il);
+
+        map_layer_ids[il] = layers.size();
+        layers.push_back({ il, k, v });
     }
 
     // allocate tensors and initialize the buffers to avoid NaNs in the padding
@@ -117,8 +117,10 @@ llama_kv_cache_unified::llama_kv_cache_unified(
         if (!buf) {
             throw std::runtime_error("failed to allocate buffer for kv cache");
         }
-        ggml_backend_buffer_clear(buf, 0);
+
         LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
+
+        ggml_backend_buffer_clear(buf, 0);
         bufs.emplace_back(buf);
     }
 
@@ -126,20 +128,17 @@ llama_kv_cache_unified::llama_kv_cache_unified(
         const size_t memory_size_k = size_k_bytes();
         const size_t memory_size_v = size_v_bytes();
 
-        LLAMA_LOG_INFO("%s: KV self size  = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
-                (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
+        LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
+                (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max,
                 ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
                 ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
     }
 }
 
 void llama_kv_cache_unified::clear() {
-    for (int32_t i = 0; i < (int32_t) size; ++i) {
-        cells[i].pos = -1;
-        cells[i].seq_id.clear();
-    }
+    cells.reset();
+
     head = 0;
-    used = 0;
 
     for (auto & buf : bufs) {
         ggml_backend_buffer_clear(buf.get(), 0);
@@ -147,7 +146,7 @@ void llama_kv_cache_unified::clear() {
 }
 
 bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
-    uint32_t new_head = size;
+    uint32_t new_head = cells.size();
 
     if (p0 < 0) {
         p0 = 0;
@@ -157,32 +156,20 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
         p1 = std::numeric_limits<llama_pos>::max();
     }
 
-    for (uint32_t i = 0; i < size; ++i) {
-        if (cells[i].pos >= p0 && cells[i].pos < p1) {
-            if (seq_id < 0) {
-                cells[i].seq_id.clear();
-            } else if (cells[i].has_seq_id(seq_id)) {
-                cells[i].seq_id.erase(seq_id);
-            } else {
-                continue;
-            }
-            if (cells[i].is_empty()) {
-                // keep count of the number of used cells
-                if (cells[i].pos >= 0) {
-                    used--;
-                }
-
-                cells[i].pos = -1;
+    for (uint32_t i = 0; i < cells.size(); ++i) {
+        if (!cells.pos_in(i, p0, p1)) {
+            continue;
+        }
 
-                if (new_head == size) {
-                    new_head = i;
-                }
+        if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
+            if (new_head == cells.size()) {
+                new_head = i;
             }
         }
     }
 
     // If we freed up a slot, set head to it so searching can start there.
-    if (new_head != size && new_head < head) {
+    if (new_head != cells.size() && new_head < head) {
         head = new_head;
     }
 
@@ -202,49 +189,40 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
         p1 = std::numeric_limits<llama_pos>::max();
     }
 
-    // otherwise, this is the KV of a Transformer-like model
-    head = 0;
+    for (uint32_t i = 0; i < cells.size(); ++i) {
+        if (!cells.pos_in(i, p0, p1)) {
+            continue;
+        }
 
-    for (uint32_t i = 0; i < size; ++i) {
-        if (cells[i].has_seq_id(seq_id_src) && cells[i].pos >= p0 && cells[i].pos < p1) {
-            cells[i].seq_id.insert(seq_id_dst);
+        if (cells.seq_has(i, seq_id_src)) {
+            cells.seq_add(i, seq_id_dst);
         }
     }
 }
 
 void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
-    uint32_t new_head = size;
+    uint32_t new_head = cells.size();
 
-    for (uint32_t i = 0; i < size; ++i) {
-        if (!cells[i].has_seq_id(seq_id)) {
-            if (cells[i].pos >= 0) {
-                used--;
-            }
-
-            cells[i].pos = -1;
-            cells[i].seq_id.clear();
-
-            if (new_head == size){
+    for (uint32_t i = 0; i < cells.size(); ++i) {
+        if (cells.seq_keep(i, seq_id)) {
+            if (new_head == cells.size()) {
                 new_head = i;
             }
-        } else {
-            cells[i].seq_id.clear();
-            cells[i].seq_id.insert(seq_id);
         }
     }
 
     // If we freed up a slot, set head to it so searching can start there.
-    if (new_head != size && new_head < head) {
+    if (new_head != cells.size() && new_head < head) {
         head = new_head;
     }
 }
 
-void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
-    if (delta == 0) {
+void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
+    if (shift == 0) {
         return;
     }
 
-    uint32_t new_head = size;
+    uint32_t new_head = cells.size();
 
     if (p0 < 0) {
         p0 = 0;
@@ -254,24 +232,19 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
         p1 = std::numeric_limits<llama_pos>::max();
     }
 
-    // If there is no range then return early to avoid looping over the
+    // If there is no range then return early to avoid looping over all cells.
     if (p0 == p1) {
         return;
     }
 
-    for (uint32_t i = 0; i < size; ++i) {
-        if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
-            has_shift = true;
-            cells[i].pos   += delta;
-            cells[i].delta += delta;
+    for (uint32_t i = 0; i < cells.size(); ++i) {
+        if (!cells.pos_in(i, p0, p1)) {
+            continue;
+        }
 
-            if (cells[i].pos < 0) {
-                if (!cells[i].is_empty()) {
-                    used--;
-                }
-                cells[i].pos = -1;
-                cells[i].seq_id.clear();
-                if (new_head == size) {
+        if (cells.seq_has(i, seq_id)) {
+            if (cells.pos_add(i, shift)) {
+                if (new_head == cells.size()) {
                     new_head = i;
                 }
             }
@@ -280,7 +253,7 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
 
     // If we freed up a slot, set head to it so searching can start there.
     // Otherwise we just start the next search from the beginning.
-    head = new_head != size ? new_head : 0;
+    head = new_head != cells.size() ? new_head : 0;
 }
 
 void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
@@ -301,66 +274,41 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
         return;
     }
 
-    for (uint32_t i = 0; i < size; ++i) {
-        if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
-            has_shift = true;
+    for (uint32_t i = 0; i < cells.size(); ++i) {
+        if (!cells.pos_in(i, p0, p1)) {
+            continue;
+        }
 
-            {
-                llama_pos p_old = cells[i].pos;
-                cells[i].pos   /= d;
-                cells[i].delta += cells[i].pos - p_old;
-            }
+        if (cells.seq_has(i, seq_id)) {
+            cells.pos_div(i, d);
         }
     }
 }
 
-llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
-    llama_pos result = 0;
-
-    for (uint32_t i = 0; i < size; ++i) {
-        if (cells[i].has_seq_id(seq_id)) {
-            result = std::max(result, cells[i].pos);
-        }
-    }
+llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
+    return cells.seq_pos_min(seq_id);
+}
 
-    return result;
+llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
+    return cells.seq_pos_max(seq_id);
 }
 
 void llama_kv_cache_unified::restore() {
-    if (pending.ranges.empty()) {
-        return;
+    for (auto & state : recovery.states) {
+        cells.set(state.i, state.cells);
     }
 
-    uint32_t new_head = size;
-
-    for (auto & range : pending.ranges) {
-        for (uint32_t i = range.c0; i < range.c1; ++i) {
-            cells[i].seq_id.clear();
-
-            // keep count of the number of used cells
-            if (cells[i].pos >= 0) {
-                used--;
-            }
-
-            cells[i].pos = -1;
-        }
-
-        new_head = std::min(new_head, range.c0);
-    }
-
-    if (new_head != size && new_head < head) {
-        head = new_head;
-    }
+    recovery.clear();
 }
 
 void llama_kv_cache_unified::commit() {
-    if (pending.ranges.empty()) {
-        LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
-                __func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
+    if (recovery.states.empty()) {
+        LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n",
+                __func__, "https://github.com/ggml-org/llama.cpp/pull/13194");
         return;
     }
 
-    pending.ranges.clear();
+    recovery.clear();
 }
 
 bool llama_kv_cache_unified::update(llama_context & lctx) {
@@ -368,7 +316,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
 
     auto * sched = lctx.get_sched();
 
-    if (has_shift) {
+    if (cells.get_has_shift()) {
         if (!get_can_shift()) {
             GGML_ABORT("The current KV cache / model configuration does not support K-shift");
         }
@@ -392,13 +340,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
             need_reserve = true;
         }
 
-        {
-            has_shift = false;
-
-            for (uint32_t i = 0; i < size; ++i) {
-                cells[i].delta = 0;
-            }
-        }
+        cells.reset_shift();
     }
 
     if (do_defrag) {
@@ -429,7 +371,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
 void llama_kv_cache_unified::defrag_sched(float thold) {
     // - do not defrag small contexts (i.e. < 2048 tokens)
     // - count the padding towards the number of used tokens
-    const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + padding)/n)) : 0.0f;
+    const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n)) : 0.0f;
 
     // queue defragmentation for next llama_kv_cache_update
     if (fragmentation > thold) {
@@ -440,7 +382,7 @@ void llama_kv_cache_unified::defrag_sched(float thold) {
 }
 
 void llama_kv_cache_unified::set_full() {
-    n = size;
+    n = cells.size();
 
     // when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
     //   affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
@@ -450,51 +392,67 @@ void llama_kv_cache_unified::set_full() {
     head = 0;
 }
 
-llama_sbatch llama_kv_cache_unified::sbatch_init(
-        const llama_batch & batch,
-        bool logits_all) {
+llama_sbatch llama_kv_cache_unified::sbatch_init(const llama_batch & batch, bool logits_all) {
     return llama_sbatch(batch, hparams.n_embd, true, logits_all);
 }
 
-llama_ubatch llama_kv_cache_unified::ubatch_next(
-        llama_sbatch & sbatch,
-        uint32_t n_ubatch,
-        bool embd_pooled) const {
+llama_ubatch llama_kv_cache_unified::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
     GGML_UNUSED(embd_pooled);
     return sbatch.split_simple(n_ubatch);
 }
 
-bool llama_kv_cache_unified::find_slot(
-       const llama_ubatch & ubatch) {
+bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
     const uint32_t n_tokens = ubatch.n_tokens;
-    const uint32_t n_seqs   = ubatch.n_seqs;
-    const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
 
     // if we have enough unused cells before the current head ->
     //   better to start searching from the beginning of the cache, hoping to fill it
-    if (head > used + 2*ubatch.n_tokens) {
+    if (head > cells.get_used() + 2*ubatch.n_tokens) {
         head = 0;
     }
 
     // otherwise, one cell per token.
 
-    if (n_tokens > size) {
-        LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size);
+    if (n_tokens > cells.size()) {
+        LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
         return false;
     }
 
+//#define FIND_SLOT_DEBUG 1
+#if FIND_SLOT_DEBUG
+    LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
+
+    // for debugging
+    {
+        std::string ss;
+        if (n_swa > 0) {
+            for (uint32_t i = 0; i < size; ++i) {
+                if (cells.is_empty(i)) {
+                    ss += '.';
+                } else {
+                    ss += 'x';
+                }
+                if (i%256 == 255) {
+                    ss += '\n';
+                }
+            }
+        }
+        LLAMA_LOG_WARN("\n%s\n", ss.c_str());
+    }
+#endif
+
     uint32_t n_tested = 0;
 
     while (true) {
-        if (head + n_tokens > size) {
-            n_tested += size - head;
+        if (head + n_tokens > cells.size()) {
+            n_tested += cells.size() - head;
             head = 0;
             continue;
         }
 
         bool found = true;
         for (uint32_t i = 0; i < n_tokens; i++) {
-            if (cells[head + i].pos >= 0) {
+            // TODO: improve to accept cells that are masked by the SWA
+            if (!cells.is_empty(head + i)) {
                 found = false;
                 head     += i + 1;
                 n_tested += i + 1;
@@ -506,66 +464,257 @@ bool llama_kv_cache_unified::find_slot(
             break;
         }
 
-        if (n_tested >= size) {
+        if (n_tested >= cells.size()) {
             //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
             return false;
         }
     }
 
-    for (uint32_t s = 0; s < n_seqs; s++) {
-        for (uint32_t i = 0; i < n_seq_tokens; ++i) {
-            uint32_t k = s*n_seq_tokens + i;
-            cells[head + k].pos = ubatch.pos[k];
+    // store the old state of the cells in the recovery stack
+    recovery.states.push_back({head, cells.cp(head, n_tokens)});
 
-            for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) {
-                cells[head + k].seq_id.insert(ubatch.seq_id[s][j]);
-            }
+    for (uint32_t i = 0; i < n_tokens; ++i) {
+        cells.pos_set(head + i, ubatch.pos[i]);
+
+        for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
+            cells.seq_add(head + i, ubatch.seq_id[i][j]);
         }
     }
 
-    used += n_tokens;
-
-    pending.ranges.push_back({head, head + n_tokens});
-
     // a heuristic, to avoid attending the full cache if it is not yet utilized
     // after enough generations, the benefit from this heuristic disappears
     // if we start defragmenting the cache, the benefit from this will be more important
-    n = std::min(size, std::max(padding, GGML_PAD(cell_max(), padding)));
+    n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
+
+#ifdef FIND_SLOT_DEBUG
+    LLAMA_LOG_WARN("end:   n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
+#endif
 
-    //printf("n = %5d, used = %5d, head = %5d\n", n, used, head);
+    return true;
+}
 
+bool llama_kv_cache_unified::get_can_shift() const {
     return true;
 }
 
-int32_t llama_kv_cache_unified::get_n_tokens() const {
-    int32_t result = 0;
+uint32_t llama_kv_cache_unified::get_n() const {
+    return n;
+}
+
+uint32_t llama_kv_cache_unified::get_size() const {
+    return cells.size();
+}
+
+ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const {
+    const int32_t ikv = map_layer_ids.at(il);
+
+    auto * k = layers[ikv].k;
+
+    return ggml_view_3d(ctx, k,
+            hparams.n_embd_head_k, hparams.n_head_kv(il), n,
+            ggml_row_size(k->type, hparams.n_embd_head_k),
+            ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
+            0);
+}
+
+ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il) const {
+    const int32_t ikv = map_layer_ids.at(il);
+
+    auto * v = layers[ikv].v;
 
-    for (uint32_t i = 0; i < size; i++) {
-        result += cells[i].seq_id.size();
+    if (!v_trans) {
+        // note: v->nb[1] <= v->nb[2]
+        return ggml_view_3d(ctx, v,
+                hparams.n_embd_head_v, hparams.n_head_kv(il), n,
+                ggml_row_size(v->type, hparams.n_embd_head_v),    // v->nb[1]
+                ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
+                0);
     }
 
-    return result;
+    // note: v->nb[1] > v->nb[2]
+    return ggml_view_3d(ctx, v,
+            n, hparams.n_head_kv(il), hparams.n_embd_head_v,
+            ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
+            ggml_row_size(v->type, v->ne[1]),                       // v->nb[2]
+            0);
 }
 
-int32_t llama_kv_cache_unified::get_used_cells() const {
-    return used;
+ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
+    const int32_t ikv = map_layer_ids.at(il);
+
+    auto * k = layers[ikv].k;
+
+    const int64_t n_tokens = k_cur->ne[2];
+
+    ggml_tensor * k_view = ggml_view_1d(ctx, k,
+            n_tokens*hparams.n_embd_k_gqa(il),
+            ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head);
+
+    return ggml_cpy(ctx, k_cur, k_view);
 }
 
-bool llama_kv_cache_unified::get_can_shift() const {
-    return can_shift;
+ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
+    const int32_t ikv = map_layer_ids.at(il);
+
+    auto * v = layers[ikv].v;
+
+    const int64_t n_tokens = v_cur->ne[2];
+
+    v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens);
+
+    ggml_tensor * v_view = nullptr;
+
+    if (!v_trans) {
+        v_view = ggml_view_1d(ctx, v,
+                n_tokens*hparams.n_embd_v_gqa(il),
+                ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head);
+    } else {
+        // note: the V cache is transposed when not using flash attention
+        v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
+                (v->ne[1])*ggml_element_size(v),
+                (    head)*ggml_element_size(v));
+
+        v_cur = ggml_transpose(ctx, v_cur);
+    }
+
+    return ggml_cpy(ctx, v_cur, v_view);
+}
+
+void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax) {
+    // no pruning is needed when the cache does not use SWA
+    GGML_ASSERT(swa_type != LLAMA_SWA_TYPE_NONE && "do not prune non-SWA cache");
+
+    int n_attended = 0;
+
+    for (uint32_t i = 0; i < cells.size(); ++i) {
+        if (!cells.seq_has(i, seq_id)) {
+            continue;
+        }
+
+        const llama_pos p0 = cells.pos_get(i);
+
+        if (p0 <= pmin && !is_masked_swa(p0, pmin)) {
+            n_attended++;
+        }
+
+        if (is_masked_swa(p0, pmax)) {
+            cells.seq_rm(i, seq_id);
+        }
+    }
+
+    if (n_attended < std::min<int>(n_swa, pmin)) {
+        LLAMA_LOG_WARN("%s: partial SWA cache detected - possible loss of information, pmin = %d, n_attended = %d, n_swa = %d\n", __func__, pmin, n_attended, n_swa);
+    }
+}
+
+void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
+    const int64_t n_tokens     = ubatch->n_tokens;
+    const int64_t n_seq_tokens = ubatch->n_seq_tokens;
+    const int64_t n_seqs       = ubatch->n_seqs;
+
+    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
+    float * data = (float *) dst->data;
+
+    const int64_t n_kv = n;
+
+    // Use only the previous KV cells of the correct sequence for each token of the ubatch.
+    // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
+    // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
+    //   Causal mask:
+    //      xxx-------
+    //      xxxx------
+    //      xxxxx-----
+    //   Non-causal mask:
+    //      xxxxx-----
+    //      xxxxx-----
+    //      xxxxx-----
+    // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
+    for (int h = 0; h < 1; ++h) {
+        for (int s = 0; s < n_seqs; ++s) {
+            const llama_seq_id seq_id = ubatch->seq_id[s][0];
+
+            for (int j = 0; j < n_seq_tokens; ++j) {
+                const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
+
+                for (int i = 0; i < n_kv; ++i) {
+                    float f = 0.0f;
+
+                    bool masked = false;
+
+                    if (cells.is_empty(i)) {
+                        masked = true;
+                    } else {
+                        const llama_pos p0 = cells.pos_get(i);
+
+                        // mask the token if not the same sequence
+                        masked = masked || (!cells.seq_has(i, seq_id));
+
+                        // mask future tokens
+                        masked = masked || (causal_attn && p0 > p1);
+
+                        // apply SWA if any
+                        masked = masked || (is_masked_swa(p0, p1));
+
+                        if (!masked && hparams.use_alibi) {
+                            f = -std::abs(p0 - p1);
+                        }
+                    }
+
+                    if (masked) {
+                        f = -INFINITY;
+                    }
+
+                    data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
+                }
+            }
+        }
+
+        // mask padded tokens
+        if (data) {
+            for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
+                for (int j = 0; j < n_kv; ++j) {
+                    data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
+                }
+            }
+        }
+    }
 }
 
-llama_pos llama_kv_cache_unified::get_pos_max() const {
-    llama_pos pos_max = -1;
-    for (const auto & cell : cells) {
-        pos_max = std::max(pos_max, cell.pos);
+void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
+    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
+
+    int32_t * data = (int32_t *) dst->data;
+
+    for (uint32_t i = 0; i < cells.size(); ++i) {
+        data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
     }
+}
+
+void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
+    const int64_t n_tokens = ubatch->n_tokens;
+
+    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
+    GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
+
+    int32_t * data = (int32_t *) dst->data;
+
+    const int64_t n_kv = n;
 
-    return pos_max;
+    for (int h = 0; h < 1; ++h) {
+        for (int j = 0; j < n_tokens; ++j) {
+            for (int i = 0; i < n_kv; ++i) {
+                // the position when the cells is empty is irrelevant - it will be masked out later in the attention
+                const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i);
+
+                data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
+            }
+        }
+    }
 }
 
 size_t llama_kv_cache_unified::total_size() const {
     size_t size = 0;
+
     for (const auto & buf : bufs) {
         size += ggml_backend_buffer_get_size(buf.get());
     }
@@ -576,8 +725,8 @@ size_t llama_kv_cache_unified::total_size() const {
 size_t llama_kv_cache_unified::size_k_bytes() const {
     size_t size_k_bytes = 0;
 
-    for (const auto & k : k_l) {
-        size_k_bytes += ggml_nbytes(k);
+    for (const auto & layer : layers) {
+        size_k_bytes += ggml_nbytes(layer.k);
     }
 
     return size_k_bytes;
@@ -586,8 +735,8 @@ size_t llama_kv_cache_unified::size_k_bytes() const {
 size_t llama_kv_cache_unified::size_v_bytes() const {
     size_t size_v_bytes = 0;
 
-    for (const auto & v : v_l) {
-        size_v_bytes += ggml_nbytes(v);
+    for (const auto & layer : layers) {
+        size_v_bytes += ggml_nbytes(layer.v);
     }
 
     return size_v_bytes;
@@ -651,13 +800,7 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
     GGML_UNUSED(ubatch);
 
     if (k_shift) {
-        assert(ggml_backend_buffer_is_host(k_shift->buffer));
-
-        int32_t * data = (int32_t *) k_shift->data;
-
-        for (uint32_t i = 0; i < kv_self->size; ++i) {
-            data[i] = kv_self->cells[i].delta;
-        }
+        kv_self->set_input_k_shift(k_shift);
     }
 }
 
@@ -667,13 +810,9 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
                 ggml_cgraph * gf) const {
     auto res = std::make_unique<llm_graph_result>();
 
-    const auto & n_layer = hparams.n_layer;
-
     const auto & n_embd_head_k = hparams.n_embd_head_k;
   //const auto & n_embd_head_v = hparams.n_embd_head_v;
 
-    const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
-
     //GGML_ASSERT(kv_self->size == n_ctx);
 
     auto inp = std::make_unique<llm_graph_input_k_shift>(this);
@@ -681,24 +820,22 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
     inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx);
     ggml_set_input(inp->k_shift);
 
-    for (uint32_t il = 0; il < n_layer; ++il) {
+    for (const auto & layer : layers) {
+        const uint32_t il = layer.il;
+
         const int64_t n_head_kv    = hparams.n_head_kv(il);
         const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
 
-        const bool is_swa = hparams.is_swa(il);
+        const float freq_base_l  = model.get_rope_freq_base (cparams, il);
+        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
 
-        // note: the swa rope params could become part of the cparams in the future
-        //       if we decide to make them configurable, like the non-sliding ones
-        const float freq_base_l  = is_swa ? hparams.rope_freq_base_train_swa  : cparams.rope_freq_base;
-        const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
-
-        ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
+        ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
 
         ggml_tensor * k =
-            ggml_view_3d(ctx, k_l[il],
-                n_embd_head_k, n_head_kv, size,
-                ggml_row_size(k_l[il]->type, n_embd_head_k),
-                ggml_row_size(k_l[il]->type, n_embd_k_gqa),
+            ggml_view_3d(ctx, layer.k,
+                n_embd_head_k, n_head_kv, cells.size(),
+                ggml_row_size(layer.k->type, n_embd_head_k),
+                ggml_row_size(layer.k->type, n_embd_k_gqa),
                 0);
 
         ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
@@ -803,44 +940,46 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
             nm++;
         }
 
-        for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
+        for (const auto & layer : layers) {
+            const uint32_t il = layer.il;
+
             const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
             const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
-            ggml_tensor * view_k_src = ggml_view_2d(ctx, k_l[il],
+            ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k,
                     n_embd_k_gqa, nm,
-                    ggml_row_size(k_l[il]->type, n_embd_k_gqa),
-                    ggml_row_size(k_l[il]->type, n_embd_k_gqa*i));
+                    ggml_row_size(layer.k->type, n_embd_k_gqa),
+                    ggml_row_size(layer.k->type, n_embd_k_gqa*i));
 
-            ggml_tensor * view_k_dst = ggml_view_2d(ctx, k_l[il],
+            ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k,
                     n_embd_k_gqa, nm,
-                    ggml_row_size(k_l[il]->type, n_embd_k_gqa),
-                    ggml_row_size(k_l[il]->type, n_embd_k_gqa*id));
+                    ggml_row_size(layer.k->type, n_embd_k_gqa),
+                    ggml_row_size(layer.k->type, n_embd_k_gqa*id));
 
             ggml_tensor * view_v_src;
             ggml_tensor * view_v_dst;
 
             if (cparams.flash_attn) {
                 // NOTE: the V cache is not transposed when using flash attention
-                view_v_src = ggml_view_2d(ctx, v_l[il],
+                view_v_src = ggml_view_2d(ctx, layer.v,
                         n_embd_v_gqa, nm,
-                        ggml_row_size(v_l[il]->type, n_embd_v_gqa),
-                        ggml_row_size(v_l[il]->type, n_embd_v_gqa*i));
+                        ggml_row_size(layer.v->type, n_embd_v_gqa),
+                        ggml_row_size(layer.v->type, n_embd_v_gqa*i));
 
-                view_v_dst = ggml_view_2d(ctx, v_l[il],
+                view_v_dst = ggml_view_2d(ctx, layer.v,
                         n_embd_v_gqa, nm,
-                        ggml_row_size(v_l[il]->type, n_embd_v_gqa),
-                        ggml_row_size(v_l[il]->type, n_embd_v_gqa*id));
+                        ggml_row_size(layer.v->type, n_embd_v_gqa),
+                        ggml_row_size(layer.v->type, n_embd_v_gqa*id));
             } else {
-                view_v_src = ggml_view_2d(ctx, v_l[il],
+                view_v_src = ggml_view_2d(ctx, layer.v,
                         nm, n_embd_v_gqa,
-                        ggml_row_size(v_l[il]->type, size),
-                        ggml_row_size(v_l[il]->type, i));
+                        ggml_row_size(layer.v->type, cells.size()),
+                        ggml_row_size(layer.v->type, i));
 
-                view_v_dst = ggml_view_2d(ctx, v_l[il],
+                view_v_dst = ggml_view_2d(ctx, layer.v,
                         nm, n_embd_v_gqa,
-                        ggml_row_size(v_l[il]->type, size),
-                        ggml_row_size(v_l[il]->type, id));
+                        ggml_row_size(layer.v->type, cells.size()),
+                        ggml_row_size(layer.v->type, id));
             }
 
             ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst));
@@ -857,10 +996,10 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
 }
 
 bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
-    const uint32_t n_layer = hparams.n_layer;
+    const uint32_t n_layer = layers.size();
 
-    const uint32_t n_kv   = cell_max();
-    const uint32_t n_used = used;
+    const uint32_t n_kv   = cells.used_max_p1();
+    const uint32_t n_used = cells.get_used();
 
     assert(n_used <= n_kv);
 
@@ -888,9 +1027,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
     ids.resize(n_kv, n_kv);
 
     for (uint32_t i0 = 0; i0 < n_used; ++i0) {
-        const auto & cell0 = cells[i0];
-
-        if (!cell0.is_empty()) {
+        if (!cells.is_empty(i0)) {
             ids[i0] = i0;
 
             continue;
@@ -901,7 +1038,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
         uint32_t nh = 1;
 
         // determine the size of the hole
-        while (i0 + nh < n_used && cells[i0 + nh].is_empty()) {
+        while (i0 + nh < n_used && cells.is_empty(i0 + nh)) {
             nh++;
         }
 
@@ -910,9 +1047,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
 
         // starting from the end, find nh non-empty cells
         for (; is > i0; --is) {
-            const auto & cell1 = cells[is];
-
-            if (cell1.is_empty() || ids[is] != n_kv) {
+            if (cells.is_empty(is) || ids[is] != n_kv) {
                 continue;
             }
 
@@ -939,9 +1074,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
 
         // go back and move the nf cells to the hole
         for (; i1 < n_kv; ++i1) {
-            auto & cell1 = cells[i1];
-
-            if (cell1.is_empty() || ids[i1] != n_kv) {
+            if (cells.is_empty(i1) || ids[i1] != n_kv) {
                 if (n_moves == max_moves) {
                     stop = true;
                     break;
@@ -955,10 +1088,8 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
             ids[i1] = i0 + nf;
 
             // move the cell meta data
-            cells[i0 + nf] = cell1;
+            cells.mv(i1, i0 + nf);
 
-            // clear the old cell and move the head there
-            cell1 = kv_cell();
             head = n_used;
 
             if (!cont) {
@@ -993,16 +1124,30 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
     return true;
 }
 
-uint32_t llama_kv_cache_unified::cell_max() const {
-    for (uint32_t i = size; i > 0; --i) {
-        const kv_cell & cell = cells[i - 1];
+bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
+    assert(p0 >= 0 && p1 >= 0);
 
-        if (cell.pos >= 0 && !cell.is_empty()) {
-            return i;
-        }
+    switch (swa_type) {
+        case LLAMA_SWA_TYPE_NONE:
+            {
+            } break;
+        case LLAMA_SWA_TYPE_STANDARD:
+            {
+                if (p1 - p0 >= (int32_t) n_swa) {
+                    return true;
+                }
+            } break;
+        case LLAMA_SWA_TYPE_CHUNKED:
+            {
+                const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
+
+                if (p0 < pos_chunk_start) {
+                    return true;
+                }
+            } break;
     }
 
-    return 0;
+    return false;
 }
 
 void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
@@ -1011,23 +1156,24 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq
 
     // Count the number of cells with the specified seq_id
     // Find all the ranges of cells with this seq id (or all, when -1)
-    uint32_t cell_range_begin = size;
-    for (uint32_t i = 0; i < size; ++i) {
-        const auto & cell = cells[i];
-        if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
+    uint32_t cell_range_begin = cells.size();
+
+    for (uint32_t i = 0; i < cells.size(); ++i) {
+        if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
             ++cell_count;
-            if (cell_range_begin == size) {
+            if (cell_range_begin == cells.size()) {
                 cell_range_begin = i;
             }
         } else {
-            if (cell_range_begin != size) {
+            if (cell_range_begin != cells.size()) {
                 cell_ranges.emplace_back(cell_range_begin, i);
-                cell_range_begin = size;
+                cell_range_begin = cells.size();
             }
         }
     }
-    if (cell_range_begin != size) {
-        cell_ranges.emplace_back(cell_range_begin, size);
+
+    if (cell_range_begin != cells.size()) {
+        cell_ranges.emplace_back(cell_range_begin, cells.size());
     }
 
     // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
@@ -1064,17 +1210,24 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i
 void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
     for (const auto & range : cell_ranges) {
         for (uint32_t i = range.first; i < range.second; ++i) {
-            const auto & cell = cells[i];
-            const llama_pos pos      = cell.pos;
-            const uint32_t  n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
+            std::vector<llama_seq_id> seq_ids;
+
+            for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) {
+                if (cur == seq_id || seq_id == -1) {
+                    if (cells.seq_has(i, cur)) {
+                        seq_ids.push_back(cur);
+                    }
+                }
+            }
+
+            const llama_pos pos     = cells.pos_get(i);
+            const uint32_t n_seq_id = seq_ids.size();
 
             io.write(&pos,      sizeof(pos));
             io.write(&n_seq_id, sizeof(n_seq_id));
 
-            if (n_seq_id) {
-                for (auto seq_id : cell.seq_id) {
-                    io.write(&seq_id, sizeof(seq_id));
-                }
+            for (const auto & seq_id : seq_ids) {
+                io.write(&seq_id, sizeof(seq_id));
             }
         }
     }
@@ -1082,7 +1235,7 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::
 
 void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
     const uint32_t v_trans = this->v_trans ? 1 : 0;
-    const uint32_t n_layer = hparams.n_layer;
+    const uint32_t n_layer = layers.size();
 
     io.write(&v_trans, sizeof(v_trans));
     io.write(&n_layer, sizeof(n_layer));
@@ -1091,56 +1244,63 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
 
     // Iterate and write all the keys first, each row is a cell
     // Get whole range at a time
-    for (uint32_t il = 0; il < n_layer; ++il) {
+    for (const auto & layer : layers) {
+        const uint32_t il = layer.il;
+
         const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
 
         // Write key type
-        const int32_t k_type_i = (int32_t)k_l[il]->type;
+        const int32_t k_type_i = (int32_t)layer.k->type;
         io.write(&k_type_i, sizeof(k_type_i));
 
         // Write row size of key
-        const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
+        const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
         io.write(&k_size_row, sizeof(k_size_row));
 
         // Read each range of cells of k_size length each into tmp_buf and write out
         for (const auto & range : cell_ranges) {
             const size_t range_size = range.second - range.first;
             const size_t buf_size = range_size * k_size_row;
-            io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
+            io.write_tensor(layer.k, range.first * k_size_row, buf_size);
         }
     }
 
     if (!v_trans) {
-        for (uint32_t il = 0; il < n_layer; ++il) {
+        for (const auto & layer : layers) {
+            const uint32_t il = layer.il;
+
             const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
 
             // Write value type
-            const int32_t v_type_i = (int32_t)v_l[il]->type;
+            const int32_t v_type_i = (int32_t)layer.v->type;
             io.write(&v_type_i, sizeof(v_type_i));
 
             // Write row size of value
-            const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
+            const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
             io.write(&v_size_row, sizeof(v_size_row));
 
             // Read each range of cells of v_size length each into tmp_buf and write out
             for (const auto & range : cell_ranges) {
                 const size_t range_size = range.second - range.first;
                 const size_t buf_size = range_size * v_size_row;
-                io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
+                io.write_tensor(layer.v, range.first * v_size_row, buf_size);
             }
         }
     } else {
         // When v is transposed, we also need the element size and get the element ranges from each row
-        const uint32_t kv_size = size;
-        for (uint32_t il = 0; il < n_layer; ++il) {
+        const uint32_t kv_size = cells.size();
+
+        for (const auto & layer : layers) {
+            const uint32_t il = layer.il;
+
             const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
 
             // Write value type
-            const int32_t v_type_i = (int32_t)v_l[il]->type;
+            const int32_t v_type_i = (int32_t)layer.v->type;
             io.write(&v_type_i, sizeof(v_type_i));
 
             // Write element size
-            const uint32_t v_size_el = ggml_type_size(v_l[il]->type);
+            const uint32_t v_size_el = ggml_type_size(layer.v->type);
             io.write(&v_size_el, sizeof(v_size_el));
 
             // Write GQA embedding size
@@ -1153,7 +1313,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
                     const size_t range_size = range.second - range.first;
                     const size_t src_offset = (range.first + j * kv_size) * v_size_el;
                     const size_t buf_size = range_size * v_size_el;
-                    io.write_tensor(v_l[il], src_offset, buf_size);
+                    io.write_tensor(layer.v, src_offset, buf_size);
                 }
             }
         }
@@ -1170,8 +1330,6 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
         llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
 
         batch.n_tokens = cell_count;
-        batch.n_seq_tokens = cell_count;
-        batch.n_seqs = 1;
 
         for (uint32_t i = 0; i < cell_count; ++i) {
             llama_pos pos;
@@ -1180,32 +1338,40 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
             io.read_to(&pos,      sizeof(pos));
             io.read_to(&n_seq_id, sizeof(n_seq_id));
 
-            if (n_seq_id != 0) {
+            if (n_seq_id != 1) {
                 LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
                 return false;
             }
 
-            batch.pos[i] = pos;
+            // read the sequence id, but directly discard it - we will use dest_seq_id instead
+            {
+                llama_seq_id seq_id;
+                io.read_to(&seq_id, sizeof(seq_id));
+            }
+
+            batch.pos[i]      = pos;
+            batch.n_seq_id[i] = n_seq_id;
+            batch.seq_id[i]   = &dest_seq_id;
         }
-        batch.n_seq_id[0] = 1;
-        batch.seq_id[0] = &dest_seq_id;
+
         if (!find_slot(batch)) {
             LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
             return false;
         }
+
         commit();
 
         // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
         // Assume that this is one contiguous block of cells
-        GGML_ASSERT(head + cell_count <= size);
-        GGML_ASSERT(cells[head].pos == batch.pos[0]);
-        GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
-        GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
-        GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
+        GGML_ASSERT(head + cell_count <= cells.size());
+        GGML_ASSERT(cells.pos_get(head)                  == batch.pos[0]);
+        GGML_ASSERT(cells.pos_get(head + cell_count - 1) == batch.pos[cell_count - 1]);
+        GGML_ASSERT(cells.seq_has(head,                  dest_seq_id));
+        GGML_ASSERT(cells.seq_has(head + cell_count - 1, dest_seq_id));
     } else {
         // whole KV cache restore
 
-        if (cell_count > size) {
+        if (cell_count > cells.size()) {
             LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
             return false;
         }
@@ -1213,34 +1379,28 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
         clear();
 
         for (uint32_t i = 0; i < cell_count; ++i) {
-            kv_cell & cell = cells[i];
-
             llama_pos pos;
             uint32_t  n_seq_id;
 
             io.read_to(&pos,      sizeof(pos));
             io.read_to(&n_seq_id, sizeof(n_seq_id));
 
-            cell.pos = pos;
+            cells.pos_set(i, pos);
 
             for (uint32_t j = 0; j < n_seq_id; ++j) {
                 llama_seq_id seq_id;
                 io.read_to(&seq_id, sizeof(seq_id));
 
-                // TODO: llama_kv_cache_unified should have a notion of max sequences
-                //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
-                if (seq_id < 0) {
-                    //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
-                    LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
+                if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
+                    LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
                     return false;
                 }
 
-                cell.seq_id.insert(seq_id);
+                cells.seq_add(i, seq_id);
             }
         }
 
         head = 0;
-        used = cell_count;
     }
 
     return true;
@@ -1249,15 +1409,16 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
 bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
     uint32_t v_trans;
     uint32_t n_layer;
+
     io.read_to(&v_trans, sizeof(v_trans));
     io.read_to(&n_layer, sizeof(n_layer));
 
-    if (n_layer != hparams.n_layer) {
-        LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
+    if (n_layer != layers.size()) {
+        LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
         return false;
     }
-    if (cell_count > size) {
-        LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
+    if (cell_count > cells.size()) {
+        LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
         return false;
     }
     if (this->v_trans != (bool) v_trans) {
@@ -1266,13 +1427,15 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
     }
 
     // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
-    for (uint32_t il = 0; il < n_layer; ++il) {
+    for (const auto & layer : layers) {
+        const uint32_t il = layer.il;
+
         const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
 
         // Read type of key
         int32_t k_type_i_ref;
         io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
-        const int32_t k_type_i = (int32_t) k_l[il]->type;
+        const int32_t k_type_i = (int32_t) layer.k->type;
         if (k_type_i != k_type_i_ref) {
             LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
             return false;
@@ -1281,7 +1444,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
         // Read row size of key
         uint64_t k_size_row_ref;
         io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
-        const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
+        const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
         if (k_size_row != k_size_row_ref) {
             LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
             return false;
@@ -1289,18 +1452,20 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
 
         if (cell_count) {
             // Read and set the keys for the whole cell range
-            ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
+            ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
         }
     }
 
     if (!this->v_trans) {
-        for (uint32_t il = 0; il < n_layer; ++il) {
+        for (const auto & layer : layers) {
+            const uint32_t il = layer.il;
+
             const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
 
             // Read type of value
             int32_t v_type_i_ref;
             io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
-            const int32_t v_type_i = (int32_t)v_l[il]->type;
+            const int32_t v_type_i = (int32_t)layer.v->type;
             if (v_type_i != v_type_i_ref) {
                 LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
                 return false;
@@ -1309,7 +1474,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
             // Read row size of value
             uint64_t v_size_row_ref;
             io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
-            const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
+            const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
             if (v_size_row != v_size_row_ref) {
                 LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
                 return false;
@@ -1317,18 +1482,20 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
 
             if (cell_count) {
                 // Read and set the values for the whole cell range
-                ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
+                ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
             }
         }
     } else {
         // For each layer, read the values for each cell (transposed)
-        for (uint32_t il = 0; il < n_layer; ++il) {
+        for (const auto & layer : layers) {
+            const uint32_t il = layer.il;
+
             const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
 
             // Read type of value
             int32_t v_type_i_ref;
             io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
-            const int32_t v_type_i = (int32_t)v_l[il]->type;
+            const int32_t v_type_i = (int32_t)layer.v->type;
             if (v_type_i != v_type_i_ref) {
                 LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
                 return false;
@@ -1337,7 +1504,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
             // Read element size of value
             uint32_t v_size_el_ref;
             io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
-            const size_t v_size_el = ggml_type_size(v_l[il]->type);
+            const size_t v_size_el = ggml_type_size(layer.v->type);
             if (v_size_el != v_size_el_ref) {
                 LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
                 return false;
@@ -1354,8 +1521,8 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
             if (cell_count) {
                 // For each row in the transposed matrix, read the values for the whole cell range
                 for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                    const size_t dst_offset = (head + j * size) * v_size_el;
-                    ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
+                    const size_t dst_offset = (head + j * cells.size()) * v_size_el;
+                    ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
                 }
             }
         }
@@ -1364,6 +1531,193 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
     return true;
 }
 
+//
+// llama_kv_cache_unified_iswa
+//
+
+llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
+        const llama_model & model,
+                ggml_type   type_k,
+                ggml_type   type_v,
+                     bool   v_trans,
+                     bool   offload,
+                     bool   swa_full,
+                 uint32_t   kv_size,
+                 uint32_t   n_seq_max,
+                 uint32_t   n_batch,
+                 uint32_t   n_pad) : hparams(model.hparams) {
+    llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
+    llama_kv_cache_unified::layer_filter_cb filter_swa  = [&](int32_t il) { return  model.hparams.is_swa(il); };
+
+    const uint32_t size_base = kv_size;
+
+    uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, n_pad));
+
+    // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size and disable pruning
+    if (swa_full) {
+        LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
+                __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
+
+        size_swa = size_base;
+        do_prune = false;
+    }
+
+    LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
+
+    kv_base = std::make_unique<llama_kv_cache_unified>(
+            model, std::move(filter_base), type_k, type_v,
+            v_trans, offload, size_base, n_seq_max, n_pad,
+            0, LLAMA_SWA_TYPE_NONE);
+
+    LLAMA_LOG_INFO("%s: creating     SWA KV cache, size = %u cells\n", __func__, size_swa);
+
+    kv_swa = std::make_unique<llama_kv_cache_unified>(
+            model, std::move(filter_swa), type_k, type_v,
+            v_trans, offload, size_swa, n_seq_max, n_pad,
+            hparams.n_swa, hparams.swa_type);
+}
+
+void llama_kv_cache_unified_iswa::clear() {
+    kv_base->clear();
+    kv_swa ->clear();
+}
+
+bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
+    bool res = true;
+
+    res = res & kv_base->seq_rm(seq_id, p0, p1);
+    res = res & kv_swa ->seq_rm(seq_id, p0, p1);
+
+    return res;
+}
+
+void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
+    kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
+    kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
+}
+
+void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
+    kv_base->seq_keep(seq_id);
+    kv_swa ->seq_keep(seq_id);
+}
+
+void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
+    kv_base->seq_add(seq_id, p0, p1, shift);
+    kv_swa ->seq_add(seq_id, p0, p1, shift);
+}
+
+void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
+    kv_base->seq_div(seq_id, p0, p1, d);
+    kv_swa ->seq_div(seq_id, p0, p1, d);
+}
+
+llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
+    // the base cache is a superset of the SWA cache, so we can just check the SWA cache
+    return kv_swa->seq_pos_min(seq_id);
+}
+
+llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
+    return kv_swa->seq_pos_max(seq_id);
+}
+
+void llama_kv_cache_unified_iswa::restore() {
+    kv_base->restore();
+    kv_swa ->restore();
+}
+
+void llama_kv_cache_unified_iswa::commit() {
+    kv_base->commit();
+    kv_swa ->commit();
+
+    // slide the attention window, forgetting/pruning old tokens that are outside the window
+    if (do_prune) {
+        for (const auto & [seq_id, entry] : pending.pos) {
+            kv_swa->prune_swa(seq_id, entry.pmin, entry.pmax);
+        }
+
+    }
+
+    pending.clear();
+}
+
+bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
+    bool res = true;
+
+    res = res & kv_base->update(lctx);
+    res = res & kv_swa ->update(lctx);
+
+    return res;
+}
+
+void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
+    kv_base->defrag_sched(thold);
+    kv_swa ->defrag_sched(thold);
+}
+
+void llama_kv_cache_unified_iswa::set_full() {
+    kv_base->set_full();
+    kv_swa ->set_full();
+}
+
+llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) {
+    pending.clear();
+
+    if (do_prune) {
+        for (int i = 0; i < batch.n_tokens; ++i) {
+            for (int s = 0; s < batch.n_seq_id[i]; ++s) {
+                const llama_seq_id seq_id = batch.seq_id[i][s];
+                const llama_pos    pos    = batch.pos[i];
+
+                if (pending.pos.find(seq_id) == pending.pos.end()) {
+                    pending.pos[seq_id].pmin = pos;
+                    pending.pos[seq_id].pmax = pos;
+                } else {
+                    pending.pos[seq_id].pmin = std::min(pending.pos[seq_id].pmin, pos);
+                    pending.pos[seq_id].pmax = std::max(pending.pos[seq_id].pmax, pos);
+                }
+            }
+        }
+    }
+
+    return llama_sbatch(batch, hparams.n_embd, true, logits_all);
+}
+
+llama_ubatch llama_kv_cache_unified_iswa::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
+    GGML_UNUSED(embd_pooled);
+    return sbatch.split_simple(n_ubatch);
+}
+
+bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) {
+    bool res = true;
+
+    res = res & kv_base->find_slot(batch);
+    res = res & kv_swa ->find_slot(batch);
+
+    return res;
+}
+
+bool llama_kv_cache_unified_iswa::get_can_shift() const {
+    return kv_base->get_size() == kv_swa->get_size();
+}
+
+void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
+    kv_base->state_write(io, seq_id);
+    kv_swa ->state_write(io, seq_id);
+}
+
+void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
+    kv_base->state_read(io, seq_id);
+    kv_swa ->state_read(io, seq_id);
+}
+
+llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_base() const {
+    return kv_base.get();
+}
+
+llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_swa() const {
+    return kv_swa.get();
+}
+
 //
 // llama_kv_cache_recurrent
 //
@@ -1373,19 +1727,17 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
                 ggml_type   type_k,
                 ggml_type   type_v,
                      bool   offload,
-                 uint32_t   kv_size) : hparams(model.hparams) {
+                 uint32_t   kv_size,
+                 uint32_t   n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
     const int32_t n_layer = hparams.n_layer;
 
-    LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d\n",
-            __func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
+    LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
+            __func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
 
     head = 0;
     size = kv_size;
     used = 0;
 
-    this->type_k = type_k;
-    this->type_v = type_v;
-
     cells.clear();
     cells.resize(kv_size);
 
@@ -1623,8 +1975,8 @@ void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
     }
 }
 
-void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
-    if (delta == 0) {
+void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
+    if (shift == 0) {
         return;
     }
 
@@ -1647,7 +1999,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
         if (tail_id >= 0) {
             kv_cell & cell = cells[tail_id];
             if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
-                cell.pos += delta;
+                cell.pos += shift;
             }
         }
     }
@@ -1683,8 +2035,24 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
     }
 }
 
+llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
+    llama_pos result = std::numeric_limits<llama_pos>::max();
+
+    for (uint32_t i = 0; i < size; ++i) {
+        if (cells[i].has_seq_id(seq_id)) {
+            result = std::min(result, cells[i].pos);
+        }
+    }
+
+    if (result == std::numeric_limits<llama_pos>::max()) {
+        result = -1;
+    }
+
+    return result;
+}
+
 llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
-    llama_pos result = 0;
+    llama_pos result = -1;
 
     for (uint32_t i = 0; i < size; ++i) {
         if (cells[i].has_seq_id(seq_id)) {
@@ -1707,8 +2075,8 @@ void llama_kv_cache_recurrent::commit() {
     pending.ranges.clear();
 }
 
-bool llama_kv_cache_recurrent::update(llama_context & lctx) {
-    GGML_UNUSED(lctx);
+bool llama_kv_cache_recurrent::update(llama_context & ctx) {
+    GGML_UNUSED(ctx);
     return false;
 }
 
@@ -1769,7 +2137,7 @@ bool llama_kv_cache_recurrent::find_slot(
             if (seq_id < 0 || (uint32_t) seq_id >= size) {
                 // too big seq_id
                 // TODO: would it be possible to resize the cache instead?
-                LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
+                LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max);
                 return false;
             }
             if (j > 0) {
@@ -1912,29 +2280,6 @@ bool llama_kv_cache_recurrent::find_slot(
     return n >= n_seqs;
 }
 
-int32_t llama_kv_cache_recurrent::get_n_tokens() const {
-    int32_t result = 0;
-
-    for (uint32_t i = 0; i < size; i++) {
-        result += cells[i].seq_id.size();
-    }
-
-    return result;
-}
-
-int32_t llama_kv_cache_recurrent::get_used_cells() const {
-    return used;
-}
-
-llama_pos llama_kv_cache_recurrent::get_pos_max() const {
-    llama_pos pos_max = -1;
-    for (const auto & cell : cells) {
-        pos_max = std::max(pos_max, cell.pos);
-    }
-
-    return pos_max;
-}
-
 bool llama_kv_cache_recurrent::get_can_shift() const {
     return false;
 }
@@ -2063,6 +2408,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq
     io.read_to(&cell_count, sizeof(cell_count));
 
     bool res = true;
+
     res = res && state_read_meta(io, cell_count, seq_id);
     res = res && state_read_data(io, cell_count);
 
@@ -2391,104 +2737,3 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
 
     return true;
 }
-
-//
-// kv cache view
-//
-
-llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max) {
-    llama_kv_cache_view result = {
-        /*.n_cells            = */ 0,
-        /*.n_seq_max          = */ n_seq_max,
-        /*.token_count        = */ 0,
-        /*.used_cells         = */ kv.get_used_cells(),
-        /*.max_contiguous     = */ 0,
-        /*.max_contiguous_idx = */ -1,
-        /*.cells              = */ nullptr,
-        /*.cells_sequences    = */ nullptr,
-    };
-
-    return result;
-}
-
-void llama_kv_cache_view_free(llama_kv_cache_view * view) {
-    if (view->cells != nullptr) {
-        free(view->cells);
-        view->cells = nullptr;
-    }
-    if (view->cells_sequences != nullptr) {
-        free(view->cells_sequences);
-        view->cells_sequences = nullptr;
-    }
-}
-
-void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv) {
-    // TODO: rework this in the future, for now quick hack
-    const llama_kv_cache_unified * kvu = dynamic_cast<const llama_kv_cache_unified *>(kv);
-    if (kvu == nullptr) {
-        LLAMA_LOG_ERROR("%s: the kv_cache_view currently works only with llama_kv_cache_unified\n", __func__);
-        return;
-    }
-
-    if (uint32_t(view->n_cells) < kvu->size || view->cells == nullptr) {
-        view->n_cells = int32_t(kvu->size);
-        void * p = realloc(view->cells, sizeof(llama_kv_cache_view_cell) * view->n_cells);
-        GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
-        view->cells = (llama_kv_cache_view_cell *)p;
-        p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells);
-        GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
-        view->cells_sequences = (llama_seq_id *)p;
-    }
-
-    const std::vector<llama_kv_cache_unified::kv_cell> & kv_cells = kvu->cells;
-    llama_kv_cache_view_cell * c_curr = view->cells;
-    llama_seq_id * cs_curr = view->cells_sequences;
-    int32_t used_cells = 0;
-    int32_t token_count = 0;
-    int32_t curr_contig_idx = -1;
-    uint32_t max_contig = 0;
-    int32_t max_contig_idx = -1;
-
-    for (int32_t i = 0; i < int32_t(kvu->size); i++, c_curr++, cs_curr += view->n_seq_max) {
-        const size_t curr_size = kv_cells[i].seq_id.size();
-        token_count += curr_size;
-        c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
-
-        if (curr_size > 0) {
-            if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) {
-                max_contig = i - curr_contig_idx;
-                max_contig_idx = curr_contig_idx;
-            }
-            curr_contig_idx = -1;
-        } else if (curr_contig_idx < 0) {
-            curr_contig_idx = i;
-        }
-
-        int seq_idx = 0;
-        for (const llama_seq_id it : kv_cells[i].seq_id) {
-            if (seq_idx >= view->n_seq_max) {
-                break;
-            }
-            cs_curr[seq_idx] = it;
-            seq_idx++;
-        }
-        if (seq_idx != 0) {
-            used_cells++;
-        }
-        for (; seq_idx < view->n_seq_max; seq_idx++) {
-            cs_curr[seq_idx] = -1;
-        }
-    }
-    if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) {
-        max_contig_idx = curr_contig_idx;
-        max_contig = kv_cells.size() - curr_contig_idx;
-    }
-    view->max_contiguous = max_contig;
-    view->max_contiguous_idx = max_contig_idx;
-    view->token_count = token_count;
-    view->used_cells = used_cells;
-    if (uint32_t(used_cells) != kvu->used) {
-        LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
-            __func__, kvu->used, used_cells);
-    }
-}
index e83e12c09f2b1d545c53b45b5c1518ca1ad8001a..ce6261e45a6e17e0052d8fa6ffaeff32b9b91641 100644 (file)
@@ -4,10 +4,12 @@
 #include "llama-io.h"
 #include "llama-graph.h"
 #include "llama-memory.h"
+#include "llama-kv-cells.h"
 
 #include "ggml-cpp.h"
 
 #include <set>
+#include <unordered_map>
 #include <vector>
 
 struct llama_cparams;
@@ -34,12 +36,16 @@ struct llama_kv_cache : public llama_memory_i {
     virtual void defrag_sched(float thold) = 0;
 
     // simulate full cache, used for allocating worst-case compute buffers
+    // TODO: remove
     virtual void set_full() = 0;
 
     //
     // batch processing
     //
 
+    // =============================================================================================================
+    // TODO: refactor and simplify this [TAG: KV_API]
+
     virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
 
     // different KV caches require different batch splitting strategies
@@ -48,11 +54,10 @@ struct llama_kv_cache : public llama_memory_i {
     // find an empty slot of size "n_tokens" in the cache
     virtual bool find_slot(const llama_ubatch & batch) = 0;
 
+    // =============================================================================================================
+
     // getters
-    virtual int32_t   get_n_tokens()   const = 0;
-    virtual int32_t   get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
-    virtual llama_pos get_pos_max()    const = 0;
-    virtual bool      get_can_shift()  const = 0;
+    virtual bool get_can_shift() const = 0;
 
     bool get_can_edit() const override { return get_can_shift(); }
 
@@ -87,38 +92,25 @@ private:
 // llama_kv_cache_unified
 //
 
-// TODO: add notion of max sequences
 class llama_kv_cache_unified : public llama_kv_cache {
 public:
-    struct kv_cell {
-        llama_pos pos   = -1;
-        llama_pos delta =  0;
-
-        std::set<llama_seq_id> seq_id;
-
-        bool has_seq_id(const llama_seq_id & id) const {
-            return seq_id.find(id) != seq_id.end();
-        }
-
-        bool is_empty() const {
-            return seq_id.empty();
-        }
-
-        bool is_same_seq(const kv_cell & other) const {
-            return seq_id == other.seq_id;
-        }
-    };
-
     static uint32_t get_padding(const llama_cparams & cparams);
 
+    // this callback is used to filter out layers that should not be included in the cache
+    using layer_filter_cb = std::function<bool(int32_t il)>;
+
     llama_kv_cache_unified(
-            const llama_model & model,
-                    ggml_type   type_k,
-                    ggml_type   type_v,
-                         bool   v_trans,
-                         bool   offload,
-                     uint32_t   kv_size,
-                     uint32_t   padding);
+            const llama_model &  model,
+              layer_filter_cb && filter,
+                    ggml_type    type_k,
+                    ggml_type    type_v,
+                         bool    v_trans,
+                         bool    offload,
+                     uint32_t    kv_size,
+                     uint32_t    n_seq_max,
+                     uint32_t    n_pad,
+                     uint32_t    n_swa,
+               llama_swa_type    swa_type);
 
     ~llama_kv_cache_unified() = default;
 
@@ -130,10 +122,11 @@ public:
 
     bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
     void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
-    void seq_keep(llama_seq_id seq_id) override;
-    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) override;
+    void seq_keep(llama_seq_id seq_id)                                                          override;
+    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
     void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
 
+    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
     llama_pos seq_pos_max(llama_seq_id seq_id) const override;
 
     //
@@ -150,7 +143,6 @@ public:
     void set_full() override;
 
     llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
-
     llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
 
     // updates the cache head
@@ -158,50 +150,94 @@ public:
     // to the first cell of the slot.
     bool find_slot(const llama_ubatch & batch) override;
 
-    int32_t get_n_tokens()   const override;
-    int32_t get_used_cells() const override;
-
-    // TODO: better data structures to reduce the cost of this operation
-    llama_pos get_pos_max() const override;
-
     bool get_can_shift() const override;
 
     // state write/load
 
     void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
-    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1) override;
+    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1)       override;
 
-    uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
-    uint32_t size = 0; // total number of cells, shared across all sequences
-    uint32_t used = 0; // used cells (i.e. at least one seq_id)
+    //
+    // llama_kv_cache_unified specific API
+    //
 
-    // computed before each graph build
-    uint32_t n = 0;
+    uint32_t get_n()    const;
+    uint32_t get_size() const;
 
-    std::vector<kv_cell> cells;
+    // get views of the current state of the cache
+    ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
+    ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
 
-    std::vector<ggml_tensor *> k_l; // per layer
-    std::vector<ggml_tensor *> v_l;
+    // store k_cur and v_cur in the cache based on the current head location
+    ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
+    ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
+
+    void prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax);
+
+    void set_input_kq_mask   (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
+    void set_input_k_shift   (ggml_tensor * dst) const;
+    void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
 
 private:
     const llama_model & model;
     const llama_hparams & hparams;
 
-    bool has_shift = false;
-    bool do_defrag = false;
+    struct kv_layer {
+        // layer index in the model
+        // note: can be different from the layer index in the KV cache
+        uint32_t il;
+
+        ggml_tensor * k;
+        ggml_tensor * v;
+    };
 
+    bool do_defrag = false;
     bool v_trans   = true;  // the value tensor is transposed
-    bool can_shift = false;
+
+    uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
+
+    // computed before each graph build
+    // TODO: cells should start to maintain this value dynamically based on the edits
+    uint32_t n = 0;
+
+    const uint32_t n_seq_max = 1;
 
     // required padding
-    uint32_t padding = 1;
+    const uint32_t n_pad = 1;
 
-    ggml_type type_k = GGML_TYPE_F16;
-    ggml_type type_v = GGML_TYPE_F16;
+    // SWA
+    const uint32_t n_swa = 0;
+
+    const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
 
     std::vector<ggml_context_ptr>        ctxs;
     std::vector<ggml_backend_buffer_ptr> bufs;
 
+    llama_kv_cells_unified cells;
+
+    std::vector<kv_layer> layers;
+
+    // model layer id -> KV cache layer id
+    std::unordered_map<int32_t, int32_t> map_layer_ids;
+
+    // recovery information used to restore the KV cells to their original state in case of a failure
+    // TODO: do not store as a state in the llama_kv_cache object, instead return upon batch preparation
+    //       to achieve that, first need to refactor the llama_kv_cache interface [TAG: KV_API]
+    struct {
+        void clear() {
+            states.clear();
+        }
+
+        struct state {
+            uint32_t i;
+
+            llama_kv_cells_unified cells;
+        };
+
+        // stack with the partial states before each ubatch
+        std::vector<state> states;
+    } recovery;
+
     // defrag
     struct {
         std::vector<uint32_t> ids;
@@ -210,25 +246,13 @@ private:
     // return true if cells have been moved
     bool defrag_prepare(int32_t n_max_nodes);
 
-    // commit/restore cache
-    struct slot_range {
-        uint32_t c0 = 0; // note: these are cell indices, not sequence positions
-        uint32_t c1 = 0;
-    };
-
-    // pending cell updates that are not yet committed
-    struct {
-        std::vector<slot_range> ranges;
-    } pending;
-
-    // find how many cells are currently in use
-    uint32_t cell_max() const;
-
     size_t total_size() const;
 
     size_t size_k_bytes() const;
     size_t size_v_bytes() const;
 
+    bool is_masked_swa(llama_pos p0, llama_pos p1) const;
+
     ggml_tensor * build_rope_shift(
             const llama_cparams & cparams,
                    ggml_context * ctx,
@@ -255,6 +279,100 @@ private:
     bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
 };
 
+//
+// llama_kv_cache_unified_iswa
+//
+
+// utilizes two instances of llama_kv_cache_unified
+//   the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
+//   upon successful commit, the SWA cache removes old tokens outside the n_swa window
+
+class llama_kv_cache_unified_iswa : public llama_kv_cache {
+public:
+    llama_kv_cache_unified_iswa(
+            const llama_model & model,
+                    ggml_type   type_k,
+                    ggml_type   type_v,
+                         bool   v_trans,
+                         bool   offload,
+                         bool   swa_full,
+                     uint32_t   kv_size,
+                     uint32_t   n_seq_max,
+                     uint32_t   n_batch,
+                     uint32_t   n_pad);
+
+    ~llama_kv_cache_unified_iswa() = default;
+
+    //
+    // llama_memory_i
+    //
+
+    void clear() override;
+
+    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
+    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
+    void seq_keep(llama_seq_id seq_id)                                                          override;
+    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
+    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
+
+    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
+    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
+
+    //
+    // llama_kv_cache
+    //
+
+    void restore() override;
+    void commit()  override;
+
+    bool update(llama_context & ctx) override;
+
+    void defrag_sched(float thold) override;
+
+    void set_full() override;
+
+    llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
+    llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
+
+    bool find_slot(const llama_ubatch & batch) override;
+
+    bool get_can_shift() const override;
+
+    // state write/load
+
+    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
+    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1)       override;
+
+    //
+    // llama_kv_cache_unified_iswa specific API
+    //
+
+    llama_kv_cache_unified * get_kv_base() const;
+    llama_kv_cache_unified * get_kv_swa () const;
+
+private:
+    const llama_hparams & hparams;
+
+    bool do_prune = true;
+
+    struct {
+        struct entry {
+            llama_pos pmin;
+            llama_pos pmax;
+        };
+
+        void clear() {
+            pos.clear();
+        }
+
+        // used to perform SWA pruning of old tokens
+        std::unordered_map<llama_seq_id, entry> pos;
+    } pending;
+
+    std::unique_ptr<llama_kv_cache_unified> kv_base;
+    std::unique_ptr<llama_kv_cache_unified> kv_swa;
+};
+
 //
 // llama_kv_cache_recurrent
 //
@@ -286,7 +404,8 @@ public:
                     ggml_type   type_k,
                     ggml_type   type_v,
                          bool   offload,
-                     uint32_t   kv_size);
+                     uint32_t   kv_size,
+                     uint32_t   n_seq_max);
 
     ~llama_kv_cache_recurrent() = default;
 
@@ -298,10 +417,11 @@ public:
 
     bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
     void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
-    void seq_keep(llama_seq_id seq_id) override;
-    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) override;
+    void seq_keep(llama_seq_id seq_id)                                                          override;
+    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
     void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
 
+    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
     llama_pos seq_pos_max(llama_seq_id seq_id) const override;
 
     //
@@ -311,24 +431,17 @@ public:
     void restore() override;
     void commit()  override;
 
-    bool update(llama_context & lctx) override;
+    bool update(llama_context & ctx) override;
 
     void defrag_sched(float thold) override;
 
     void set_full() override;
 
     llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
-
     llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
 
     bool find_slot(const llama_ubatch & batch) override;
 
-    int32_t get_n_tokens()   const override;
-    int32_t get_used_cells() const override;
-
-    // TODO: better data structures to reduce the cost of this operation
-    llama_pos get_pos_max() const override;
-
     bool get_can_shift() const override;
 
     // TODO: temporary methods - they are not really const as they do const_cast<>, fix this
@@ -368,8 +481,7 @@ private:
         std::vector<slot_range> ranges;
     } pending;
 
-    ggml_type type_k = GGML_TYPE_F16;
-    ggml_type type_v = GGML_TYPE_F16;
+    const uint32_t n_seq_max = 1;
 
     std::vector<ggml_context_ptr>        ctxs;
     std::vector<ggml_backend_buffer_ptr> bufs;
@@ -388,12 +500,3 @@ private:
     bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
     bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
 };
-
-
-//
-// kv cache view
-//
-
-llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max);
-
-void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv);
diff --git a/examples/talk-llama/llama-kv-cells.h b/examples/talk-llama/llama-kv-cells.h
new file mode 100644 (file)
index 0000000..dbbd03f
--- /dev/null
@@ -0,0 +1,379 @@
+#pragma once
+
+#include "llama.h"
+#include "llama-cparams.h"
+
+#include <bitset>
+#include <cassert>
+#include <vector>
+#include <set>
+
+// meta information about KV cells that can be part of multiple sequences at the same time
+// TODO: add unit tests
+class llama_kv_cells_unified {
+public:
+    void reset() {
+        for (uint32_t i = 0; i < pos.size(); ++i) {
+            pos[i]   = -1;
+            shift[i] =  0;
+            seq[i].reset();
+        }
+
+        has_shift = false;
+
+        used.clear();
+
+        for (uint32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+            seq_pos[s].clear();
+        }
+    }
+
+    void reset_shift() {
+        has_shift = false;
+
+        for (uint32_t i = 0; i < shift.size(); ++i) {
+            shift[i] = 0;
+        }
+    }
+
+    uint32_t size() const {
+        return pos.size();
+    }
+
+    void resize(uint32_t n) {
+        pos.resize(n);
+        shift.resize(n);
+        seq.resize(n);
+
+        reset();
+    }
+
+    bool is_empty(uint32_t i) const {
+        assert(i < pos.size());
+        assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0);
+
+        return pos[i] == -1;
+    }
+
+    uint32_t get_used() const {
+        return used.size();
+    }
+
+    // the index of the first cell that is used
+    // return 0 if no cells are used
+    uint32_t used_min() const {
+        return used.empty() ? 0 : *used.begin();
+    }
+
+    // the index of the last cell that is used + 1
+    // return 0 if no cells are used
+    uint32_t used_max_p1() const {
+#if 0
+        if (!seq_pos[0].empty()) printf("kv_cells: min[0] = %5d, max[0] = %5d\n", *seq_pos[0].begin(), *seq_pos[0].rbegin());
+        if (!seq_pos[1].empty()) printf("kv_cells: min[1] = %5d, max[1] = %5d\n", *seq_pos[1].begin(), *seq_pos[1].rbegin());
+        if (!seq_pos[2].empty()) printf("kv_cells: min[2] = %5d, max[2] = %5d\n", *seq_pos[2].begin(), *seq_pos[2].rbegin());
+#endif
+
+        return used.empty() ? 0 : *used.rbegin() + 1;
+    }
+
+    bool get_has_shift() const {
+        return has_shift;
+    }
+
+    // move cell isrc to idst (used during defrag)
+    void mv(uint32_t isrc, uint32_t idst) {
+        assert(isrc < pos.size());
+        assert(idst < pos.size());
+
+        pos  [idst] = pos  [isrc];
+        shift[idst] = shift[isrc];
+        seq  [idst] = seq  [isrc];
+
+        pos  [isrc] = -1;
+        shift[isrc] =  0;
+        seq  [isrc].reset();
+
+        used.erase (isrc);
+        used.insert(idst);
+    }
+
+    // copy the state of cells [i, i + n) (used for save/restore the state of the cells)
+    llama_kv_cells_unified cp(uint32_t i, uint32_t n) const {
+        assert(i + n <= pos.size());
+
+        llama_kv_cells_unified res;
+
+        res.resize(n);
+
+        for (uint32_t j = 0; j < n; ++j) {
+            res.pos[j] = pos[i + j];
+            res.seq[j] = seq[i + j];
+
+            assert(shift[i + j] == 0);
+        }
+
+        return res;
+    }
+
+    // set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
+    void set(uint32_t i, const llama_kv_cells_unified & other) {
+        assert(i + other.pos.size() <= pos.size());
+
+        for (uint32_t j = 0; j < other.pos.size(); ++j) {
+            if (pos[i + j] == -1 && other.pos[j] != -1) {
+                used.insert(i + j);
+            }
+
+            if (pos[i + j] != -1 && other.pos[j] == -1) {
+                used.erase(i + j);
+            }
+
+            if (pos[i + j] != -1) {
+                seq_pos_rm(i + j);
+            }
+
+            pos[i + j] = other.pos[j];
+            seq[i + j] = other.seq[j];
+
+            if (pos[i + j] != -1) {
+                seq_pos_add(i + j);
+            }
+
+            assert(shift[i + j] == 0);
+        }
+    }
+
+    // note: call only if the cell has seq_id
+    // return true if the cell becomes empty
+    bool seq_rm(uint32_t i, llama_seq_id seq_id) {
+        assert(i < pos.size());
+        assert(seq[i].test(seq_id));
+        assert(pos[i] != -1);
+        assert(seq_id >= 0);
+
+        seq[i].reset(seq_id);
+        seq_pos[seq_id].erase(pos[i]);
+
+        if (seq[i].none()) {
+            pos[i] = -1;
+
+            used.erase(i);
+
+            return true;
+        }
+
+        return false;
+    }
+
+    // return true if the cell becomes empty (i.e. it did not contain seq_id before the call)
+    bool seq_keep(uint32_t i, llama_seq_id seq_id) {
+        assert(i < pos.size());
+
+        if (seq[i].test(seq_id)) {
+            seq_pos_rm(i);
+            seq[i].reset();
+
+            seq[i].set(seq_id);
+            seq_pos[seq_id].insert(pos[i]);
+
+            return false;
+        }
+
+        if (seq[i].any()) {
+            seq_pos_rm(i);
+            seq[i].reset();
+
+            pos[i] = -1;
+
+            used.erase(i);
+
+            return true;
+        }
+
+        assert(pos[i] == -1);
+
+        return false;
+    }
+
+    bool seq_has(uint32_t i, llama_seq_id seq_id) const {
+        assert(i < pos.size());
+        assert(seq_id >= 0);
+
+        return seq[i].test(seq_id);
+    }
+
+    // note: call only if the cell is not empty and the seq_id is not in the cell
+    void seq_add(uint32_t i, llama_seq_id seq_id) {
+        assert(i < pos.size());
+        assert(pos[i] != -1);
+        assert(!seq[i].test(seq_id));
+
+        seq[i].set(seq_id);
+        seq_pos[seq_id].insert(pos[i]);
+    }
+
+    // the minimum position of sequence seq_id currently present in any of the cells
+    // return -1 if the sequence is not present
+    llama_pos seq_pos_min(llama_seq_id seq_id) const {
+        assert(seq_id >= 0);
+        assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
+
+        if (seq_pos[seq_id].empty()) {
+            return -1;
+        }
+
+        return *seq_pos[seq_id].begin();
+    }
+
+    // the maximum position of sequence seq_id currently present in any of the cells
+    // return -1 if the sequence is not present
+    llama_pos seq_pos_max(llama_seq_id seq_id) const {
+        assert(seq_id >= 0);
+        assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
+
+        if (seq_pos[seq_id].empty()) {
+            return -1;
+        }
+
+        return *seq_pos[seq_id].rbegin();
+    }
+
+    // note: call only if the cell is not empty
+    llama_pos pos_get(uint32_t i) const {
+        assert(i < pos.size());
+        assert(pos[i] != -1);
+
+        return pos[i];
+    }
+
+    // note: call only if the cell is not empty
+    llama_pos get_shift(uint32_t i) const {
+        assert(i < pos.size());
+        assert(pos[i] != -1);
+
+        return shift[i];
+    }
+
+    // check if a cell is not empty and its position is within [p0, p1)
+    bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const {
+        assert(i < pos.size());
+
+        return pos[i] >= p0 && pos[i] < p1;
+    }
+
+    // set the position of an empty cell
+    // does not modify "has_shift"
+    // note: call only if the cell is empty
+    void pos_set(uint32_t i, llama_pos p) {
+        assert(i < pos.size());
+        assert(pos[i] == -1);
+
+        pos[i] = p;
+
+        used.insert(i);
+    }
+
+    // pos[i] = pos[i] + d
+    // sets "has_shift" to true
+    // note: call only if the cell is not empty
+    bool pos_add(uint32_t i, llama_pos d) {
+        assert(i < pos.size());
+        assert(pos[i] != -1);
+
+        seq_pos_rm(i);
+
+        pos[i]   += d;
+        shift[i] += d;
+
+        seq_pos_add(i);
+
+        has_shift = true;
+
+        if (pos[i] < 0) {
+            seq_pos_rm(i);
+
+            seq[i].reset();
+            pos[i] = -1;
+
+            used.erase(i);
+
+            return true;
+        }
+
+        return false;
+    }
+
+    // pos[i] = pos[i] / d
+    // sets "has_shift" to true
+    // note: call only if the cell is not empty
+    void pos_div(uint32_t i, int d) {
+        assert(i < pos.size());
+        assert(pos[i] != -1);
+
+        const llama_pos p_old = pos[i];
+
+        seq_pos_rm(i);
+
+        pos[i]   /= d;
+        shift[i] += p_old - pos[i];
+
+        seq_pos_add(i);
+
+        has_shift = true;
+    }
+
+private:
+    bool has_shift = false;
+
+    // set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
+    std::set<uint32_t> used;
+
+    std::vector<llama_pos> pos;
+
+    // this array accumulates any applied shifts to the pos array since the last reset_shift() call
+    // this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
+    //
+    //   cells.pos_add(x, shift_x);
+    //   cells.pos_div(y, shift_y);
+    //   ...
+    //
+    //   if (cells.has_shift()) {
+    //      for (int i = 0; i < n; ++i) {
+    //          auto shift_i = cells.get_shift(i);
+    //          ...
+    //      }
+    //      cells.reset_shift();
+    //   }
+    //
+    std::vector<llama_pos> shift;
+
+    using bits_t = std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>;
+
+    // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
+    std::vector<bits_t> seq;
+
+    // the set seq_pos[s] tells us which positions are currently present for sequence s
+    // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
+    std::set<llama_pos> seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES];
+
+    // helper functions for updating `seq_pos`, once cell at a time:
+
+    // remove cell i
+    void seq_pos_rm(uint32_t i) {
+        for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+            if (seq[i].test(s)) {
+                seq_pos[s].erase(pos[i]);
+            }
+        }
+    }
+
+    // add cell i
+    void seq_pos_add(uint32_t i) {
+        for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+            if (seq[i].test(s)) {
+                seq_pos[s].insert(pos[i]);
+            }
+        }
+    }
+};
index c7412d5911ed79153a5d8fefe7eb24917901630c..a2d250434affa8c58fe6cb7639279b34b5397b67 100644 (file)
@@ -7,8 +7,8 @@ struct llama_memory_params {
     ggml_type type_k;
     ggml_type type_v;
 
-    // parameters for other types of memory
-    // ...
+    // use full-size SWA cache
+    bool swa_full;
 };
 
 // general concept of LLM memory
@@ -22,9 +22,10 @@ public:
     virtual bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) = 0;
     virtual void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
     virtual void seq_keep(llama_seq_id seq_id) = 0;
-    virtual void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) = 0;
+    virtual void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) = 0;
     virtual void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) = 0;
 
+    virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
     virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
 
     virtual bool get_can_edit() const = 0;
index 7fd094b63f26921ee06bae90c3c440b717f1dbee..e99f5309f99044663d190cacef2cfb1b1864f71d 100644 (file)
@@ -463,11 +463,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
         GGML_ASSERT(hparams.n_expert_used == 0);
     }
 
-    // zero-out the array hparams
     std::fill(hparams.n_head_arr.begin(),    hparams.n_head_arr.end(),    0);
     std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
     std::fill(hparams.n_ff_arr.begin(),      hparams.n_ff_arr.end(),      0);
 
+    std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0);
+
+    std::fill(hparams.swa_layers.begin(), hparams.swa_layers.end(), 0);
+
     ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH,  hparams.n_ff_arr,   hparams.n_layer, false);
     ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false);
 
@@ -571,9 +574,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
                 ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,  hparams.n_ff_exp);
                 ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP,   hparams.n_moe_layer_step);
-                hparams.n_swa_pattern = 4;    // pattern: 3 chunked - 1 full
-                hparams.n_attn_chunk  = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
-                hparams.n_swa = 1; // TODO @ngxson : this is added to trigger the SWA branch (we store the chunked attn mask in the SWA tensor), will need to clean this up later
+
+                hparams.swa_type      = LLAMA_SWA_TYPE_CHUNKED;
+                hparams.n_swa         = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
+                hparams.set_swa_pattern(4);   // pattern: 3 chunked - 1 full
 
                 switch (hparams.n_expert) {
                     case 16:  type = LLM_TYPE_17B_16E; break;
@@ -852,22 +856,17 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
 
-                // for backward compatibility ; see: https://github.com/ggerganov/llama.cpp/pull/8931
-                if ((hparams.n_layer == 32 || hparams.n_layer == 40) && hparams.n_ctx_train == 4096) {
-                    // default value for Phi-3-mini-4k-instruct and Phi-3-medium-4k-instruct
-                    hparams.n_swa = 2047;
-                } else if (hparams.n_layer == 32 && hparams.n_head_kv(0) == 32 && hparams.n_ctx_train == 131072) {
-                    // default value for Phi-3-mini-128k-instruct
-                    // note: this seems incorrect because the window is bigger than the train context?
-                    hparams.n_swa = 262144;
-                } else if (hparams.n_layer == 40 && hparams.n_ctx_train == 131072) {
-                    // default value for Phi-3-medium-128k-instruct
-                    // note: this seems incorrect because the window is equal to the train context?
-                    hparams.n_swa = 131072;
-                }
-                bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
-                if (!found_swa && hparams.n_swa == 0) {
-                    throw std::runtime_error("invalid value for sliding_window");
+                const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
+
+                if (found_swa && hparams.n_swa > 0) {
+                    LLAMA_LOG_WARN("%s: Phi SWA is currently disabled - results might be suboptimal for some models (see %s)\n",
+                            __func__, "https://github.com/ggml-org/llama.cpp/pull/13676");
+
+                    // TODO: fix conversion scripts to correctly populate `n_swa` and `n_swa_pattern`
+                    hparams.swa_type = LLAMA_SWA_TYPE_NONE;
+
+                    hparams.n_swa         = 0;
+                    hparams.set_swa_pattern(1);
                 }
             } break;
         case LLM_ARCH_PHIMOE:
@@ -937,8 +936,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
             } break;
         case LLM_ARCH_GEMMA2:
             {
+                hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
                 hparams.n_swa = 4096; // default value of gemma 2
-                hparams.n_swa_pattern = 2;
+                hparams.set_swa_pattern(2);
                 hparams.attn_soft_cap = true;
 
                 ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW,    hparams.n_swa, false);
@@ -955,7 +955,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
             } break;
         case LLM_ARCH_GEMMA3:
             {
-                hparams.n_swa_pattern = 6;
+                hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
+                hparams.set_swa_pattern(6);
 
                 hparams.rope_freq_base_train_swa  = 10000.0f;
                 hparams.rope_freq_scale_train_swa = 1.0f;
@@ -1039,7 +1040,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
             } break;
         case LLM_ARCH_COHERE2:
             {
-                hparams.n_swa_pattern = 4;
+                hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
+                hparams.set_swa_pattern(4);
 
                 ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
                 ml.get_key(LLM_KV_LOGIT_SCALE,              hparams.f_logit_scale);
@@ -2487,7 +2489,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
 
                     // output
                     output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
 
                     for (int i = 0; i < n_layer; ++i) {
                         auto & layer = layers[i];
@@ -4321,7 +4327,7 @@ void llama_model::print_info() const {
         LLAMA_LOG_INFO("%s: n_head_kv        = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
         LLAMA_LOG_INFO("%s: n_rot            = %u\n",     __func__, hparams.n_rot);
         LLAMA_LOG_INFO("%s: n_swa            = %u\n",     __func__, hparams.n_swa);
-        LLAMA_LOG_INFO("%s: n_swa_pattern    = %u\n",     __func__, hparams.n_swa_pattern);
+        LLAMA_LOG_INFO("%s: is_swa_any       = %u\n",     __func__, hparams.is_swa_any());
         LLAMA_LOG_INFO("%s: n_embd_head_k    = %u\n",     __func__, hparams.n_embd_head_k);
         LLAMA_LOG_INFO("%s: n_embd_head_v    = %u\n",     __func__, hparams.n_embd_head_v);
         LLAMA_LOG_INFO("%s: n_gqa            = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il);        }, hparams.n_layer).c_str());
@@ -4489,7 +4495,17 @@ const ggml_tensor * llama_model::get_tensor(const char * name) const {
     return it->second;
 }
 
-ggml_tensor * llama_model::get_rope_factors(uint32_t n_ctx_per_seq, int il) const {
+float llama_model::get_rope_freq_base (const llama_cparams & cparams, int il) const {
+    return hparams.is_swa(il) ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
+}
+
+float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) const {
+    return hparams.is_swa(il) ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
+}
+
+ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const {
+    const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
+
     // choose long/short freq factors based on the context size
     if (layers[il].rope_freqs != nullptr) {
         return layers[il].rope_freqs;
@@ -4517,21 +4533,174 @@ struct llm_build_llama : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
+        auto * inp_attn = build_attn_inp_kv_unified();
+
+        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // rope freq factors for llama3; may return nullptr for llama2 and other models
+                ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
+
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
+                cb(cur, "attn_out", il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network (non-MoE)
+            if (model.layers[il].ffn_gate_inp == nullptr) {
+
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            } else {
+                // MoE branch
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_moe_ffn(cur,
+                        model.layers[il].ffn_gate_inp,
+                        model.layers[il].ffn_up_exps,
+                        model.layers[il].ffn_gate_exps,
+                        model.layers[il].ffn_down_exps,
+                        nullptr,
+                        n_expert, n_expert_used,
+                        LLM_FFN_SILU, true,
+                        false, 0.0,
+                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                        il);
+                cb(cur, "ffn_moe_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_llama_iswa : public llm_graph_context {
+    llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
         // temperature tuning
         ggml_tensor * inp_attn_scale = nullptr;
-        if (arch == LLM_ARCH_LLAMA4) {
-            inp_attn_scale = build_inp_attn_scale();
-        }
+        inp_attn_scale = build_inp_attn_scale();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv_unified_iswa();
 
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
 
-            bool use_rope = arch == LLM_ARCH_LLAMA4
-                ? (il + 1) % hparams.n_no_rope_layer_step != 0
-                : true;
+            const bool use_rope = (il + 1) % hparams.n_no_rope_layer_step != 0;
 
             // norm
             cur = build_norm(inpL,
@@ -4542,7 +4711,7 @@ struct llm_build_llama : public llm_graph_context {
             // self-attention
             {
                 // rope freq factors for llama3; may return nullptr for llama2 and other models
-                ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
+                ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
 
                 // compute Q and K and RoPE them
                 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -4590,7 +4759,7 @@ struct llm_build_llama : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                if (arch == LLM_ARCH_LLAMA4 && use_rope && hparams.use_kq_norm) {
+                if (use_rope && hparams.use_kq_norm) {
                     // Llama4TextL2Norm
                     Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps);
                     Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps);
@@ -4616,7 +4785,6 @@ struct llm_build_llama : public llm_graph_context {
 
             // feed-forward network (non-MoE)
             if (model.layers[il].ffn_gate_inp == nullptr) {
-
                 cur = build_norm(ffn_inp,
                         model.layers[il].ffn_norm, NULL,
                         LLM_NORM_RMS, il);
@@ -4629,9 +4797,7 @@ struct llm_build_llama : public llm_graph_context {
                         NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, il);
                 cb(cur, "ffn_out", il);
-
-            } else if (arch == LLM_ARCH_LLAMA4) {
-                // llama4 MoE
+            } else {
                 ggml_tensor * ffn_inp_normed = build_norm(ffn_inp,
                         model.layers[il].ffn_norm, NULL,
                         LLM_NORM_RMS, il);
@@ -4660,26 +4826,6 @@ struct llm_build_llama : public llm_graph_context {
 
                 cur = ggml_add(ctx0, moe_out, shexp_out);
                 cb(cur, "ffn_moe_out_merged", il);
-
-            } else {
-                // MoE branch
-                cur = build_norm(ffn_inp,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = build_moe_ffn(cur,
-                        model.layers[il].ffn_gate_inp,
-                        model.layers[il].ffn_up_exps,
-                        model.layers[il].ffn_gate_exps,
-                        model.layers[il].ffn_down_exps,
-                        nullptr,
-                        n_expert, n_expert_used,
-                        LLM_FFN_SILU, true,
-                        false, 0.0,
-                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
-                        il);
-                cb(cur, "ffn_moe_out", il);
             }
 
             cur = ggml_add(ctx0, cur, ffn_inp);
@@ -4753,7 +4899,7 @@ struct llm_build_deci : public llm_graph_context {
             } else if (n_head > 0) {
                 // self-attention
                 // rope freq factors for llama3; may return nullptr for llama2 and other models
-                ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
+                ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
 
                 // compute Q and K and RoPE them
                 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -7202,6 +7348,7 @@ struct llm_build_phi2 : public llm_graph_context {
     }
 };
 
+template<bool iswa>
 struct llm_build_phi3 : public llm_graph_context {
     llm_build_phi3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -7217,7 +7364,14 @@ struct llm_build_phi3 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_unified_iswa, llm_graph_input_attn_kv_unified>;
+        inp_attn_type * inp_attn = nullptr;
+
+        if constexpr (iswa) {
+            inp_attn = build_attn_inp_kv_unified_iswa();
+        } else {
+            inp_attn = build_attn_inp_kv_unified();
+        }
 
         for (int il = 0; il < n_layer; ++il) {
             auto * residual = inpL;
@@ -7225,7 +7379,7 @@ struct llm_build_phi3 : public llm_graph_context {
             // self-attention
             {
                 // rope freq factors for 128k context
-                ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
+                ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
 
                 ggml_tensor* attn_norm_output = build_norm(inpL,
                         model.layers[il].attn_norm,
@@ -7977,7 +8131,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
 
-            ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
+            ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
 
             // norm
             cur = build_norm(inpL,
@@ -8277,8 +8431,8 @@ struct llm_build_gemma : public llm_graph_context {
     }
 };
 
-struct llm_build_gemma2 : public llm_graph_context {
-    llm_build_gemma2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+struct llm_build_gemma2_iswa : public llm_graph_context {
+    llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_k;
 
         ggml_tensor * cur;
@@ -8292,7 +8446,7 @@ struct llm_build_gemma2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv_unified_iswa();
 
         for (int il = 0; il < n_layer; ++il) {
             // norm
@@ -8414,8 +8568,8 @@ struct llm_build_gemma2 : public llm_graph_context {
     }
 };
 
-struct llm_build_gemma3 : public llm_graph_context {
-    llm_build_gemma3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+struct llm_build_gemma3_iswa : public llm_graph_context {
+    llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_k;
 
         ggml_tensor * cur;
@@ -8433,13 +8587,11 @@ struct llm_build_gemma3 : public llm_graph_context {
         ggml_tensor * inp_pos = build_inp_pos();
 
         // TODO: is causal == true correct? might need some changes
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv_unified_iswa();
 
         for (int il = 0; il < n_layer; ++il) {
-            const bool is_swa = hparams.is_swa(il);
-
-            const float freq_base_l  = is_swa ? hparams.rope_freq_base_train_swa  : cparams.rope_freq_base;
-            const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
+            const float freq_base_l  = model.get_rope_freq_base (cparams, il);
+            const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
 
             // norm
             cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
@@ -9016,8 +9168,8 @@ struct llm_build_command_r : public llm_graph_context {
     }
 };
 
-struct llm_build_cohere2 : public llm_graph_context {
-    llm_build_cohere2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+struct llm_build_cohere2_iswa : public llm_graph_context {
+    llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -9032,7 +9184,7 @@ struct llm_build_cohere2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv_unified_iswa();
 
         for (int il = 0; il < n_layer; ++il) {
             const bool is_swa = hparams.is_swa(il);
@@ -9045,7 +9197,7 @@ struct llm_build_cohere2 : public llm_graph_context {
             // self-attention
             {
                 // rope freq factors for 128k context
-                ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
+                ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
 
                 // compute Q and K and RoPE them
                 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -9983,7 +10135,7 @@ struct llm_build_deepseek : public llm_graph_context {
             // self-attention
             {
                 // rope freq factors for llama3; may return nullptr for llama2 and other models
-                ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
+                ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
 
                 // compute Q and K and RoPE them
                 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -11347,7 +11499,7 @@ struct llm_build_exaone : public llm_graph_context {
             // self-attention
             {
                 // rope freq factors for llama3; may return nullptr for llama2 and other models
-                ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
+                ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
 
                 // compute Q and K and RoPE them
                 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -12263,7 +12415,7 @@ struct llm_build_granite : public llm_graph_context {
                 Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
                 if (use_rope) {
-                    ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
+                    ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
                     Qcur = ggml_rope_ext(
                             ctx0, Qcur, inp_pos, rope_factors,
                             n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -12916,7 +13068,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
             // self-attention
             {
                 // rope freq factors for llama3; may return nullptr for llama2 and other models
-                ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
+                ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
 
                 // compute Q and K and RoPE them
                 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -13044,6 +13196,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
         case LLM_ARCH_JINA_BERT_V2:
         case LLM_ARCH_NOMIC_BERT:
         case LLM_ARCH_NOMIC_BERT_MOE:
+        case LLM_ARCH_WAVTOKENIZER_DEC:
             {
                 res = nullptr;
             } break;
@@ -13058,7 +13211,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                         GGML_TYPE_F32,
                         GGML_TYPE_F32,
                         cparams.offload_kqv,
-                        std::max((uint32_t) 1, cparams.n_seq_max));
+                        std::max((uint32_t) 1, cparams.n_seq_max),
+                        cparams.n_seq_max);
             } break;
         default:
             {
@@ -13068,14 +13222,36 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
 
                 LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
 
-                res = new llama_kv_cache_unified(
-                        *this,
-                        params.type_k,
-                        params.type_v,
-                        !cparams.flash_attn,
-                        cparams.offload_kqv,
-                        cparams.n_ctx,
-                        padding);
+                if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
+                    GGML_ASSERT(hparams.is_swa_any());
+
+                    res = new llama_kv_cache_unified_iswa(
+                            *this,
+                            params.type_k,
+                            params.type_v,
+                            !cparams.flash_attn,
+                            cparams.offload_kqv,
+                            params.swa_full,
+                            cparams.n_ctx,
+                            cparams.n_seq_max,
+                            cparams.n_batch,
+                            padding);
+                } else {
+                    GGML_ASSERT(!hparams.is_swa_any());
+
+                    res = new llama_kv_cache_unified(
+                            *this,
+                            nullptr,
+                            params.type_k,
+                            params.type_v,
+                            !cparams.flash_attn,
+                            cparams.offload_kqv,
+                            cparams.n_ctx,
+                            cparams.n_seq_max,
+                            padding,
+                            hparams.n_swa,
+                            hparams.swa_type);
+                }
             }
     }
 
@@ -13090,11 +13266,14 @@ llm_graph_result_ptr llama_model::build_graph(
 
     switch (arch) {
         case LLM_ARCH_LLAMA:
-        case LLM_ARCH_LLAMA4:
         case LLM_ARCH_MINICPM:
             {
                 llm = std::make_unique<llm_build_llama>(*this, params, gf);
             } break;
+        case LLM_ARCH_LLAMA4:
+            {
+                llm = std::make_unique<llm_build_llama_iswa>(*this, params, gf);
+            } break;
         case LLM_ARCH_DECI:
             {
                 llm = std::make_unique<llm_build_deci>(*this, params, gf);
@@ -13169,7 +13348,11 @@ llm_graph_result_ptr llama_model::build_graph(
         case LLM_ARCH_PHI3:
         case LLM_ARCH_PHIMOE:
             {
-                llm = std::make_unique<llm_build_phi3>(*this, params, gf);
+                if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
+                    llm = std::make_unique<llm_build_phi3<true>> (*this, params, gf);
+                } else {
+                    llm = std::make_unique<llm_build_phi3<false>>(*this, params, gf);
+                }
             } break;
         case LLM_ARCH_PLAMO:
             {
@@ -13201,11 +13384,11 @@ llm_graph_result_ptr llama_model::build_graph(
             } break;
         case LLM_ARCH_GEMMA2:
             {
-                llm = std::make_unique<llm_build_gemma2>(*this, params, gf);
+                llm = std::make_unique<llm_build_gemma2_iswa>(*this, params, gf);
             } break;
         case LLM_ARCH_GEMMA3:
             {
-                llm = std::make_unique<llm_build_gemma3>(*this, params, gf);
+                llm = std::make_unique<llm_build_gemma3_iswa>(*this, params, gf);
             } break;
         case LLM_ARCH_STARCODER2:
             {
@@ -13225,7 +13408,7 @@ llm_graph_result_ptr llama_model::build_graph(
             } break;
         case LLM_ARCH_COHERE2:
             {
-                llm = std::make_unique<llm_build_cohere2>(*this, params, gf);
+                llm = std::make_unique<llm_build_cohere2_iswa>(*this, params, gf);
             } break;
         case LLM_ARCH_DBRX:
             {
index 6bdec263b709b2b027db73799aaa71b5f7326225..cbea2cb331b626f6ca2f829a186ec0822b20ce76 100644 (file)
@@ -398,7 +398,10 @@ struct llama_model {
 
     const struct ggml_tensor * get_tensor(const char * name) const;
 
-    ggml_tensor * get_rope_factors(uint32_t n_ctx_per_seq, int il) const;
+    float get_rope_freq_base (const llama_cparams & cparams, int il) const;
+    float get_rope_freq_scale(const llama_cparams & cparams, int il) const;
+
+    ggml_tensor * get_rope_factors(const llama_cparams & cparams, int il) const;
 
     // note: can mutate `cparams`
     // TODO: move this to new llm_arch_model_i interface
index 804b11e0a943e9625c78516c5da629ec91261968..bfbf5fa23011240c0dec57b390670ef1ff47079b 100644 (file)
@@ -798,7 +798,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d
         }
 
         // if we have enough values the operation was a success
-        if (filtered_tokens.size() >= ctx->min_keep) {
+        if (!filtered_tokens.empty() && filtered_tokens.size() >= ctx->min_keep) {
             memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
             cur_p->size = filtered_tokens.size();
             min_p_applied = true;
@@ -909,7 +909,7 @@ static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token
         cum_sum += cur_p->data[idx].p;
 
         // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
-        if (cum_sum > ctx->p && i >= ctx->min_keep - 1) {
+        if (cum_sum > ctx->p && (ctx->min_keep == 0 || i >= ctx->min_keep - 1)) {
             last_idx = i + 1;
             break;
         }
index 9389ca805a584fadcb9110b0841f95234907b8b7..d5a036a8c4413cb91c6d242ec2f049ced6b3cc62 100644 (file)
@@ -835,7 +835,7 @@ struct llm_tokenizer_ugm_session {
         }
 
         // initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores
-        std::vector<struct best_tokenization> tokenization_results(input_len + 1, {vocab.token_unk(), 0, -FLT_MAX});
+        std::vector<struct best_tokenization> tokenization_results(input_len + 1, {vocab.token_unk(), 0, -DBL_MAX});
         // at the beginning tokenization score is zero
         tokenization_results[0] = { vocab.token_unk(), 0, 0 };
 
@@ -867,7 +867,7 @@ struct llm_tokenizer_ugm_session {
                     const double challenger_score = current_best.score_sum + token_score;
                     struct best_tokenization & current_champ = tokenization_results[prefix_offset];
                     if (challenger_score > current_champ.score_sum) {
-                        struct best_tokenization challenger = { token_id, input_offset, (float) challenger_score };
+                        struct best_tokenization challenger = { token_id, input_offset, challenger_score };
                         current_champ = challenger;
                     }
                 }
@@ -881,7 +881,7 @@ struct llm_tokenizer_ugm_session {
                 prefix_offset = input_offset + n_utf8_code_units;
                 struct best_tokenization & current_champ = tokenization_results[prefix_offset];
                 if (challenger_score > current_champ.score_sum) {
-                    struct best_tokenization challenger = { vocab.token_unk(), input_offset, (float) challenger_score };
+                    struct best_tokenization challenger = { vocab.token_unk(), input_offset, challenger_score };
                     current_champ = challenger;
                 }
             }
@@ -1007,7 +1007,7 @@ private:
     struct best_tokenization {
         llama_token token_id;
         size_t input_offset;
-        float score_sum;
+        double score_sum;
     };
 
     struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) {
index 99e5fba244fcc2902fa9e97a085392820ee3975e..01762bea2bf962c33d30c4b72da6c05ce4bbf26f 100644 (file)
@@ -361,10 +361,11 @@ extern "C" {
 
         // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
         bool embeddings;  // if true, extract embeddings (together with logits)
-        bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
-        bool flash_attn;  // whether to use flash attention [EXPERIMENTAL]
-        bool no_perf;     // whether to measure performance timings
-        bool op_offload;  // whether to offload host tensor operations to device
+        bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU
+        bool flash_attn;  // use flash attention [EXPERIMENTAL]
+        bool no_perf;     // measure performance timings
+        bool op_offload;  // offload host tensor operations to device
+        bool swa_full;    // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
     };
 
     // model quantization parameters
@@ -470,6 +471,7 @@ extern "C" {
     LLAMA_API int64_t llama_time_us(void);
 
     LLAMA_API size_t llama_max_devices(void);
+    LLAMA_API size_t llama_max_parallel_sequences(void);
 
     LLAMA_API bool llama_supports_mmap       (void);
     LLAMA_API bool llama_supports_mlock      (void);
@@ -607,71 +609,14 @@ extern "C" {
     // KV cache
     //
 
-    // TODO: start using struct llama_kv_cache
-
-    // Information associated with an individual cell in the KV cache view.
-    struct llama_kv_cache_view_cell {
-        // The position for this cell. Takes KV cache shifts into account.
-        // May be negative if the cell is not populated.
-        llama_pos pos;
-    };
-
-    // An updateable view of the KV cache.
-    struct llama_kv_cache_view {
-        // Number of KV cache cells. This will be the same as the context size.
-        int32_t n_cells;
-
-        // Maximum number of sequences that can exist in a cell. It's not an error
-        // if there are more sequences in a cell than this value, however they will
-        // not be visible in the view cells_sequences.
-        int32_t n_seq_max;
-
-        // Number of tokens in the cache. For example, if there are two populated
-        // cells, the first with 1 sequence id in it and the second with 2 sequence
-        // ids then you'll have 3 tokens.
-        int32_t token_count;
-
-        // Number of populated cache cells.
-        int32_t used_cells;
-
-        // Maximum contiguous empty slots in the cache.
-        int32_t max_contiguous;
-
-        // Index to the start of the max_contiguous slot range. Can be negative
-        // when cache is full.
-        int32_t max_contiguous_idx;
-
-        // Information for an individual cell.
-        struct llama_kv_cache_view_cell * cells;
-
-        // The sequences for each cell. There will be n_seq_max items per cell.
-        llama_seq_id * cells_sequences;
-    };
-
-    // Create an empty KV cache view. (use only for debugging purposes)
-    LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max);
-
-    // Free a KV cache view. (use only for debugging purposes)
-    LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
-
-    // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
-    // TODO: change signature to llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_context * ctx)
-    LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
-
-    ///
-
     // Returns the number of tokens in the KV cache (slow, use only for debug)
     // If a KV cell has multiple sequences assigned to it, it will be counted multiple times
-    LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx);
-
-    DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx),
-            "use llama_kv_self_n_tokens instead");
+    DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx),
+               "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
 
     // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
-    LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx);
-
-    DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx),
-            "use llama_kv_self_used_cells instead");
+    DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx),
+               "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
 
     // Clear the KV cache - both cell info is erased and KV data is zeroed
     LLAMA_API void llama_kv_self_clear(
@@ -730,10 +675,18 @@ extern "C" {
                        llama_pos   p1,
                              int   d);
 
+    // Returns the smallest position present in the KV cache for the specified sequence
+    // This is typically non-zero only for SWA caches
+    // Return -1 if the sequence is empty
+    LLAMA_API llama_pos llama_kv_self_seq_pos_min(
+            struct llama_context * ctx,
+                    llama_seq_id   seq_id);
+
     // Returns the largest position present in the KV cache for the specified sequence
+    // Return -1 if the sequence is empty
     LLAMA_API llama_pos llama_kv_self_seq_pos_max(
             struct llama_context * ctx,
-                     llama_seq_id   seq_id);
+                    llama_seq_id   seq_id);
 
     // Defragment the KV cache
     // This will be applied:
@@ -747,61 +700,6 @@ extern "C" {
     // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
     LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
 
-    DEPRECATED(LLAMA_API void llama_kv_cache_clear(
-            struct llama_context * ctx),
-            "use llama_kv_self_clear instead");
-
-    DEPRECATED(LLAMA_API bool llama_kv_cache_seq_rm(
-            struct llama_context * ctx,
-                    llama_seq_id   seq_id,
-                       llama_pos   p0,
-                       llama_pos   p1),
-            "use llama_kv_self_seq_rm instead");
-
-    DEPRECATED(LLAMA_API void llama_kv_cache_seq_cp(
-            struct llama_context * ctx,
-                    llama_seq_id   seq_id_src,
-                    llama_seq_id   seq_id_dst,
-                       llama_pos   p0,
-                       llama_pos   p1),
-            "use llama_kv_self_seq_cp instead");
-
-    DEPRECATED(LLAMA_API void llama_kv_cache_seq_keep(
-            struct llama_context * ctx,
-                    llama_seq_id   seq_id),
-            "use llama_kv_self_seq_keep instead");
-
-    DEPRECATED(LLAMA_API void llama_kv_cache_seq_add(
-            struct llama_context * ctx,
-                    llama_seq_id   seq_id,
-                       llama_pos   p0,
-                       llama_pos   p1,
-                       llama_pos   delta),
-            "use llama_kv_self_seq_add instead");
-
-    DEPRECATED(LLAMA_API void llama_kv_cache_seq_div(
-            struct llama_context * ctx,
-                    llama_seq_id   seq_id,
-                       llama_pos   p0,
-                       llama_pos   p1,
-                             int   d),
-            "use llama_kv_self_seq_div instead");
-
-    DEPRECATED(LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
-            struct llama_context * ctx,
-                    llama_seq_id   seq_id),
-            "use llama_kv_self_seq_pos_max instead");
-
-    DEPRECATED(LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx),
-            "use llama_kv_self_defrag instead");
-
-    DEPRECATED(LLAMA_API bool llama_kv_cache_can_shift(const struct llama_context * ctx),
-            "use llama_kv_self_can_shift instead");
-
-    DEPRECATED(LLAMA_API void llama_kv_cache_update(struct llama_context * ctx),
-            "use llama_kv_self_update instead");
-
-
     //
     // State / sessions
     //
@@ -943,9 +841,12 @@ extern "C" {
     // Requires KV cache.
     // For encode-decoder contexts, processes the batch using the decoder.
     // Positive return values does not mean a fatal error, but rather a warning.
-    //   0 - success
-    //   1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
-    // < 0 - error. the KV cache state is restored to the state before this call
+    // Upon non-zero return values, the KV cache state is restored to the state before this call
+    //    0 - success
+    //    1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
+    //    2 - aborted
+    //   -1 - invalid input batch
+    // < -1 - error
     LLAMA_API int32_t llama_decode(
             struct llama_context * ctx,
               struct llama_batch   batch);