]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
kv-cache : fix split_equal handling in unified implementation (#14130)
authorGeorgi Gerganov <redacted>
Thu, 12 Jun 2025 07:02:15 +0000 (10:02 +0300)
committerGitHub <redacted>
Thu, 12 Jun 2025 07:02:15 +0000 (10:02 +0300)
ggml-ci

src/llama-context.cpp
src/llama-kv-cache-unified-iswa.cpp
src/llama-kv-cache-unified.cpp

index 525a00d8adb952b0f5b034b15e48b54e2e1576f8..8cea21d6989efd5a1b7aad0486ac03a35bcf5b57 100644 (file)
@@ -877,6 +877,8 @@ int llama_context::encode(llama_batch & inp_batch) {
         memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
 
         // remember the sequence ids used during the encoding - needed for cross attention later
+        // TODO: the seuqence indexing here is likely not correct in the general case
+        //       probably works only for split_simple
         cross.seq_ids_enc.resize(n_tokens);
         for (int32_t i = 0; i < n_tokens; i++) {
             cross.seq_ids_enc[i].clear();
index 28d18265476497e42b93087b8c336a360d6e0d13..caa58ea9aa3b0e6274c264fcca3cc309bc2383a1 100644 (file)
@@ -98,33 +98,66 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
 llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
     GGML_UNUSED(embd_pooled);
 
-    // TODO: if we fail with split_simple, we should attempt different splitting strategies
-    //       but to do that properly, we first have to refactor the batches to be more flexible
+    // first try simple split
+    do {
+        auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
 
-    auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
+        std::vector<llama_ubatch> ubatches;
 
-    std::vector<llama_ubatch> ubatches;
+        while (sbatch.n_tokens > 0) {
+            auto ubatch = sbatch.split_simple(n_ubatch);
 
-    while (sbatch.n_tokens > 0) {
-        auto ubatch = sbatch.split_simple(n_ubatch);
+            ubatches.push_back(ubatch);
+        }
 
-        ubatches.push_back(ubatch);
-    }
+        auto heads_base = kv_base->prepare(ubatches);
+        if (heads_base.empty()) {
+            break;
+        }
 
-    auto heads_base = kv_base->prepare(ubatches);
-    if (heads_base.empty()) {
-        return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
-    }
+        auto heads_swa = kv_swa->prepare(ubatches);
+        if (heads_swa.empty()) {
+            break;
+        }
 
-    auto heads_swa = kv_swa->prepare(ubatches);
-    if (heads_swa.empty()) {
-        return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
-    }
+        assert(heads_base.size() == heads_swa.size());
+
+        return std::make_unique<llama_kv_cache_unified_iswa_state>(
+                this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
+    } while (false);
+
+    // if it fails, try equal split
+    do {
+        auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
+
+        std::vector<llama_ubatch> ubatches;
 
-    assert(heads_base.size() == heads_swa.size());
+        while (sbatch.n_tokens > 0) {
+            auto ubatch = sbatch.split_equal(n_ubatch);
+
+            ubatches.push_back(ubatch);
+        }
+
+        auto heads_base = kv_base->prepare(ubatches);
+        if (heads_base.empty()) {
+            break;
+        }
+
+        auto heads_swa = kv_swa->prepare(ubatches);
+        if (heads_swa.empty()) {
+            break;
+        }
+
+        assert(heads_base.size() == heads_swa.size());
+
+        return std::make_unique<llama_kv_cache_unified_iswa_state>(
+                this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
+    } while (false);
+
+    // TODO: if we fail again, we should attempt different splitting strategies
+    //       but to do that properly, we first have to refactor the batches to be more flexible
 
-    return std::make_unique<llama_kv_cache_unified_iswa_state>(
-            this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
+    return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
 }
 
 llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
index 1a9f4e3159f94d83b03b18dd3393c875309aec97..ddeb138f38fb960533390ac0153e0362ad5a3a42 100644 (file)
@@ -314,20 +314,24 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
             bool logits_all) {
     GGML_UNUSED(embd_pooled);
 
-    auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
+    do {
+        auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
 
-    std::vector<llama_ubatch> ubatches;
-    while (sbatch.n_tokens > 0) {
-        ubatches.push_back(sbatch.split_simple(n_ubatch));
-    }
+        std::vector<llama_ubatch> ubatches;
+        while (sbatch.n_tokens > 0) {
+            ubatches.push_back(sbatch.split_simple(n_ubatch));
+        }
 
-    auto heads = prepare(ubatches);
-    if (heads.empty()) {
-        return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
-    }
+        auto heads = prepare(ubatches);
+        if (heads.empty()) {
+            break;
+        }
 
-    return std::make_unique<llama_kv_cache_unified_state>(
-            this, std::move(sbatch), std::move(heads), std::move(ubatches));
+        return std::make_unique<llama_kv_cache_unified_state>(
+                this, std::move(sbatch), std::move(heads), std::move(ubatches));
+    } while (false);
+
+    return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
 }
 
 llama_memory_state_ptr llama_kv_cache_unified::init_full() {
@@ -521,7 +525,6 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
     }
 
     if (debug > 0) {
-        LLAMA_LOG_CONT("\n");
         LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa);
 
         if ((debug == 2 && n_swa > 0) || debug > 2) {
@@ -530,7 +533,13 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
                 if (cells.is_empty(i)) {
                     ss += '.';
                 } else {
-                    ss += std::to_string(cells.seq_get(i));
+                    assert(cells.seq_count(i) >= 1);
+
+                    if (cells.seq_count(i) == 1) {
+                        ss += std::to_string(cells.seq_get(i));
+                    } else {
+                        ss += 'M';
+                    }
                 }
                 if (i%256 == 255) {
                     ss += " *";
@@ -636,6 +645,12 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
 }
 
 void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
+    if (debug > 0) {
+        LLAMA_LOG_DEBUG("%s: ubatch info:\n", __func__);
+        LLAMA_LOG_DEBUG("%s:   n_tokens = %d, equal_seqs = %d\n", __func__, ubatch.n_tokens, ubatch.equal_seqs);
+        LLAMA_LOG_DEBUG("%s:   n_seq_tokens = %d, n_seqs = %d\n", __func__, ubatch.n_seq_tokens, ubatch.n_seqs);
+    }
+
     // keep track of the max sequence position that we would overwrite with this ubatch
     // for non-SWA cache, this would be always empty
     llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES];
@@ -643,22 +658,26 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
         seq_pos_max_rm[s] = -1;
     }
 
-    for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
-        if (!cells.is_empty(head_cur + i)) {
-            assert(cells.seq_count(head_cur + i) == 1);
+    for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
+        for (uint32_t j = 0; j < ubatch.n_seq_tokens; ++j) {
+            const uint32_t idx = s*ubatch.n_seq_tokens + j;
 
-            const llama_seq_id seq_id = cells.seq_get(head_cur + i);
-            const llama_pos    pos    = cells.pos_get(head_cur + i);
+            if (!cells.is_empty(head_cur + idx)) {
+                assert(cells.seq_count(head_cur + idx) == 1);
 
-            seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
+                const llama_seq_id seq_id = cells.seq_get(head_cur + idx);
+                const llama_pos    pos    = cells.pos_get(head_cur + idx);
 
-            cells.rm(head_cur + i);
-        }
+                seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
+
+                cells.rm(head_cur + idx);
+            }
 
-        cells.pos_set(head_cur + i, ubatch.pos[i]);
+            cells.pos_set(head_cur + idx, ubatch.pos[idx]);
 
-        for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
-            cells.seq_add(head_cur + i, ubatch.seq_id[i][j]);
+            for (int32_t i = 0; i < ubatch.n_seq_id[s]; i++) {
+                cells.seq_add(head_cur + idx, ubatch.seq_id[s][i]);
+            }
         }
     }
 
@@ -677,7 +696,6 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
             seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
         }
     }
