const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
+ slot_info sinfo;
+
bool res = true;
- res = res && state_read_meta(io, strm, cell_count, seq_id);
- res = res && state_read_data(io, strm, cell_count);
+ res = res && state_read_meta(io, strm, cell_count, sinfo, seq_id);
+ res = res && state_read_data(io, strm, cell_count, sinfo);
if (!res) {
if (seq_id == -1) {
}
}
-bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
+bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id) {
auto & cells = v_cells[strm];
auto & head = v_heads[strm];
ubatch.seq_id[i] = &dest_seq_id;
}
- const auto sinfo = find_slot(ubatch, true);
+ sinfo = find_slot(ubatch, false);
if (sinfo.empty()) {
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
return false;
// see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
apply_ubatch(sinfo, ubatch);
- const auto head_cur = sinfo.head();
-
- // keep the head at the old position because we will read the KV data into it in state_read_data()
- head = head_cur;
-
- LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id);
+ LLAMA_LOG_DEBUG("%s: cell_count = %d, dest_seq_id = %d\n", __func__, cell_count, dest_seq_id);
- // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
- // Assume that this is one contiguous block of cells
- GGML_ASSERT(head_cur + cell_count <= cells.size());
- GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]);
- GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
- GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
- GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
+ // DEBUG CHECK: verify that all cells were allocated and have correct seq_id and pos values
+ GGML_ASSERT(sinfo.n_stream() == 1);
+ GGML_ASSERT(sinfo.idxs[0].size() == cell_count);
+ for (uint32_t i = 0; i < cell_count; ++i) {
+ const uint32_t idx = sinfo.idxs[0][i];
+ GGML_ASSERT(cells.pos_get(idx) == ubatch.pos[i]);
+ GGML_ASSERT(cells.seq_has(idx, dest_seq_id));
+ }
} else {
// whole KV cache restore
}
}
+ // Create contiguous slot_info for whole cache restore
+ sinfo.s0 = strm;
+ sinfo.s1 = strm;
+ sinfo.resize(1);
+ sinfo.strm[0] = strm;
+ sinfo.idxs[0].resize(cell_count);
+ for (uint32_t i = 0; i < cell_count; ++i) {
+ sinfo.idxs[0][i] = i;
+ }
+
head = 0;
}
return true;
}
-bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
+bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo) {
auto & cells = v_cells[strm];
- auto & head = v_heads[strm];
uint32_t v_trans;
uint32_t n_layer;
}
if (cell_count) {
- // Read and set the keys for the whole cell range
- ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
+ if (sinfo.is_contiguous()) {
+ // Fast path: contiguous cells, single memcpy
+ ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), sinfo.head() * k_size_row, cell_count * k_size_row);
+ } else {
+ // Slow path: scatter to non-contiguous positions
+ const void * src = io.read(cell_count * k_size_row);
+ for (uint32_t i = 0; i < cell_count; ++i) {
+ const size_t dst_offset = sinfo.idxs[0][i] * k_size_row;
+ ggml_backend_tensor_set(k, (const char*)src + i * k_size_row, dst_offset, k_size_row);
+ }
+ }
}
}
}
if (cell_count) {
- // Read and set the values for the whole cell range
- ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
+ if (sinfo.is_contiguous()) {
+ // Fast path: contiguous cells, single memcpy
+ ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), sinfo.head() * v_size_row, cell_count * v_size_row);
+ } else {
+ // Slow path: scatter to non-contiguous positions
+ const void * src = io.read(cell_count * v_size_row);
+ for (uint32_t i = 0; i < cell_count; ++i) {
+ const size_t dst_offset = sinfo.idxs[0][i] * v_size_row;
+ ggml_backend_tensor_set(v, (const char*)src + i * v_size_row, dst_offset, v_size_row);
+ }
+ }
}
}
} else {
}
if (cell_count) {
- // For each row in the transposed matrix, read the values for the whole cell range
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
- const size_t dst_offset = (head + j * cells.size()) * v_size_el;
- ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
+ if (sinfo.is_contiguous()) {
+ // Fast path: contiguous cells
+ const uint32_t h = sinfo.head();
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+ const size_t dst_offset = (h + j * cells.size()) * v_size_el;
+ ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
+ }
+ } else {
+ // Slow path: scatter to non-contiguous positions
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+ const void * src = io.read(cell_count * v_size_el);
+ for (uint32_t i = 0; i < cell_count; ++i) {
+ const size_t dst_offset = (sinfo.idxs[0][i] + j * cells.size()) * v_size_el;
+ ggml_backend_tensor_set(v, (const char*)src + i * v_size_el, dst_offset, v_size_el);
+ }
+ }
}
}
}
void clear() {
idxs.clear();
}
+
+ // check if indices are contiguous starting from head()
+ bool is_contiguous() const {
+ if (idxs.empty() || idxs[0].empty()) {
+ return true;
+ }
+ if (idxs.size() > 1) {
+ return false;
+ }
+ const uint32_t h = idxs[0][0];
+ for (size_t i = 0; i < idxs[0].size(); ++i) {
+ if (idxs[0][i] != h + i) {
+ return false;
+ }
+ }
+ return true;
+ }
};
using slot_info_vec_t = std::vector<slot_info>;
void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
- bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
- bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
+ bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id = -1);
+ bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo);
};
class llama_kv_cache_context : public llama_memory_context_i {
llama_build_and_test(test-model-load-cancel.cpp LABEL "model")
llama_build_and_test(test-autorelease.cpp LABEL "model")
+# Test for state restore with fragmented KV cache
+# Requires a model, uses same args pattern as test-thread-safety
+if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x")
+ llama_build_and_test(test-state-restore-fragmented.cpp LABEL "model" ARGS -hf ggml-org/models -hff tinyllamas/stories15M-q4_0.gguf)
+else()
+ llama_build_and_test(test-state-restore-fragmented.cpp LABEL "model" ARGS -hf ggml-org/models -hff tinyllamas/stories15M-be.Q4_0.gguf)
+endif()
+
if (NOT GGML_BACKEND_DL)
# these tests use the backends directly and cannot be built with dynamic loading
llama_build_and_test(test-barrier.cpp)
--- /dev/null
+// Test for state restore with fragmented KV cache
+// This tests the fix for: https://github.com/ggml-org/llama.cpp/issues/17527
+// The issue was that state restore required contiguous KV cache slots,
+// which fails when the cache is fragmented.
+//
+// The fix changes find_slot(ubatch, true) to find_slot(ubatch, false)
+// in state_read_meta(), allowing non-contiguous slot allocation.
+
+#include "arg.h"
+#include "common.h"
+#include "llama.h"
+
+#include <vector>
+#include <cstdio>
+#include <cstring>
+
+int main(int argc, char ** argv) {
+ common_params params;
+
+ params.sampling.seed = 1234;
+ params.kv_unified = true;
+ params.n_parallel = 3;
+ params.n_ctx = 256;
+
+ if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
+ return 1;
+ }
+
+ common_init();
+
+ // init
+ common_init_result_ptr llama_init = common_init_from_params(params);
+
+ llama_model * model = llama_init->model();
+ llama_context * ctx = llama_init->context();
+
+ if (model == nullptr || ctx == nullptr) {
+ fprintf(stderr, "%s : failed to init\n", __func__);
+ return 1;
+ }
+
+ GGML_UNUSED(model);
+
+ // tokenize prompt
+ std::vector<llama_token> tokens(70, 1);
+
+ // interleave the 3 sequences:
+ // 01201230123...
+ llama_batch batch = llama_batch_init(params.n_parallel*tokens.size(), 0, 1);
+ for (size_t i = 0; i < tokens.size(); i++) {
+ for (int s = 0; s < params.n_parallel; ++s) {
+ common_batch_add(batch, tokens[i], i, {s}, false);
+ }
+ }
+ batch.logits[batch.n_tokens - 1] = true;
+
+ if (llama_decode(ctx, batch)) {
+ fprintf(stderr, "%s : failed to decode seq 0\n", __func__);
+ return 1;
+ }
+
+ fprintf(stderr, "%s : processed prompt on seq 0, 1, 2 (%zu tokens each)\n", __func__, tokens.size());
+
+ // Save state of seq 1
+ std::vector<uint8_t> seq_state(llama_state_seq_get_size(ctx, 1));
+ const size_t ncopy = llama_state_seq_get_data(ctx, seq_state.data(), seq_state.size(), 1);
+ if (ncopy != seq_state.size()) {
+ fprintf(stderr, "%s : failed to save seq 1 state\n", __func__);
+ return 1;
+ }
+ fprintf(stderr, "%s : saved seq 1 state, %zu bytes\n", __func__, ncopy);
+
+ // clear seq 1 to create a "hole" in the KV cache (fragmentation)
+ // 0.20.20.20.2....
+ llama_memory_t mem = llama_get_memory(ctx);
+ llama_memory_seq_rm(mem, 1, -1, -1);
+ fprintf(stderr, "%s : cleared seq 1 to create fragmentation\n", __func__);
+
+ // Now the cache has holes where seq 1 was
+ // This creates fragmentation - there's no contiguous block large enough
+ // for the seq 1 state if we only look for contiguous slots
+
+ // Restore seq 1 state into seq 1 (should work with non-contiguous allocation)
+ // We use seq 1 since it's a valid sequence ID (0 to n_parallel-1)
+ // Before the fix, this would fail with "failed to find available cells in kv cache"
+ const size_t nset = llama_state_seq_set_data(ctx, seq_state.data(), seq_state.size(), 1);
+ if (nset != seq_state.size()) {
+ fprintf(stderr, "%s : FAILED to restore seq state into fragmented cache (got %zu, expected %zu)\n",
+ __func__, nset, seq_state.size());
+ fprintf(stderr, "%s : This is the bug - state restore fails with fragmented KV cache\n", __func__);
+ llama_batch_free(batch);
+ return 1;
+ }
+ fprintf(stderr, "%s : restored state into seq 1, %zu bytes\n", __func__, nset);
+
+ // Verify we can decode with the restored state
+ // Generate one token to verify the restored state is usable
+ auto sparams = llama_sampler_chain_default_params();
+ llama_sampler * smpl = llama_sampler_chain_init(sparams);
+ llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sampling.seed));
+
+ auto next_token = llama_sampler_sample(smpl, ctx, -1);
+ auto next_token_str = common_token_to_piece(ctx, next_token);
+
+ common_batch_clear(batch);
+ common_batch_add(batch, next_token, (int)tokens.size(), {1}, true);
+
+ if (llama_decode(ctx, batch)) {
+ fprintf(stderr, "%s : failed to decode with restored state\n", __func__);
+ llama_sampler_free(smpl);
+ llama_batch_free(batch);
+ return 1;
+ }
+
+ fprintf(stderr, "%s : successfully decoded with restored state, generated: '%s'\n", __func__, next_token_str.c_str());
+ fprintf(stderr, "%s : SUCCESS - state restore works with fragmented KV cache\n", __func__);
+
+ llama_sampler_free(smpl);
+ llama_batch_free(batch);
+
+ return 0;
+}