]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
kv-cache : better estimate of n_kv for multi-sequence batches (#15610)
authorGeorgi Gerganov <redacted>
Wed, 27 Aug 2025 10:55:12 +0000 (13:55 +0300)
committerGitHub <redacted>
Wed, 27 Aug 2025 10:55:12 +0000 (13:55 +0300)
ggml-ci

src/llama-kv-cache.cpp
src/llama-kv-cache.h

index d7ab56ccd9aacf9f7d356067ae1b0cd51bf5c878..920c1d0dbdc745b99ef4d16301c0a63a57e77d28 100644 (file)
@@ -771,8 +771,8 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch,
             GGML_ASSERT(ubatch.seq_id  [s*n_tokens][0] == seq_id);
         }
 
-        res.s0 = std::min<llama_seq_id>(res.s0, seq_to_stream[seq_id]);
-        res.s1 = std::max<llama_seq_id>(res.s1, seq_to_stream[seq_id]);
+        res.s0 = std::min<uint32_t>(res.s0, seq_to_stream[seq_id]);
+        res.s1 = std::max<uint32_t>(res.s1, seq_to_stream[seq_id]);
 
         res.strm[s] = seq_to_stream[seq_id];
         res.idxs[s].reserve(n_tokens);
@@ -964,11 +964,11 @@ bool llama_kv_cache::get_has_shift() const {
     return result;
 }
 
-uint32_t llama_kv_cache::get_n_kv() const {
+uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
     uint32_t result = 0;
 
-    for (uint32_t s = 0; s < n_stream; ++s) {
-        const auto & cells = v_cells[s];
+    for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+        const auto & cells = v_cells[sinfo.strm[s]];
 
         result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
     }
@@ -1017,18 +1017,18 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k
         // note: v->nb[1] <= v->nb[2]
         return ggml_view_4d(ctx, v,
                 hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
-                ggml_row_size(v->type, hparams.n_embd_head_v),            // v->nb[1]
-                ggml_row_size(v->type, n_embd_v_gqa),         // v->nb[2]
-                ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
+                ggml_row_size(v->type, hparams.n_embd_head_v),          // v->nb[1]
+                ggml_row_size(v->type, n_embd_v_gqa),                   // v->nb[2]
+                ggml_row_size(v->type, n_embd_v_gqa*kv_size),           // v->nb[3]
                 ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
     }
 
     // note: v->nb[1] > v->nb[2]
     return ggml_view_4d(ctx, v,
             n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
-            ggml_row_size(v->type, kv_size*hparams.n_embd_head_v),    // v->nb[1]
-            ggml_row_size(v->type, kv_size),                          // v->nb[2]
-            ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
+            ggml_row_size(v->type, kv_size*hparams.n_embd_head_v),  // v->nb[1]
+            ggml_row_size(v->type, kv_size),                        // v->nb[2]
+            ggml_row_size(v->type, kv_size*n_embd_v_gqa),           // v->nb[3]
             ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
 }
 
@@ -1985,8 +1985,7 @@ bool llama_kv_cache_context::apply() {
     }
 
     kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
-
-    n_kv = kv->get_n_kv();
+    n_kv = kv->get_n_kv(sinfos[i_cur]);
 
     return true;
 }
index 76a5cb1e28e7e642b31ab7376fd7bc5305221424..3ca82917d3237d25c7a138263fb2197ffc0f9e76 100644 (file)
@@ -38,8 +38,8 @@ public:
         using idx_vec_t = std::vector<uint32_t>;
 
         // number of streams: ns = s1 - s0 + 1
-        llama_seq_id s0;
-        llama_seq_id s1;
+        uint32_t s0;
+        uint32_t s1;
 
         std::vector<llama_seq_id> strm; // [ns]
         std::vector<idx_vec_t>    idxs; // [ns]
@@ -139,7 +139,7 @@ public:
     // graph_build API
     //
 
-    uint32_t get_n_kv() const;
+    uint32_t get_n_kv(const slot_info & sinfo) const;
 
     // TODO: temporary
     bool get_supports_set_rows() const;