-
     // move the head at the end of the slot
     head = head_cur + ubatch.n_tokens;
 }
@@ -774,14 +792,14 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
 }
 
 void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
-    const int64_t n_tokens     = ubatch->n_tokens;
-    const int64_t n_seq_tokens = ubatch->n_seq_tokens;
-    const int64_t n_seqs       = ubatch->n_seqs;
+    const uint32_t n_tokens     = ubatch->n_tokens;
+    const uint32_t n_seq_tokens = ubatch->n_seq_tokens;
+    const uint32_t n_seqs       = ubatch->n_seqs;
 
     GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
     float * data = (float *) dst->data;
 
-    const auto n_kv = dst->ne[0];
+    const int64_t n_kv = dst->ne[0];
 
     // Use only the previous KV cells of the correct sequence for each token of the ubatch.
     // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
@@ -795,12 +813,14 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
     //      xxxxx-----
     //      xxxxx-----
     // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
-    for (int h = 0; h < 1; ++h) {
-        for (int s = 0; s < n_seqs; ++s) {
+    for (uint32_t h = 0; h < 1; ++h) {
+        for (uint32_t s = 0; s < n_seqs; ++s) {
             const llama_seq_id seq_id = ubatch->seq_id[s][0];
 
-            for (int j = 0; j < n_seq_tokens; ++j) {
-                const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
+            for (uint32_t j = 0; j < n_seq_tokens; ++j) {
+                const uint32_t idx = s*n_seq_tokens + j;
+
+                const llama_pos p1 = ubatch->pos[idx];
 
                 for (uint32_t i = 0; i < n_kv; ++i) {
                     float f = 0.0f;
@@ -830,16 +850,16 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
                         f = -INFINITY;
                     }
 
-                    data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
+                    data[h*(n_kv*n_tokens) + idx*n_kv + i] = f;
                 }
             }
         }
 
         // mask padded tokens
         if (data) {
-            for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
-                for (uint32_t j = 0; j < n_kv; ++j) {
-                    data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
+            for (uint32_t j = n_tokens; j < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++j) {
+                for (uint32_t i = 0; i < n_kv; ++i) {
+                    data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
                 }
             }
         }
@@ -1490,9 +1510,11 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
         seq_rm(dest_seq_id, -1, -1);
 
         llama_sbatch sbatch;
-        llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
+        llama_ubatch ubatch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
 
-        batch.n_tokens = cell_count;
+        ubatch.n_tokens = cell_count;
+        ubatch.n_seq_tokens = cell_count;
+        ubatch.n_seqs = 1;
 
         for (uint32_t i = 0; i < cell_count; ++i) {
             llama_pos pos;
@@ -1512,18 +1534,18 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
                 io.read_to(&seq_id, sizeof(seq_id));
             }
 
-            batch.pos[i]      = pos;
-            batch.n_seq_id[i] = n_seq_id;
-            batch.seq_id[i]   = &dest_seq_id;
+            ubatch.pos[i]      = pos;
+            ubatch.n_seq_id[i] = n_seq_id;
+            ubatch.seq_id[i]   = &dest_seq_id;
         }
 
-        const auto head_cur = find_slot(batch);
+        const auto head_cur = find_slot(ubatch);
         if (head_cur < 0) {
             LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
             return false;
         }
 
-        apply_ubatch(head_cur, batch);
+        apply_ubatch(head_cur, ubatch);
 
         // keep the head at the old position because we will read the KV data into it in state_read_data()
         head = head_cur;
@@ -1531,8 +1553,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
         // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
         // Assume that this is one contiguous block of cells
         GGML_ASSERT(head_cur + cell_count <= cells.size());
-        GGML_ASSERT(cells.pos_get(head_cur)                  == batch.pos[0]);
-        GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == batch.pos[cell_count - 1]);
+        GGML_ASSERT(cells.pos_get(head_cur)                  == ubatch.pos[0]);
+        GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
         GGML_ASSERT(cells.seq_has(head_cur,                  dest_seq_id));
         GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
     } else {