// Requires the context to have a memory.
// For encode-decoder contexts, processes the batch using the decoder.
// Positive return values does not mean a fatal error, but rather a warning.
- // Upon non-zero return values, the memory state is restored to the state before this call
+ // Upon fatal-error or abort, the ubatches that managed to be been processed will remain in the memory state of the context
+ // To handle this correctly, query the memory state using llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
+ // Upon other return values, the memory state is restored to the state before this call
// 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
- // 2 - aborted
+ // 2 - aborted (processed ubatches will remain in the context's memory)
// -1 - invalid input batch
- // < -1 - error
+ // < -1 - fatal error (processed ubatches will remain in the context's memory)
LLAMA_API int32_t llama_decode(
struct llama_context * ctx,
struct llama_batch batch);
}
if (memory) {
+ bool ok = true;
+
if (batch.token) {
if (seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
- LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
- return false;
+ ok = false;
}
} else {
assert(batch.embd);
// for embeddings (typically used as vision input), we allow them to have repeating positions
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
if (seq_pos_min(s) != memory->seq_pos_max(s) && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
- LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
- return false;
+ ok = false;
}
}
+
+ if (!ok) {
+ LLAMA_LOG_ERROR(
+ "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
+ " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
+ " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
+ " it is required that the sequence positions remain consecutive: Y = X + 1\n",
+ __func__, s, s, memory->seq_pos_max(s), s, seq_pos_min(s));
+
+ return false;
+ }
}
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
#include <cassert>
#include <vector>
#include <set>
+#include <map>
// meta information about KV cells that can be part of multiple sequences at the same time
// TODO: add unit tests
assert(seq_id >= 0);
seq[i].reset(seq_id);
- seq_pos[seq_id].erase(pos[i]);
+ seq_pos_dec(seq_id, pos[i]);
if (seq[i].none()) {
pos[i] = -1;
seq[i].reset();
seq[i].set(seq_id);
- seq_pos[seq_id].insert(pos[i]);
+ seq_pos_inc(seq_id, pos[i]);
return false;
}
assert(!seq[i].test(seq_id));
seq[i].set(seq_id);
- seq_pos[seq_id].insert(pos[i]);
+ seq_pos_inc(seq_id, pos[i]);
}
// return the sequence id of this cell
return -1;
}
- return *seq_pos[seq_id].begin();
+ assert(seq_pos[seq_id].begin()->second > 0);
+
+ return seq_pos[seq_id].begin()->first;
}
// the maximum position of sequence seq_id currently present in any of the cells
return -1;
}
- return *seq_pos[seq_id].rbegin();
+ assert(seq_pos[seq_id].rbegin()->second > 0);
+
+ return seq_pos[seq_id].rbegin()->first;
}
// note: call only if the cell is not empty
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
std::vector<seq_set_t> seq;
- // the set seq_pos[s] tells us which positions are currently present for sequence s
+ // the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
+ // if the position p is not present, seq_pos[s][p] is not set
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
- std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
+ //
+ // note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
+ // - during performing a cache reuse via (rm + add)
+ // - some vision models have input embeddings with repeating positions
+ //
+ std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
// helper functions for updating `seq_pos`, once cell at a time:
+ void seq_pos_dec(llama_seq_id s, llama_pos p) {
+ auto it = seq_pos[s].find(p);
+ assert(it != seq_pos[s].end());
+
+ if (--it->second == 0) {
+ seq_pos[s].erase(it);
+ }
+ }
+
+ void seq_pos_inc(llama_seq_id s, llama_pos p) {
+ seq_pos[s][p]++;
+ }
+
// remove cell i
void seq_pos_rm(uint32_t i) {
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (seq[i].test(s)) {
- seq_pos[s].erase(pos[i]);
+ seq_pos_dec(s, pos[i]);
}
}
}
void seq_pos_add(uint32_t i) {
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (seq[i].test(s)) {
- seq_pos[s].insert(pos[i]);
+ seq_pos_inc(s, pos[i]);
}
}
}