struct ggml_tensor * ffn_down_scale;
};
+// very similar to llama_batch,
+// but has more metadata about sequences
+struct llama_ubatch {
+ bool equal_seqs;
+ // TODO: whole_seqs for embeddings?
+
+ uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
+ uint32_t n_seq_tokens; // tokens per sequence
+ uint32_t n_seqs;
+
+ llama_token * token; // [n_tokens]
+ float * embd; // [n_embd, n_tokens]
+ llama_pos * pos; // [n_tokens]
+ int32_t * n_seq_id; // [n_seqs]
+ llama_seq_id ** seq_id; // [n_seqs]
+ int8_t * output; // [n_tokens]
+};
+
struct llama_kv_cell {
llama_pos pos = -1;
llama_pos delta = 0;
- int32_t src = 0; // used by recurrent state models to copy states
+ int32_t src = -1; // used by recurrent state models to copy states
+ int32_t tail = -1;
std::set<llama_seq_id> seq_id;
struct llama_kv_cache {
bool has_shift = false;
bool do_defrag = false;
- bool do_copy = false;
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
bool v_trans = true; // the value tensor is transposed
}
};
+struct llama_sbatch_seq {
+ int32_t n_seq_id;
+ llama_seq_id * seq_id;
+ size_t offset;
+ size_t length;
+
+ // helper for smoother batch API transition -- can be deprecated in the future
+ llama_seq_id all_seq_id; // used if seq_id == NULL
+};
+
+// sequence-length-aware batch splitting
+struct llama_sbatch {
+ // tokens left in this batch
+ size_t n_tokens;
+
+ size_t n_embd;
+
+ bool logits_all; // TODO: remove once lctx.logits_all is removed too
+
+ // sorted indices into the batch
+ std::vector<size_t> ids;
+ // batch indices of the output
+ std::vector<size_t> out_ids;
+ std::vector<llama_sbatch_seq> seq;
+ const llama_batch * batch = nullptr;
+
+ // buffers for the ubatch
+ std::vector<llama_token> ubatch_token;
+ std::vector<float> ubatch_embd;
+ std::vector<llama_pos> ubatch_pos;
+ std::vector<int32_t> ubatch_n_seq_id;
+ std::vector<llama_seq_id *> ubatch_seq_id;
+ std::vector<int8_t> ubatch_output;
+
+ llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false) {
+ // clear empty sequences
+ // the previous ubatch is assumed to be gone,
+ // so nothing should refer to values in these sequences anymore.
+ for (size_t i = seq.size(); i-- > 0;) {
+ if (seq[i].length == 0) {
+ seq.pop_back();
+ } else {
+ break;
+ }
+ }
+ ubatch_token.resize(!has_embd ? n_ubatch : 0);
+ ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
+ ubatch_pos.resize(n_ubatch);
+ ubatch_n_seq_id.resize(n_ubatch);
+ ubatch_seq_id.resize(n_ubatch);
+ ubatch_output.resize(n_ubatch);
+ llama_ubatch ubatch = {
+ /*equal_seqs =*/ true,
+ /*n_tokens =*/ 0,
+ /*n_seq_tokens =*/ 0,
+ /*n_seqs =*/ 0,
+ /*token =*/ !has_embd ? ubatch_token.data() : nullptr,
+ /*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
+ /*pos =*/ ubatch_pos.data(),
+ /*n_seq_id =*/ ubatch_n_seq_id.data(),
+ /*seq_id =*/ ubatch_seq_id.data(),
+ /*output =*/ ubatch_output.data(),
+ };
+ return ubatch;
+ }
+
+ void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
+ GGML_ASSERT(batch != nullptr);
+ GGML_ASSERT(length <= seq.length);
+ // Can only add sequences of equal lengths to a batch,
+ // otherwise it isn't clear to which sequence a token belongs
+ GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
+ GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
+ // NOTE: loops are separated for cache-friendliness
+ if (batch->token) {
+ if (ubatch.equal_seqs) {
+ for (size_t i = 0; i < length; ++i) {
+ ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
+ }
+ } else {
+ // simple split
+ ubatch.token = batch->token + seq.offset;
+ }
+ } else {
+ ubatch.token = nullptr;
+ }
+ if (batch->embd) {
+ if (ubatch.equal_seqs) {
+ for (size_t i = 0; i < length; ++i) {
+ memcpy(
+ ubatch.embd + n_embd * (ubatch.n_tokens + i),
+ batch->embd + n_embd * ids[seq.offset + i],
+ n_embd * sizeof(float)
+ );
+ }
+ } else {
+ // simple split
+ ubatch.embd = batch->embd + (n_embd * seq.offset);
+ }
+ } else {
+ ubatch.embd = nullptr;
+ }
+ // from here on, the else branches are deprecated;
+ // they are helpers for smoother batch API transition
+ if (batch->pos) {
+ if (ubatch.equal_seqs) {
+ for (size_t i = 0; i < length; ++i) {
+ ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
+ }
+ } else {
+ // simple split
+ ubatch.pos = batch->pos + seq.offset;
+ }
+ } else {
+ for (size_t i = 0; i < length; ++i) {
+ llama_pos bi = ids[seq.offset + i];
+ ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1);
+ }
+ }
+ if (ubatch.equal_seqs) {
+ ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
+ if (seq.seq_id) {
+ ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
+ } else {
+ GGML_ASSERT(seq.n_seq_id == 1);
+ ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id;
+ }
+ } else {
+ // simple split
+ if (batch->n_seq_id) {
+ for (size_t i = 0; i < length; ++i) {
+ ubatch.n_seq_id = batch->n_seq_id + seq.offset;
+ }
+ } else {
+ for (size_t i = 0; i < length; ++i) {
+ ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
+ }
+ }
+ if (batch->seq_id) {
+ for (size_t i = 0; i < length; ++i) {
+ ubatch.seq_id = batch->seq_id + seq.offset;
+ }
+ } else {
+ for (size_t i = 0; i < length; ++i) {
+ ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id;
+ }
+ }
+ }
+ if (logits_all) {
+ for (size_t i = 0; i < length; ++i) {
+ ubatch.output[ubatch.n_tokens + i] = 1;
+ out_ids.push_back(ids[seq.offset + i]);
+ }
+ } else if (batch->logits) {
+ if (ubatch.equal_seqs) {
+ for (size_t i = 0; i < length; ++i) {
+ size_t id = ids[seq.offset + i];
+ int8_t is_output = batch->logits[id];
+ ubatch.output[ubatch.n_tokens + i] = is_output;
+ if (is_output) { out_ids.push_back(id); }
+ }
+ } else {
+ // simple split
+ ubatch.output = batch->logits + seq.offset;
+ for (size_t i = 0; i < length; ++i) {
+ if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
+ }
+ }
+ } else {
+ // only get last output
+ for (size_t i = 0; i < length; ++i) {
+ size_t id = ids[seq.offset + i];
+ int8_t is_last = id == ids.size() - 1;
+ ubatch.output[ubatch.n_tokens + i] = is_last;
+ if (is_last) { out_ids.push_back(id); }
+ }
+ }
+ if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
+ ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
+ }
+ ubatch.n_tokens += length;
+ ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
+ seq.offset += length;
+ seq.length -= length;
+ n_tokens -= length;
+ GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
+ }
+
+ // simple split, unknown number of sequences of unequal lengths
+ llama_ubatch split_simple(size_t n_ubatch) {
+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
+ llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
+ ubatch.equal_seqs = false;
+ if (!seq.empty()) {
+ llama_sbatch_seq & s = seq[0];
+ size_t length = s.length < n_ubatch ? s.length : n_ubatch;
+ GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
+ add_seq_to_ubatch(ubatch, s, length);
+ }
+ return ubatch;
+ }
+
+ // make batches of equal-length sequences
+ llama_ubatch split_equal(size_t n_ubatch) {
+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
+ llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
+ if (!seq.empty()) {
+ size_t length = 0;
+ size_t n_tokens_in_ubatch = 0;
+ GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
+ // smallest first, because it's easier to split this way;
+ // starting from the end to pop in constant time.
+ for (size_t i = seq.size(); i-- > 0;) {
+ llama_sbatch_seq & s = seq[i];
+ GGML_ASSERT(s.length > 0);
+ if (length == 0) {
+ length = s.length < n_ubatch ? s.length : n_ubatch;
+ }
+ add_seq_to_ubatch(ubatch, s, length);
+ n_tokens_in_ubatch += length;
+ // shared prompts can't be mixed with any of their sequences,
+ // so it's safer to compute them in their own ubatch
+ if (s.n_seq_id > 1) { break; }
+ // stop when there isn't enough space for another sequence
+ if (length + n_tokens_in_ubatch > n_ubatch) { break; }
+ }
+ }
+ return ubatch;
+ }
+
+ // sequence-wise split
+ llama_ubatch split_seq(size_t n_ubatch) {
+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
+ llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
+ if (!seq.empty()) {
+ llama_sbatch_seq & s = seq[seq.size() - 1];
+ size_t length = s.length < n_ubatch ? s.length : n_ubatch;
+ GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
+ add_seq_to_ubatch(ubatch, s, length);
+ }
+ return ubatch;
+ }
+
+ void from_batch(const llama_batch & batch, const size_t n_embd, const bool simple_split = false, const bool logits_all = false) {
+ GGML_ASSERT(batch.n_tokens >= 0);
+ this->batch = &batch;
+ this->n_embd = n_embd;
+ this->logits_all = logits_all;
+
+ n_tokens = batch.n_tokens;
+ ids.resize(n_tokens);
+ out_ids.clear();
+ // TODO: reserve out_ids and seq
+
+ for (size_t i = 0; i < n_tokens; ++i) {
+ ids[i] = i;
+ }
+ if (simple_split) {
+ seq.resize(1);
+ llama_sbatch_seq & s = seq[0];
+ s.n_seq_id = 0;
+ s.seq_id = nullptr;
+ s.offset = 0;
+ s.length = n_tokens;
+ s.all_seq_id = batch.all_seq_id;
+ return;
+ }
+ std::sort(ids.begin(), ids.end(),
+ [&batch](size_t a, size_t b) {
+ int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
+ int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
+ // sort by seq_id, then by pos
+ if (n_seq_a == n_seq_b) {
+ if (batch.seq_id) {
+ for (int32_t i = 0; i < n_seq_a; ++i) {
+ llama_seq_id seq_id_a = batch.seq_id[a][i];
+ llama_seq_id seq_id_b = batch.seq_id[b][i];
+ // smaller seq_ids go first
+ if (seq_id_a != seq_id_b) {
+ return seq_id_a < seq_id_b;
+ }
+ }
+ }
+ // when all else is equal, sort by pos
+ if (batch.pos) {
+ return batch.pos[a] < batch.pos[b];
+ }
+ // no pos, sort by id (assuming batch.all_pos_1 is positive)
+ return a < b;
+ }
+ // shared prompts go first
+ return n_seq_a > n_seq_b;
+ }
+ );
+ // init seq
+ llama_sbatch_seq * last_seq = nullptr;
+
+ if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) {
+ for (size_t i = 0; i < n_tokens; ++i) {
+ const size_t bi = ids[i];
+ const int32_t n_seqs = batch.n_seq_id[bi];
+ llama_seq_id * seq_ids = batch.seq_id[bi];
+ if (last_seq != nullptr) {
+ bool same = n_seqs == last_seq->n_seq_id;
+ for (int32_t j = 0; same && j < n_seqs; ++j) {
+ if (seq_ids[j] != last_seq->seq_id[j]) {
+ same = false;
+ }
+ }
+ if (same) {
+ last_seq->length += 1;
+ continue;
+ }
+ }
+ llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id};
+ seq.push_back(new_seq);
+ last_seq = &seq.back();
+ }
+ } else {
+ llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id};
+ seq.push_back(new_seq);
+ }
+ // keep shared prompts first at the end, then sort by length descending.
+ std::sort(seq.begin(), seq.end(),
+ [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
+ if (a.n_seq_id == b.n_seq_id) {
+ return a.length > b.length;
+ }
+ return a.n_seq_id < b.n_seq_id;
+ }
+ );
+ }
+};
+
struct llama_context {
llama_context(const llama_model & model)
: model(model)
struct llama_cparams cparams;
struct llama_sampling sampling;
+ struct llama_sbatch sbatch;
struct llama_kv_cache kv_self;
struct llama_control_vector cvec;
cache.has_shift = false;
- // TODO: find a nicer way to add other recurrent model architectures
- cache.recurrent = model.arch == LLM_ARCH_MAMBA;
+ cache.recurrent = llama_model_is_recurrent(&model);
cache.v_trans = !cache.recurrent && !cparams.flash_attn;
cache.head = 0;
cache.cells.clear();
cache.cells.resize(kv_size);
- if (cache.recurrent) {
- // init state copy sources
- for (uint32_t i = 0; i < cache.size; ++i) {
- cache.cells[i].src = i;
- }
- }
-
// count used buffer types
std::map<ggml_backend_buffer_type_t, int> buft_layer_count;
if (offload) {
// to the first cell of the slot.
static bool llama_kv_cache_find_slot(
struct llama_kv_cache & cache,
- const struct llama_batch & batch) {
+ const struct llama_ubatch & batch) {
const uint32_t n_tokens = batch.n_tokens;
+ const uint32_t n_seqs = batch.n_seqs;
+ const uint32_t n_seq_tokens = batch.n_seq_tokens;
if (cache.recurrent) {
// For recurrent state architectures (like Mamba),
- // each KV cache cell can store the state for a whole sequence.
-
- llama_seq_id min = cache.size - 1;
- llama_seq_id max = 0;
-
- for (uint32_t i = 0; i < n_tokens; ++i) {
- for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) {
- llama_seq_id seq_id = batch.seq_id[i][j];
- // make sure it's a valid seq_id
- if ((uint32_t) seq_id < cache.size) {
- if (seq_id > max) {
- max = seq_id;
- }
- if (seq_id < min) {
- min = seq_id;
+ // each cache cell can store the state for a whole sequence.
+ // A slot should be always be contiguous.
+
+ // can only process batches with an equal number of new tokens in each sequence
+ GGML_ASSERT(batch.equal_seqs);
+
+ int32_t min = cache.size - 1;
+ int32_t max = 0;
+
+ // everything should fit if all seq_ids are smaller than the max
+ for (uint32_t s = 0; s < n_seqs; ++s) {
+ const uint32_t n_seq_id = batch.n_seq_id[s];
+ for (uint32_t j = 0; j < n_seq_id; ++j) {
+ const llama_seq_id seq_id = batch.seq_id[s][j];
+
+ if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
+ // too big seq_id
+ // TODO: would it be possible to resize the cache instead?
+ LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
+ return false;
+ }
+ if (j > 0) {
+ llama_kv_cell & seq = cache.cells[seq_id];
+ if (seq.tail >= 0) {
+ llama_kv_cell & cell = cache.cells[seq.tail];
+ // clear cells from seq_ids that become shared
+ // (should not normally happen, but let's handle it anyway)
+ cell.seq_id.erase(seq_id);
+ seq.tail = -1;
+ if (cell.seq_id.empty()) {
+ cell.pos = -1;
+ cell.src = -1;
+ cache.used -= 1;
+ }
}
- // Assuming the tokens are in-order
- if (batch.pos[i] != cache.cells[seq_id].pos + 1) {
- // What should happen when the pos backtracks or skips a value?
- // Clearing the state mid-batch would require special-casing which isn't done.
- LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n",
- __func__, batch.pos[i], cache.cells[seq_id].pos, seq_id);
+ }
+ }
+ }
+
+#ifndef NDEBUG
+ {
+ std::vector<int32_t> tails_verif;
+ tails_verif.assign(cache.size, -1);
+ for (uint32_t i = 0; i < cache.size; ++i) {
+ llama_kv_cell & cell = cache.cells[i];
+ for (llama_seq_id seq_id : cell.seq_id) {
+ if (tails_verif[seq_id] != -1) {
+ LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
}
- if (cache.cells[seq_id].pos < 0 && 0 <= batch.pos[i]) {
- cache.used += 1;
+ tails_verif[seq_id] = i;
+ }
+ }
+ for (uint32_t i = 0; i < cache.size; ++i) {
+ if (tails_verif[i] != cache.cells[i].tail) {
+ LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cache.cells[i].tail, tails_verif[i]);
+ }
+ }
+ }
+#endif
+
+ // find next empty cell
+ uint32_t next_empty_cell = cache.head;
+
+ for (uint32_t i = 0; i < cache.size; ++i) {
+ if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
+ llama_kv_cell & cell = cache.cells[next_empty_cell];
+ if (cell.is_empty()) { break; }
+ next_empty_cell += 1;
+ }
+
+ // find usable cell range
+ for (uint32_t s = 0; s < n_seqs; ++s) {
+ const llama_seq_id seq_id = batch.seq_id[s][0];
+ llama_kv_cell & seq_meta = cache.cells[seq_id];
+ bool has_cell = false;
+ if (seq_meta.tail >= 0) {
+ llama_kv_cell & cell = cache.cells[seq_meta.tail];
+ GGML_ASSERT(cell.has_seq_id(seq_id));
+ // does this seq_id "own" the cell?
+ if (cell.seq_id.size() == 1) { has_cell = true; }
+ }
+ if (!has_cell) {
+ llama_kv_cell & empty_cell = cache.cells[next_empty_cell];
+ GGML_ASSERT(empty_cell.is_empty());
+ // copy old tail into the empty cell
+ if (seq_meta.tail >= 0) {
+ llama_kv_cell & orig_cell = cache.cells[seq_meta.tail];
+ empty_cell.pos = orig_cell.pos;
+ empty_cell.src = orig_cell.src;
+ orig_cell.seq_id.erase(seq_id);
+ empty_cell.seq_id.insert(seq_id); // will be overwritten
+ }
+ seq_meta.tail = next_empty_cell;
+ // find next empty cell
+ if (s + 1 < n_seqs) {
+ next_empty_cell += 1;
+ for (uint32_t i = 0; i < cache.size; ++i) {
+ if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
+ llama_kv_cell & cell = cache.cells[next_empty_cell];
+ if (cell.is_empty()) { break; }
+ next_empty_cell += 1;
}
- cache.cells[seq_id].pos = batch.pos[i];
- // NOTE: seq_ids are not inserted here; they are handled when the input tensors are set
- } else {
- // too big seq_id
- // TODO: would it be possible to resize the KV cache size instead?
- LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
- return false;
}
}
+ if (min > seq_meta.tail) { min = seq_meta.tail; }
+ if (max < seq_meta.tail) { max = seq_meta.tail; }
+ }
+
+ // gather and re-order
+ for (uint32_t s = 0; s < n_seqs; ++s) {
+ int32_t dst_id = s + min;
+ int32_t src_id = cache.cells[batch.seq_id[s][0]].tail;
+ if (dst_id != src_id) {
+ llama_kv_cell & dst_cell = cache.cells[dst_id];
+ llama_kv_cell & src_cell = cache.cells[src_id];
+
+ std::swap(dst_cell.pos, src_cell.pos);
+ std::swap(dst_cell.src, src_cell.src);
+ std::swap(dst_cell.seq_id, src_cell.seq_id);
+
+ // swap tails (assuming they NEVER overlap)
+ for (const llama_seq_id seq_id : src_cell.seq_id) {
+ cache.cells[seq_id].tail = src_id;
+ }
+ for (const llama_seq_id seq_id : dst_cell.seq_id) {
+ cache.cells[seq_id].tail = dst_id;
+ }
+ }
+ }
+
+ // update the pos of the used seqs
+ for (uint32_t s = 0; s < n_seqs; ++s) {
+ const llama_pos last_pos = batch.pos[n_seq_tokens * s + n_seq_tokens - 1];
+ int32_t cell_id = s + min;
+ llama_kv_cell & cell = cache.cells[cell_id];
+
+ if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
+ // What should happen when the pos backtracks or skips a value?
+ // Clearing the state mid-batch would require special-casing which isn't done.
+ LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
+ __func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens);
+ }
+ cell.pos = last_pos;
+ cell.seq_id.clear();
+ for (int32_t j = 0; j < batch.n_seq_id[s]; ++j) {
+ const llama_seq_id seq_id = batch.seq_id[s][j];
+ cell.seq_id.insert(seq_id);
+ cache.cells[seq_id].tail = cell_id;
+ }
}
// allow getting the range of used cells, from head to head + n
cache.n = max - min + 1;
// sanity check
- return max >= min;
+ return cache.n >= n_seqs;
}
// otherwise, one cell per token.
}
}
- for (uint32_t i = 0; i < n_tokens; i++) {
- cache.cells[cache.head + i].pos = batch.pos[i];
+ for (uint32_t s = 0; s < n_seqs; s++) {
+ for (uint32_t i = 0; i < n_seq_tokens; ++i) {
+ uint32_t k = s*n_seq_tokens + i;
+ cache.cells[cache.head + k].pos = batch.pos[k];
- for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
- cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]);
+ for (int32_t j = 0; j < batch.n_seq_id[s]; j++) {
+ cache.cells[cache.head + k].seq_id.insert(batch.seq_id[s][j]);
+ }
}
}
for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
+ cache.cells[i].src = -1;
+ cache.cells[i].tail = -1;
}
cache.head = 0;
cache.used = 0;
return false;
}
if (0 <= seq_id) {
- // partial intersection is invalid
- if ((0 < p0 && p0 <= cache.cells[seq_id].pos) || (0 < p1 && p1 <= cache.cells[seq_id].pos)) {
- return false;
+ int32_t & tail_id = cache.cells[seq_id].tail;
+ if (tail_id >= 0) {
+ const llama_kv_cell & cell = cache.cells[tail_id];
+ // partial intersection is invalid
+ if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
+ return false;
+ }
+ if (p0 <= cell.pos && p1 < cell.pos) {
+ tail_id = -1;
+ }
}
} else {
// seq_id is negative, then the range should include everything or nothing
if (cache.cells[i].pos >= 0) cache.used--;
cache.cells[i].pos = -1;
+ cache.cells[i].src = -1;
if (new_head == cache.size) new_head = i;
}
}
if (cache.recurrent) {
if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
- seq_id_src = cache.cells[seq_id_src].src;
- GGML_ASSERT((uint32_t) seq_id_src < cache.size);
- // intent to "copy from"
- // supports copy chains thanks to taking the source of the source
- cache.cells[seq_id_dst].src = seq_id_src;
-
- // preserve the "keep or clear" status of the copied sequence
- if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) {
- cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
- } else {
- cache.cells[seq_id_dst].seq_id.erase(seq_id_dst);
+ llama_kv_cell & tail_src = cache.cells[seq_id_src];
+ llama_kv_cell & tail_dst = cache.cells[seq_id_dst];
+ if (tail_dst.tail >= 0) {
+ // clear destination seq_id if it wasn't empty
+ llama_kv_cell & cell_dst = cache.cells[tail_dst.tail];
+
+ cell_dst.seq_id.erase(seq_id_dst);
+ tail_dst.tail = -1;
+ if (cell_dst.seq_id.empty()) {
+ cell_dst.pos = -1;
+ cell_dst.delta = -1;
+ cell_dst.src = -1;
+ cache.used -= 1;
+ }
}
+ if (tail_src.tail >= 0) {
+ llama_kv_cell & cell_src = cache.cells[tail_src.tail];
- cache.do_copy = true;
-
- cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos;
+ cell_src.seq_id.insert(seq_id_dst);
+ tail_dst.tail = tail_src.tail;
+ }
}
+
return;
}
// otherwise, this is the KV cache of a Transformer-like model
uint32_t new_head = cache.size;
for (uint32_t i = 0; i < cache.size; ++i) {
+ if (cache.recurrent && (llama_seq_id) i != seq_id) {
+ cache.cells[i].tail = -1;
+ }
if (!cache.cells[i].has_seq_id(seq_id)) {
if (cache.cells[i].pos >= 0) cache.used--;
cache.cells[i].pos = -1;
+ cache.cells[i].src = -1;
cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i;
} else {
if (cache.recurrent) {
// for Mamba-like models, only the pos needs to be shifted
if (0 <= seq_id && seq_id < (int64_t) cache.size) {
- llama_kv_cell & cell = cache.cells[seq_id];
- if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
- cell.pos += delta;
+ const int32_t tail_id = cache.cells[seq_id].tail;
+ if (tail_id >= 0) {
+ llama_kv_cell & cell = cache.cells[tail_id];
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
+ cell.pos += delta;
+ }
}
}
return;
if (cache.recurrent) {
// for Mamba-like models, only the pos needs to be changed
if (0 <= seq_id && seq_id < (int64_t) cache.size) {
- llama_kv_cell & cell = cache.cells[seq_id];
- if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
- cell.pos /= d;
+ const int32_t tail_id = cache.cells[seq_id].tail;
+ if (tail_id >= 0) {
+ llama_kv_cell & cell = cache.cells[tail_id];
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
+ cell.pos /= d;
+ }
}
}
return;
}
static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
- cache.do_defrag = true;
+ if (!cache.recurrent) {
+ cache.do_defrag = true;
+ }
}
static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) {
struct ggml_context * ctx,
struct llama_context & lctx,
const llama_hparams & hparams,
- const llama_batch & batch,
+ const llama_ubatch & batch,
struct ggml_tensor * tok_embd,
const llm_build_cb & cb) {
const int64_t n_embd = hparams.n_embd;
return cur;
}
+static struct ggml_tensor * llm_build_copy_mask_state(
+ struct ggml_context * ctx,
+ struct ggml_cgraph * graph,
+ struct ggml_tensor * s,
+ struct ggml_tensor * state_copy,
+ struct ggml_tensor * state_mask,
+ int32_t n_state,
+ int32_t kv_size,
+ int32_t kv_head,
+ int32_t n_kv,
+ int32_t n_seqs) {
+ struct ggml_tensor * states = ggml_reshape_2d(ctx, s, n_state, kv_size);
+
+ // copy states
+ // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
+ // this shrinks the tensors's ne[1] to n_kv
+ states = ggml_get_rows(ctx, states, state_copy);
+
+ // clear states of sequences which are starting at the beginning of this batch
+ // FIXME: zero-out NANs?
+ states = ggml_mul(ctx, states, state_mask);
+
+ // copy states which won't be changed further (between n_seqs and n_rs)
+ ggml_build_forward_expand(graph,
+ ggml_cpy(ctx,
+ ggml_view_1d(ctx, states, n_state*(n_kv - n_seqs), n_seqs*n_state*ggml_element_size(states)),
+ ggml_view_1d(ctx, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
+
+ // the part of the states that will be used and modified
+ return ggml_view_2d(ctx, states, n_state, n_seqs, states->nb[1], 0);
+}
+
+// TODO: split
+static struct ggml_tensor * llm_build_mamba(
+ struct ggml_context * ctx,
+ struct llama_context & lctx,
+ const llama_ubatch & batch,
+ struct ggml_cgraph * graph,
+ struct ggml_tensor * cur,
+ struct ggml_tensor * state_copy,
+ struct ggml_tensor * state_mask,
+ int32_t kv_head,
+ int32_t n_kv,
+ const llm_build_cb & cb,
+ int il) {
+ const llama_model & model = lctx.model;
+ const llama_hparams & hparams = model.hparams;
+ const llama_kv_cache & kv = lctx.kv_self;
+ const int64_t d_conv = hparams.ssm_d_conv;
+ const int64_t d_inner = hparams.ssm_d_inner;
+ const int64_t d_state = hparams.ssm_d_state;
+ const int64_t dt_rank = hparams.ssm_dt_rank;
+ const int64_t n_seqs = batch.n_seqs;
+ // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
+ const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
+ // Use the same RMS norm as the final layer norm
+ const float norm_rms_eps = hparams.f_norm_rms_eps;
+
+ const int64_t n_seq_tokens = batch.n_seq_tokens;
+
+ GGML_ASSERT(n_seqs != 0);
+ GGML_ASSERT(batch.equal_seqs);
+ GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs);
+
+ struct ggml_tensor * conv_states_all = kv.k_l[il];
+ struct ggml_tensor * ssm_states_all = kv.v_l[il];
+
+ // (ab)using the KV cache to store the states
+ struct ggml_tensor * conv = llm_build_copy_mask_state(ctx,
+ graph, conv_states_all, state_copy, state_mask,
+ hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs);
+ conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner, n_seqs);
+ struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx,
+ graph, ssm_states_all, state_copy, state_mask,
+ hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs);
+ ssm = ggml_reshape_3d(ctx, ssm, d_state, d_inner, n_seqs);
+
+ // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
+ cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs);
+
+ // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
+ struct ggml_tensor * xz = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_in, cur);
+ // split the above in two
+ // => {d_inner, n_seq_tokens, n_seqs}
+ struct ggml_tensor * x = ggml_view_3d(ctx, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0);
+ struct ggml_tensor * z = ggml_view_3d(ctx, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], d_inner*ggml_element_size(xz));
+
+ // conv
+ {
+ // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
+ struct ggml_tensor * conv_x = ggml_concat(ctx, conv, ggml_transpose(ctx, x), 0);
+
+ // copy last (d_conv - 1) columns back into the state cache
+ struct ggml_tensor * last_conv = ggml_view_3d(ctx, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
+
+ ggml_build_forward_expand(graph,
+ ggml_cpy(ctx, last_conv,
+ ggml_view_1d(ctx, conv_states_all,
+ (d_conv - 1)*(d_inner)*(n_seqs),
+ kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all))));
+
+ // 1D convolution
+ // The equivalent is to make a self-overlapping view of conv_x
+ // over d_conv columns at each stride in the 3rd dimension,
+ // then element-wise multiply that with the conv1d weight,
+ // then sum the elements of each row,
+ // (the last two steps are a dot product over rows (also doable with mul_mat))
+ // then permute away the ne[0] dimension,
+ // and then you're left with the resulting x tensor.
+ // For simultaneous sequences, all sequences need to have the same length.
+ x = ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d);
+
+ // bias
+ x = ggml_add(ctx, x, model.layers[il].ssm_conv1d_b);
+
+ x = ggml_silu(ctx, x);
+ }
+
+ // ssm
+ {
+ // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs}
+ struct ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_x, x);
+ // split
+ struct ggml_tensor * dt = ggml_view_3d(ctx, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0);
+ struct ggml_tensor * B = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank);
+ struct ggml_tensor * C = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state));
+
+ // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
+ if (ssm_dt_b_c_rms) {
+ dt = ggml_rms_norm(ctx, dt, norm_rms_eps);
+ B = ggml_rms_norm(ctx, B, norm_rms_eps);
+ C = ggml_rms_norm(ctx, C, norm_rms_eps);
+ }
+
+ // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
+ dt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_dt, dt);
+ dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b);
+
+ // Custom operator to optimize the parallel associative scan
+ // as described in the Annex D of the Mamba paper.
+ // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
+ struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C);
+
+ // store last states
+ ggml_build_forward_expand(graph,
+ ggml_cpy(ctx,
+ ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, x->nb[3]),
+ ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
+
+ struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0);
+
+ // TODO: skip computing output earlier for unused tokens
+
+ // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs}
+ y = ggml_add(ctx, y, ggml_mul(ctx, x, model.layers[il].ssm_d));
+ y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z)));
+
+ // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
+ cur = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_out, y);
+ }
+
+ // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
+ cur = ggml_reshape_2d(ctx, cur, cur->ne[0], n_seq_tokens * n_seqs);
+ cb(cur, "mamba_out", il);
+
+ return cur;
+}
+
struct llm_build_context {
const llama_model & model;
llama_context & lctx;
const llama_hparams & hparams;
const llama_cparams & cparams;
- const llama_batch & batch;
+ const llama_ubatch & batch;
const llama_kv_cache & kv_self;
const int64_t n_embd;
// TODO: consider making the entire interface noexcept
llm_build_context(
llama_context & lctx,
- const llama_batch & batch,
+ const llama_ubatch & batch,
const llm_build_cb & cb,
bool worst_case) :
model (lctx.model),
return gf;
}
- struct ggml_cgraph * build_s_copy() {
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
-
- GGML_ASSERT(kv_self.recurrent);
-
- struct ggml_tensor * state_copy = build_inp_s_copy();
-
- for (int il = 0; il < n_layer; ++il) {
- struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
- struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
-
- conv_states = ggml_get_rows(ctx0, conv_states, state_copy);
- ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy);
-
- // TODO: name the intermediate tensors with cb()
-
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il]));
- }
-
- return gf;
- }
-
struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
}
struct ggml_tensor * build_inp_s_copy() {
- lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, kv_self.size);
+ lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
cb(lctx.inp_s_copy, "inp_s_copy", -1);
ggml_set_input(lctx.inp_s_copy);
return lctx.inp_s_copy;
return lctx.inp_s_mask;
}
- struct ggml_tensor * build_inp_s_seq() {
- lctx.inp_s_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
- cb(lctx.inp_s_seq, "inp_s_seq", -1);
- ggml_set_input(lctx.inp_s_seq);
- return lctx.inp_s_seq;
- }
-
struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) {
// find result_norm tensor for input
struct ggml_tensor * inp = nullptr;
struct ggml_cgraph * build_mamba() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
- const int64_t d_model = n_embd;
- const int64_t d_conv = hparams.ssm_d_conv;
- const int64_t d_inner = hparams.ssm_d_inner;
- GGML_ASSERT(2 * d_model == d_inner);
- const int64_t d_state = hparams.ssm_d_state;
- const int64_t dt_rank = hparams.ssm_dt_rank;
- // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
- const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
- // Use the same RMS norm as the final layer norm
- const float norm_rms_eps = hparams.f_norm_rms_eps;
-
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
// {n_embd, n_tokens}
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+ struct ggml_tensor * state_copy = build_inp_s_copy();
struct ggml_tensor * state_mask = build_inp_s_mask();
- struct ggml_tensor * state_seq = build_inp_s_seq();
for (int il = 0; il < n_layer; ++il) {
- // (ab)using the KV cache to store the states
- struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
- struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
-
- // clear states of sequences which are starting at the beginning of this batch
- {
- conv_states = ggml_mul(ctx0,
- ggml_view_2d(ctx0, conv_states, conv_states->ne[0], n_kv, conv_states->nb[1], kv_head*conv_states->nb[1]),
- state_mask);
- ssm_states = ggml_mul(ctx0,
- ggml_view_2d(ctx0, ssm_states, ssm_states->ne[0], n_kv, ssm_states->nb[1], kv_head*ssm_states->nb[1]),
- state_mask);
- }
-
- conv_states = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_kv);
- ssm_states = ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_kv);
-
// norm
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
- // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens}
- struct ggml_tensor * xz = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_in, cur);
- // split the above in two
- // => {d_inner, n_tokens}
- struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0);
- struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner);
-
- // conv
- {
- // Custom operator which is needed only to ease simultaneous sequence processing.
- // For a single sequence, the equivalent is to concatenate the columns of conv_states and x,
- // then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension,
- // then element-wise multiply that with the conv1d weigth,
- // then sum the elements of each row,
- // (the last two steps are a dot product over rows (also doable with mul_mat))
- // then permute away the ne[0] dimension,
- // and then you're left with the resulting x tensor.
- // The new conv_states is the last (d_conv - 1) columns
- // of the last 3rd dimensional "layer" of the self-overlapping view.
- // For simultaneous sequences, it's more complicated.
- struct ggml_tensor * x_conv = ggml_ssm_conv(ctx0, conv_states, x, model.layers[il].ssm_conv1d, state_seq);
-
- // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache
- ggml_build_forward_expand(gf,
- ggml_cpy(ctx0,
- ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_kv, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)),
- ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv))));
-
- // extract x from x_conv
- x = ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0);
-
- // bias
- x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b);
-
- x = ggml_silu(ctx0, x);
- }
-
- // ssm
- {
- // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens}
- struct ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_x, x);
- // split
- struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0);
- struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
- struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
-
- // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
- if (ssm_dt_b_c_rms) {
- dt = ggml_rms_norm(ctx0, dt, norm_rms_eps);
- B = ggml_rms_norm(ctx0, B, norm_rms_eps);
- C = ggml_rms_norm(ctx0, C, norm_rms_eps);
- }
-
- // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens}
- dt = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_dt, dt);
- dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
-
- // Custom operator to optimize the parallel associative scan
- // as described in the Annex D of the Mamba paper.
- // => {d_inner, n_tokens} and {d_state, d_inner, n_kv} combined,
- // because only a single tensor can be returned.
- struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx0, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq);
-
- // store last states (the second part of y_ssm_states)
- ggml_build_forward_expand(gf,
- ggml_cpy(ctx0,
- ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)),
- ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_head*d_state*d_inner*ggml_element_size(ssm_states))));
-
- struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0);
-
- if (il == n_layer - 1) {
- // skip computing output for unused tokens
- struct ggml_tensor * inp_out_ids = build_inp_out_ids();
- x = ggml_get_rows(ctx0, x, inp_out_ids);
- y = ggml_get_rows(ctx0, y, inp_out_ids);
- z = ggml_get_rows(ctx0, z, inp_out_ids);
- inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
- }
-
- // {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens}
- y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
- y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
+ cur = llm_build_mamba(ctx0, lctx, batch, gf, cur,
+ state_copy, state_mask,
+ kv_head, n_kv, cb, il);
- // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens}
- cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_out, y);
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
// residual
};
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
- llama_batch dummy;
- dummy.n_tokens = 0;
+ llama_ubatch dummy = {};
+ dummy.equal_seqs = true;
llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
}
static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
- llama_batch dummy;
- dummy.n_tokens = 0;
+ llama_ubatch dummy = {};
+ dummy.equal_seqs = true;
llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
return result;
}
-static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) {
- llama_batch dummy;
- dummy.n_tokens = 0;
-
- llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
-
- struct llm_build_context llm(lctx, dummy, cb, false);
-
- llm.init();
-
- struct ggml_cgraph * result = llm.build_s_copy();
-
- llm.free();
-
- return result;
-}
-
static struct ggml_cgraph * llama_build_graph(
llama_context & lctx,
- const llama_batch & batch,
+ const llama_ubatch & batch,
bool worst_case) {
const auto & model = lctx.model;
return relative_bucket;
}
-static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
+static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
//
// set input data
//
for (int i = 0; i < n_tokens; ++i) {
data[i] = i;
}
- } else if (batch.logits) {
+ } else if (batch.output) {
int32_t n_outputs = 0;
for (int i = 0; i < n_tokens; ++i) {
- if (batch.logits[i]) {
+ if (batch.output[i]) {
data[n_outputs++] = i;
}
}
if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) {
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
if (cparams.causal_attn && !lctx.is_encoding) {
- const int64_t n_kv = kv_self.n;
- const int64_t n_tokens = batch.n_tokens;
+ const int64_t n_kv = kv_self.n;
+ const int64_t n_tokens = batch.n_tokens;
+ const int64_t n_seq_tokens = batch.n_seq_tokens;
+ const int64_t n_seqs = batch.n_seqs;
float * data = nullptr;
// of the correct sequence for each token of the batch.
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
for (int h = 0; h < 1; ++h) {
- for (int j = 0; j < n_tokens; ++j) {
- const llama_pos pos = batch.pos[j];
- const llama_seq_id seq_id = batch.seq_id[j][0];
+ for (int s = 0; s < n_seqs; ++s) {
+ const llama_seq_id seq_id = batch.seq_id[s][0];
- for (int i = 0; i < n_kv; ++i) {
- float f;
- if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
- f = -INFINITY;
- } else {
- if (hparams.use_alibi) {
- f = -std::abs(lctx.kv_self.cells[i].pos - pos);
+ for (int j = 0; j < n_seq_tokens; ++j) {
+ const llama_pos pos = batch.pos[s*n_seq_tokens + j];
+
+ for (int i = 0; i < n_kv; ++i) {
+ float f;
+ if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
+ f = -INFINITY;
} else {
- f = 0.0f;
+ if (hparams.use_alibi) {
+ f = -std::abs(kv_self.cells[i].pos - pos);
+ } else {
+ f = 0.0f;
+ }
}
- }
- if (data) {
- data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
- }
+ if (data) {
+ data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
+ }
- // may need to cut off old tokens for sliding window
- if (data_swa) {
- if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
- f = -INFINITY;
+ // may need to cut off old tokens for sliding window
+ if (data_swa) {
+ if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
+ f = -INFINITY;
+ }
+ data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
}
- data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
}
}
}
}
}
} else {
+ const int64_t n_tokens = batch.n_tokens;
+ const int64_t n_seq_tokens = batch.n_seq_tokens;
+ const int64_t n_seqs = batch.n_seqs;
// when using kv cache, the mask needs to match the kv cache size
- const int64_t n_tokens = batch.n_tokens;
const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
float * data = (float *) lctx.inp_KQ_mask->data;
for (int h = 0; h < 1; ++h) {
- for (int j = 0; j < n_tokens; ++j) {
- const llama_seq_id seq_id = batch.seq_id[j][0];
-
- for (int i = 0; i < n_tokens; ++i) {
- float f = -INFINITY;
- for (int s = 0; s < batch.n_seq_id[i]; ++s) {
- if (batch.seq_id[i][s] == seq_id) {
- if (hparams.use_alibi) {
- f = -std::abs(batch.pos[i] - batch.pos[j]);
- } else {
- f = 0.0f;
+ for (int s1 = 0; s1 < n_seqs; ++s1) {
+ const llama_seq_id seq_id = batch.seq_id[s1][0];
+
+ for (int j = 0; j < n_seq_tokens; ++j) {
+ const int32_t tj = s1*n_seq_tokens + j;
+
+ for (int s0 = 0; s0 < n_seqs; ++s0) {
+ for (int i = 0; i < n_seq_tokens; ++i) {
+ const int32_t ti = s0*n_seq_tokens + i;
+ float f = -INFINITY;
+
+ for (int s = 0; s < batch.n_seq_id[s0]; ++s) {
+ if (batch.seq_id[s0][s] == seq_id) {
+ if (hparams.use_alibi) {
+ f = -std::abs(batch.pos[ti] - batch.pos[tj]);
+ } else {
+ f = 0.0f;
+ }
+ break;
+ }
}
- break;
+
+ data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
}
}
- data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
- }
-
- for (int i = n_tokens; i < n_stride; ++i) {
- data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
+ for (int i = n_tokens; i < n_stride; ++i) {
+ data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
+ }
}
}
}
}
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
- const int64_t n_tokens = batch.n_tokens;
+ const int64_t n_tokens = batch.n_tokens;
+ const int64_t n_seq_tokens = batch.n_seq_tokens;
+ const int64_t n_seqs = batch.n_seqs;
GGML_ASSERT(lctx.inp_mean);
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean));
std::vector<uint64_t> sum(n_tokens, 0);
- for (int i = 0; i < n_tokens; ++i) {
- const llama_seq_id seq_id = batch.seq_id[i][0];
+ for (int s = 0; s < n_seqs; ++s) {
+ const llama_seq_id seq_id = batch.seq_id[s][0];
+
+ // TODO: adapt limits to n_seqs when batch.equal_seqs is true
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
- sum[seq_id] += 1;
+ sum[seq_id] += batch.n_seq_tokens;
}
std::vector<float> div(n_tokens, 0.0f);
}
}
- for (int i = 0; i < n_tokens; ++i) {
- const llama_seq_id seq_id = batch.seq_id[i][0];
- data[seq_id*n_tokens + i] = div[seq_id];
+ for (int s = 0; s < n_seqs; ++s) {
+ const llama_seq_id seq_id = batch.seq_id[s][0];
+
+ for (int i = 0; i < n_seq_tokens; ++i) {
+ data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
+ }
}
}
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
- const int64_t n_tokens = batch.n_tokens;
+ const int64_t n_tokens = batch.n_tokens;
+ const int64_t n_seq_tokens = batch.n_seq_tokens;
+ const int64_t n_seqs = batch.n_seqs;
GGML_ASSERT(lctx.inp_cls);
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
uint32_t * data = (uint32_t *) lctx.inp_cls->data;
memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
- for (int i = 0; i < n_tokens; ++i) {
- const llama_seq_id seq_id = batch.seq_id[i][0];
- const llama_pos pos = batch.pos[i];
+ for (int s = 0; s < n_seqs; ++s) {
+ const llama_seq_id seq_id = batch.seq_id[s][0];
+ // TODO: adapt limits to n_seqs when batch.equal_seqs is true
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS");
- if (pos == 0) {
- data[seq_id] = i;
+ for (int i = 0; i < n_seq_tokens; ++i) {
+ const llama_pos pos = batch.pos[s*n_seq_tokens + i];
+
+ if (pos == 0) {
+ data[seq_id] = s*n_seq_tokens + i;
+ }
}
}
}
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
- const int64_t n_tokens = batch.n_tokens;
+ const int64_t n_tokens = batch.n_tokens;
+ const int64_t n_seq_tokens = batch.n_seq_tokens;
+ const int64_t n_seqs = batch.n_seqs;
GGML_ASSERT(lctx.inp_cls);
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
std::vector<int> last_pos(n_tokens, -1);
std::vector<int> last_row(n_tokens, -1);
- for (int i = 0; i < n_tokens; ++i) {
- const llama_seq_id seq_id = batch.seq_id[i][0];
- const llama_pos pos = batch.pos[i];
+ for (int s = 0; s < n_seqs; ++s) {
+ const llama_seq_id seq_id = batch.seq_id[s][0];
+ // TODO: adapt limits to n_seqs when batch.equal_seqs is true
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
- if (pos >= last_pos[seq_id]) {
- last_pos[seq_id] = pos;
- last_row[seq_id] = i;
+ for (int i = 0; i < n_seq_tokens; ++i) {
+ const llama_pos pos = batch.pos[s*n_seq_tokens + i];
+
+ if (pos >= last_pos[seq_id]) {
+ last_pos[seq_id] = pos;
+ last_row[seq_id] = s*n_seq_tokens + i;
+ }
}
}
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
float * data = (float *) lctx.inp_s_mask->data;
- // states which are not affected by the current batch are left untouched
+ // clear unused states
for (int i = 0; i < n_kv; ++i) {
- llama_seq_id seq_id = i + lctx.kv_self.head;
- llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id];
- bool has_self_seq = kv_cell.has_seq_id(seq_id);
+ uint32_t cell_id = i + kv_self.head;
+ llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
- data[i] = (float) has_self_seq;
+ data[i] = (float) (kv_cell.src >= 0);
- // ensure current sequences will be kept
- if (!has_self_seq && kv_cell.pos >= 0) {
- kv_cell.seq_id.insert(seq_id);
+ // only clear once
+ if (kv_cell.src < 0) {
+ kv_cell.src = cell_id;
}
}
}
- // For Mamba (and other recurrent architectures),
- // update the correct state(s)/sequence(s) for each token of the batch.
- // Like with the KQ_mask, if a token in the batch has multiple sequences,
- // they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv).
- if (lctx.inp_s_seq) {
- const int64_t n_tokens = batch.n_tokens;
- GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer));
- int32_t * data = (int32_t *) lctx.inp_s_seq->data;
+ if (lctx.inp_s_copy) {
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
+ int32_t * data = (int32_t *) lctx.inp_s_copy->data;
- for (int j = 0; j < n_tokens; ++j) {
- const int32_t n_seq = batch.n_seq_id[j];
- GGML_ASSERT(0 < n_seq); // a token should be part of at least 1 sequence
+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
+ for (uint32_t i = 0; i < n_kv; ++i) {
+ const uint32_t cell_id = i + kv_self.head;
+ llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
- for (int i = 0; i < n_kv; ++i) {
- if (i < n_seq) {
- // for this type of model, the head is the minimum seq_id of the batch
- data[j*n_kv + i] = batch.seq_id[j][i] - kv_self.head;
- } else {
- data[j*n_kv + i] = -1;
- }
+ // prevent out-of-bound sources
+ if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) {
+ kv_cell.src = cell_id;
+ }
+
+ data[i] = kv_cell.src;
+
+ // ensure copy only happens once
+ if (kv_cell.src != (int32_t) cell_id) {
+ kv_cell.src = cell_id;
}
}
}
const int64_t n_tokens = batch.n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer));
+ GGML_ASSERT(!batch.equal_seqs); // TODO: use batch.n_seqs instead of failing
int32_t * data = (int32_t *) lctx.inp_pos_bucket->data;
const int64_t n_tokens = batch.n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer));
+ GGML_ASSERT(!batch.equal_seqs); // TODO: use batch.n_seqs instead of failing
float * data = (float *) lctx.inp_KQ_mask_cross->data;
return n_outputs_max;
}
+// make the outputs have the same order they had in the user-provided batch
+static void llama_output_reorder(struct llama_context * ctx) {
+ std::vector<size_t> & out_ids = ctx->sbatch.out_ids;
+ if (!out_ids.empty()) {
+ uint32_t n_vocab = ctx->model.hparams.n_vocab;
+ uint32_t n_embd = ctx->model.hparams.n_embd;
+ int32_t n_outputs = ctx->n_outputs;
+ GGML_ASSERT((size_t) n_outputs == out_ids.size());
+ // TODO: is there something more efficient which also minimizes swaps?
+ // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
+ for (int32_t i = 0; i < n_outputs - 1; ++i) {
+ int32_t j_min = i;
+ for (int32_t j = i + 1; j < n_outputs; ++j) {
+ if (out_ids[j] < out_ids[j_min]) {
+ j_min = j;
+ }
+ }
+ if (j_min == i) { continue; }
+ std::swap(out_ids[i], out_ids[j_min]);
+ if (ctx->logits_size > 0) {
+ for (uint32_t k = 0; k < n_vocab; k++) {
+ std::swap(ctx->logits[i*n_vocab + k], ctx->logits[j_min*n_vocab + k]);
+ }
+ }
+ if (ctx->embd_size > 0) {
+ for (uint32_t k = 0; k < n_embd; k++) {
+ std::swap(ctx->embd[i*n_embd + k], ctx->embd[j_min*n_embd + k]);
+ }
+ }
+ }
+ std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1);
+ for (int32_t i = 0; i < n_outputs; ++i) {
+ ctx->output_ids[out_ids[i]] = i;
+ }
+ out_ids.clear();
+ }
+}
static void llama_graph_compute(
llama_context & lctx,
const auto n_ubatch = cparams.n_ubatch;
- // TODO: simplify or deprecate
- std::vector<llama_pos> pos;
- std::vector<int32_t> n_seq_id;
- std::vector<llama_seq_id *> seq_id_arr;
- std::vector<std::vector<llama_seq_id>> seq_id;
-
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
+ lctx.embd_seq.clear();
+
// count outputs
if (batch_all.logits && !embd_pooled) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
n_outputs = 1;
}
+ lctx.sbatch.from_batch(batch_all, n_embd,
+ /* simple_split */ !kv_self.recurrent,
+ /* logits_all */ n_outputs == n_tokens_all);
+
// reserve output buffer
if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
return -2;
};
- // set output mappings
- if (batch_all.logits) {
- int32_t i_logits = 0;
- for (uint32_t i = 0; i < n_tokens_all; ++i) {
- if (batch_all.logits[i]) {
- lctx.output_ids[i] = i_logits++;
+ while (lctx.sbatch.n_tokens > 0) {
+ llama_ubatch ubatch;
+ if (kv_self.recurrent) {
+ if (embd_pooled) {
+ // Pooled embeddings cannot be split across ubatches (yet)
+ ubatch = lctx.sbatch.split_seq(n_ubatch);
+ } else {
+ // recurrent model architectures are easier to implement
+ // with equal-length sequences
+ ubatch = lctx.sbatch.split_equal(n_ubatch);
}
+ } else {
+ ubatch = lctx.sbatch.split_simple(n_ubatch);
}
- } else {
- for (uint32_t i = 0; i < n_outputs; ++i) {
- lctx.output_ids[i] = i;
- }
- }
-
- for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) {
- const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
- llama_batch u_batch = {
- /* .n_tokens = */ (int32_t) n_tokens,
- /* .token = */ batch_all.token ? batch_all.token + cur_token : nullptr,
- /* .embd = */ batch_all.embd ? batch_all.embd + cur_token*n_embd : nullptr,
- /* .pos = */ batch_all.pos ? batch_all.pos + cur_token : nullptr,
- /* .n_seq_id = */ batch_all.n_seq_id ? batch_all.n_seq_id + cur_token : nullptr,
- /* .seq_id = */ batch_all.seq_id ? batch_all.seq_id + cur_token : nullptr,
- /* .logits = */ batch_all.logits ? batch_all.logits + cur_token : nullptr,
- /* .all_pos_0 = */ batch_all.all_pos_0 + (llama_pos) cur_token*batch_all.all_pos_1,
- /* .all_pos_1 = */ batch_all.all_pos_1,
- /* .all_seq_id = */ batch_all.all_seq_id,
- };
+ const uint32_t n_tokens = ubatch.n_tokens;
// count the outputs in this u_batch
{
int32_t n_outputs_new = 0;
- if (u_batch.logits && !embd_pooled) {
- for (uint32_t i = 0; i < n_tokens; i++) {
- n_outputs_new += u_batch.logits[i] != 0;
- }
- } else if (n_outputs == n_tokens_all) {
+ if (n_outputs == n_tokens_all) {
n_outputs_new = n_tokens;
} else {
- // keep last output only
- if (cur_token + n_tokens >= n_tokens_all) {
- n_outputs_new = 1;
+ GGML_ASSERT(ubatch.output);
+ for (uint32_t i = 0; i < n_tokens; i++) {
+ n_outputs_new += (int32_t) (ubatch.output[i] != 0);
}
}
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
GGML_ASSERT(n_threads > 0);
- // helpers for smoother batch API transition
- // after deprecating the llama_eval calls, these will be removed
- if (u_batch.pos == nullptr) {
- pos.resize(n_tokens);
- for (uint32_t i = 0; i < n_tokens; i++) {
- pos[i] = u_batch.all_pos_0 + i*u_batch.all_pos_1;
- }
-
- u_batch.pos = pos.data();
- }
-
- if (u_batch.seq_id == nullptr) {
- n_seq_id.resize(n_tokens);
- seq_id.resize(n_tokens);
- seq_id_arr.resize(n_tokens);
- for (uint32_t i = 0; i < n_tokens; i++) {
- n_seq_id[i] = 1;
- seq_id[i].resize(1);
- seq_id[i][0] = u_batch.all_seq_id;
- seq_id_arr[i] = seq_id[i].data();
- }
-
- u_batch.n_seq_id = n_seq_id.data();
- u_batch.seq_id = seq_id_arr.data();
- }
-
// non-causal masks do not use the KV cache
if (hparams.causal_attn) {
llama_kv_cache_update(&lctx);
kv_self.head = 0;
}
- if (!llama_kv_cache_find_slot(kv_self, u_batch)) {
+ if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
return 1;
}
ggml_backend_sched_reset(lctx.sched);
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
- ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false);
+ ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false);
// the output is always the last tensor in the graph
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
ggml_backend_sched_alloc_graph(lctx.sched, gf);
- llama_set_inputs(lctx, u_batch);
+ llama_set_inputs(lctx, ubatch);
llama_graph_compute(lctx, gf, n_threads);
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST:
{
- // extract sequence embeddings
+ // extract sequence embeddings (cleared before processing each batch)
auto & embd_seq_out = lctx.embd_seq;
- embd_seq_out.clear();
- for (uint32_t i = 0; i < n_tokens; i++) {
- const llama_seq_id seq_id = u_batch.seq_id[i][0];
+ for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
+ const llama_seq_id seq_id = ubatch.seq_id[s][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
n_outputs_prev += lctx.n_outputs;
}
+ // set output mappings
+ {
+ bool sorted_output = true;
+
+ GGML_ASSERT(lctx.sbatch.out_ids.size() == n_outputs);
+
+ for (size_t i = 0; i < n_outputs; ++i) {
+ size_t out_id = lctx.sbatch.out_ids[i];
+ lctx.output_ids[out_id] = i;
+ if (out_id != i) {
+ sorted_output = false;
+ }
+ }
+
+ if (sorted_output) {
+ lctx.sbatch.out_ids.clear();
+ }
+ }
+
// set to total number of outputs in the batch, for use in llama_get_logits_ith
lctx.n_outputs = n_outputs;
const int64_t n_embd = hparams.n_embd;
- // TODO: simplify or deprecate
- std::vector<llama_pos> pos;
- std::vector<int32_t> n_seq_id;
- std::vector<llama_seq_id *> seq_id_arr;
- std::vector<std::vector<llama_seq_id>> seq_id;
+ lctx.sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
+
+ const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens);
// reserve output buffer
if (llama_output_reserve(lctx, n_tokens) < n_tokens) {
const int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
GGML_ASSERT(n_threads > 0);
- // helpers for smoother batch API transition
- // after deprecating the llama_eval calls, these will be removed
- if (batch.pos == nullptr) {
- pos.resize(n_tokens);
- for (uint32_t i = 0; i < n_tokens; i++) {
- pos[i] = batch.all_pos_0 + i*batch.all_pos_1;
- }
-
- batch.pos = pos.data();
- }
-
- if (batch.seq_id == nullptr) {
- n_seq_id.resize(n_tokens);
- seq_id.resize(n_tokens);
- seq_id_arr.resize(n_tokens);
- for (uint32_t i = 0; i < n_tokens; i++) {
- n_seq_id[i] = 1;
- seq_id[i].resize(1);
- seq_id[i][0] = batch.all_seq_id;
- seq_id_arr[i] = seq_id[i].data();
- }
-
- batch.n_seq_id = n_seq_id.data();
- batch.seq_id = seq_id_arr.data();
- }
-
ggml_backend_sched_reset(lctx.sched);
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
- ggml_cgraph * gf = llama_build_graph(lctx, batch, false);
+ ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false);
// the output embeddings after the final encoder normalization
struct ggml_tensor * embd = nullptr;
ggml_backend_sched_alloc_graph(lctx.sched, gf);
- llama_set_inputs(lctx, batch);
+ llama_set_inputs(lctx, ubatch);
llama_graph_compute(lctx, gf, n_threads);
float * embd_out = lctx.embd_enc.data();
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
+ GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
// remember the sequence ids used during the encoding - needed for cross attention later
lctx.seq_ids_enc.resize(n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) {
- for (int s = 0; s < batch.n_seq_id[i]; s++) {
- llama_seq_id seq_id = batch.seq_id[i][s];
+ for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
+ llama_seq_id seq_id = ubatch.seq_id[i][s];
lctx.seq_ids_enc[i].insert(seq_id);
}
}
auto & embd_seq_out = lctx.embd_seq;
embd_seq_out.clear();
+ GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
+
for (uint32_t i = 0; i < n_tokens; i++) {
- const llama_seq_id seq_id = batch.seq_id[i][0];
+ const llama_seq_id seq_id = ubatch.seq_id[i][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
}
}
- if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) {
- {
- ggml_backend_sched_reset(lctx.sched);
-
- ggml_cgraph * gf = llama_build_graph_s_copy(lctx);
-
- ggml_backend_sched_alloc_graph(lctx.sched, gf);
-
- llama_set_s_copy(lctx);
-
- llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
-
- need_reserve = true;
- }
-
- {
- auto & kv_self = lctx.kv_self;
-
- kv_self.do_copy = false;
-
- for (uint32_t i = 0; i < kv_self.size; ++i) {
- kv_self.cells[i].src = i;
- }
- }
- }
-
// defragment the KV cache if needed
if (lctx.kv_self.do_defrag) {
llama_kv_cache_defrag_internal(lctx);
if (need_reserve) {
// TODO: extract to a function
// build worst-case graph
- int n_tokens = (int)std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch);
- int n_past = lctx.cparams.n_ctx - n_tokens;
+ uint32_t n_seqs = 1; // TODO: worst-case number of sequences
+ uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch);
llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
- ggml_cgraph * gf = llama_build_graph(lctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true);
+ llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
+ ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
// initialize scheduler with the worst-case graph
ggml_backend_sched_reset(lctx.sched);
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
// sanity checks
- //
- // - qs.n_attention_wv == 0 for Mamba models
- // - qs.n_attention_wv == model.hparams.n_layer for Transformer models
- // - qs.n_attention_wv == 3 * model.hparams.n_layer for Encoder-Decoder models
- //
- GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer || qs.n_attention_wv == 3 * (int)model.hparams.n_layer) && "n_attention_wv is unexpected");
+ {
+ const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
+ // attention layers have a non-zero number of kv heads
+ int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
+ if (llama_model_has_encoder(&model)) {
+ n_attn_layer *= 3;
+ }
+ GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
+ }
size_t total_size_org = 0;
size_t total_size_new = 0;
ggml_type type_v = params.type_v;
// Mamba only needs a constant number of KV cache cells per sequence
- if (model->arch == LLM_ARCH_MAMBA) {
+ if (llama_model_is_recurrent(model)) {
// Mamba needs at least as many KV cells as there are sequences kept at any time
kv_size = std::max((uint32_t) 1, params.n_seq_max);
// it's probably best to keep as much precision as possible for the states
}
// build worst-case graph
- int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_ubatch);
- int n_past = cparams.n_ctx - n_tokens;
+ uint32_t n_seqs = 1; // TODO: worst-case number of sequences
+ uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
- ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true);
+ llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
+ ggml_cgraph * gf = llama_build_graph(*ctx, ubatch, true);
// initialize scheduler with the worst-case graph
if (!ggml_backend_sched_reserve(ctx->sched, gf)) {
return model->hparams.dec_start_token_id;
}
+bool llama_model_is_recurrent(const struct llama_model * model) {
+ switch (model->arch) {
+ case LLM_ARCH_MAMBA: return true;
+ default: return false;
+ }
+}
+
uint32_t llama_model_quantize(
const char * fname_inp,
const char * fname_out,
write_string(rng_str);
}
- void write_output_ids(const struct llama_context * ctx) {
+ void write_output_ids(struct llama_context * ctx) {
+ llama_output_reorder(ctx);
+
const uint32_t n_outputs = ctx->n_outputs;
std::vector<int32_t> output_pos;
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
- llama_batch batch = llama_batch_init(cell_count, 0, 1);
+ llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
batch.n_tokens = cell_count;
+ batch.n_seq_tokens = cell_count;
+ batch.n_seqs = 1;
+
for (uint32_t i = 0; i < cell_count; ++i) {
llama_pos pos;
uint32_t n_seq_id;
}
batch.pos[i] = pos;
- batch.n_seq_id[i] = 1;
- batch.seq_id[i][0] = dest_seq_id;
}
+ batch.n_seq_id[0] = 1;
+ batch.seq_id[0] = &dest_seq_id;
if (!llama_kv_cache_find_slot(kv_self, batch)) {
- llama_batch_free(batch);
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
return false;
}
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
-
- // Cleanup
- llama_batch_free(batch);
} else {
// whole KV cache restore
}
cell.seq_id.insert(seq_id);
+
+ if (kv_self.recurrent) {
+ int32_t & tail = kv_self.cells[seq_id].tail;
+ if (tail != -1) {
+ LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
+ return false;
+ }
+ tail = i;
+ }
}
}
kv_self.used = cell_count;
}
+ if (kv_self.recurrent) {
+ for (uint32_t i = 0; i < cell_count; ++i) {
+ uint32_t cell_id = kv_self.head + i;
+ // make sure the recurrent states will keep their restored state
+ kv_self.cells[cell_id].src = cell_id;
+ }
+ }
+
return true;
}
}
struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
- llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
+ llama_batch batch = {
+ /*n_tokens =*/ 0,
+ /*tokens =*/ nullptr,
+ /*embd =*/ nullptr,
+ /*pos =*/ nullptr,
+ /*n_seq_id =*/ nullptr,
+ /*seq_id =*/ nullptr,
+ /*logits =*/ nullptr,
+ /*all_pos_0 =*/ 0,
+ /*all_pos_1 =*/ 0,
+ /*all_seq_id =*/ 0,
+ };
if (embd) {
batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
float * llama_get_logits(struct llama_context * ctx) {
llama_synchronize(ctx);
+ // reorder logits for backward compatibility
+ // TODO: maybe deprecate this
+ llama_output_reorder(ctx);
+
return ctx->logits;
}
float * llama_get_embeddings(struct llama_context * ctx) {
llama_synchronize(ctx);
+ // reorder embeddings for backward compatibility
+ // TODO: maybe deprecate this
+ llama_output_reorder(ctx);
+
return ctx->embd;
}