const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
- seq_pos.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
- seq_cpl.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
+ seq_pos.resize(LLAMA_MAX_SEQ);
+ seq_cpl.resize(LLAMA_MAX_SEQ);
for (auto & cur : seq_cpl) {
- cur.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
+ cur.resize(LLAMA_MAX_SEQ);
}
}
if (batch.seq_id) {
for (int32_t i = 0; i < batch.n_tokens; ++i) {
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
- if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
- LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_PARALLEL_SEQUENCES);
+ if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
+ LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
return false;
}
}
pos.resize(batch.n_tokens);
// initialize the starting position for each sequence based on the positions in the memory
- llama_pos p0[LLAMA_MAX_PARALLEL_SEQUENCES];
- for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+ llama_pos p0[LLAMA_MAX_SEQ];
+ for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (!memory) {
p0[s] = 0;
} else {
// consistency checks
//
- for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+ for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (seq_pos[s].empty()) {
continue;
}
}
if (memory) {
- for (int32_t s0 = 0; s0 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s0) {
- for (int32_t s1 = 0; s1 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s1) {
+ for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
+ for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
if (seq_cpl[s0][s1]) {
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
const auto & hparams = model.hparams;
cparams.n_seq_max = std::max(1u, params.n_seq_max);
- if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) {
- throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES));
+ if (cparams.n_seq_max > LLAMA_MAX_SEQ) {
+ throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
}
cparams.n_threads = params.n_threads;
if (!res) {
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
- llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+ llama_pos pos_min[LLAMA_MAX_SEQ];
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
pos_min[s] = std::numeric_limits<llama_pos>::max();
}
pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
}
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
continue;
}
#include "llama-cparams.h"
size_t llama_max_parallel_sequences(void) {
- return LLAMA_MAX_PARALLEL_SEQUENCES;
+ return LLAMA_MAX_SEQ;
}
#include <cstdint>
-// TODO: rename to something shorter
-#define LLAMA_MAX_PARALLEL_SEQUENCES 64
+#define LLAMA_MAX_SEQ 64
struct llama_cparams {
uint32_t n_ctx; // context size used during inference
LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
}
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (cells.seq_pos_min(s) < 0) {
continue;
}
// keep track of the max sequence position that we would overwrite with this ubatch
// for non-SWA cache, this would be always empty
- llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES];
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+ llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
seq_pos_max_rm[s] = -1;
}
// note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
// will be present in the cache. so we have to purge any position which is less than those we would overwrite
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (seq_pos_max_rm[s] == -1) {
continue;
}
used.clear();
- for (uint32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+ for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
seq_pos[s].clear();
}
}
llama_seq_id seq_get(uint32_t i) const {
assert(seq[i].count() == 1);
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (seq[i].test(s)) {
return s;
}
// return -1 if the sequence is not present
llama_pos seq_pos_min(llama_seq_id seq_id) const {
assert(seq_id >= 0);
- assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
+ assert(seq_id < LLAMA_MAX_SEQ);
if (seq_pos[seq_id].empty()) {
return -1;
// return -1 if the sequence is not present
llama_pos seq_pos_max(llama_seq_id seq_id) const {
assert(seq_id >= 0);
- assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
+ assert(seq_id < LLAMA_MAX_SEQ);
if (seq_pos[seq_id].empty()) {
return -1;
//
std::vector<llama_pos> shift;
- using bits_t = std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>;
+ using bits_t = std::bitset<LLAMA_MAX_SEQ>;
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
std::vector<bits_t> seq;
// the set seq_pos[s] tells us which positions are currently present for sequence s
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
- std::set<llama_pos> seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES];
+ std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
// helper functions for updating `seq_pos`, once cell at a time:
// remove cell i
void seq_pos_rm(uint32_t i) {
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (seq[i].test(s)) {
seq_pos[s].erase(pos[i]);
}
// add cell i
void seq_pos_add(uint32_t i) {
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (seq[i].test(s)) {
seq_pos[s].insert(pos[i]);
}