llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
GGML_UNUSED(embd_pooled);
- // TODO: if we fail with split_simple, we should attempt different splitting strategies
- // but to do that properly, we first have to refactor the batches to be more flexible
+ // first try simple split
+ do {
+ auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
- auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
+ std::vector<llama_ubatch> ubatches;
- std::vector<llama_ubatch> ubatches;
+ while (sbatch.n_tokens > 0) {
+ auto ubatch = sbatch.split_simple(n_ubatch);
- while (sbatch.n_tokens > 0) {
- auto ubatch = sbatch.split_simple(n_ubatch);
+ ubatches.push_back(ubatch);
+ }
- ubatches.push_back(ubatch);
- }
+ auto heads_base = kv_base->prepare(ubatches);
+ if (heads_base.empty()) {
+ break;
+ }
- auto heads_base = kv_base->prepare(ubatches);
- if (heads_base.empty()) {
- return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
- }
+ auto heads_swa = kv_swa->prepare(ubatches);
+ if (heads_swa.empty()) {
+ break;
+ }
- auto heads_swa = kv_swa->prepare(ubatches);
- if (heads_swa.empty()) {
- return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
- }
+ assert(heads_base.size() == heads_swa.size());
+
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(
+ this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
+ } while (false);
+
+ // if it fails, try equal split
+ do {
+ auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
+
+ std::vector<llama_ubatch> ubatches;
- assert(heads_base.size() == heads_swa.size());
+ while (sbatch.n_tokens > 0) {
+ auto ubatch = sbatch.split_equal(n_ubatch);
+
+ ubatches.push_back(ubatch);
+ }
+
+ auto heads_base = kv_base->prepare(ubatches);
+ if (heads_base.empty()) {
+ break;
+ }
+
+ auto heads_swa = kv_swa->prepare(ubatches);
+ if (heads_swa.empty()) {
+ break;
+ }
+
+ assert(heads_base.size() == heads_swa.size());
+
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(
+ this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
+ } while (false);
+
+ // TODO: if we fail again, we should attempt different splitting strategies
+ // but to do that properly, we first have to refactor the batches to be more flexible
- return std::make_unique<llama_kv_cache_unified_iswa_state>(
- this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
bool logits_all) {
GGML_UNUSED(embd_pooled);
- auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
+ do {
+ auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
- std::vector<llama_ubatch> ubatches;
- while (sbatch.n_tokens > 0) {
- ubatches.push_back(sbatch.split_simple(n_ubatch));
- }
+ std::vector<llama_ubatch> ubatches;
+ while (sbatch.n_tokens > 0) {
+ ubatches.push_back(sbatch.split_simple(n_ubatch));
+ }
- auto heads = prepare(ubatches);
- if (heads.empty()) {
- return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
- }
+ auto heads = prepare(ubatches);
+ if (heads.empty()) {
+ break;
+ }
- return std::make_unique<llama_kv_cache_unified_state>(
- this, std::move(sbatch), std::move(heads), std::move(ubatches));
+ return std::make_unique<llama_kv_cache_unified_state>(
+ this, std::move(sbatch), std::move(heads), std::move(ubatches));
+ } while (false);
+
+ return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
llama_memory_state_ptr llama_kv_cache_unified::init_full() {
}
if (debug > 0) {
- LLAMA_LOG_CONT("\n");
LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa);
if ((debug == 2 && n_swa > 0) || debug > 2) {
if (cells.is_empty(i)) {
ss += '.';
} else {
- ss += std::to_string(cells.seq_get(i));
+ assert(cells.seq_count(i) >= 1);
+
+ if (cells.seq_count(i) == 1) {
+ ss += std::to_string(cells.seq_get(i));
+ } else {
+ ss += 'M';
+ }
}
if (i%256 == 255) {
ss += " *";
}
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
+ if (debug > 0) {
+ LLAMA_LOG_DEBUG("%s: ubatch info:\n", __func__);
+ LLAMA_LOG_DEBUG("%s: n_tokens = %d, equal_seqs = %d\n", __func__, ubatch.n_tokens, ubatch.equal_seqs);
+ LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d, n_seqs = %d\n", __func__, ubatch.n_seq_tokens, ubatch.n_seqs);
+ }
+
// keep track of the max sequence position that we would overwrite with this ubatch
// for non-SWA cache, this would be always empty
llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES];
seq_pos_max_rm[s] = -1;
}
- for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
- if (!cells.is_empty(head_cur + i)) {
- assert(cells.seq_count(head_cur + i) == 1);
+ for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
+ for (uint32_t j = 0; j < ubatch.n_seq_tokens; ++j) {
+ const uint32_t idx = s*ubatch.n_seq_tokens + j;
- const llama_seq_id seq_id = cells.seq_get(head_cur + i);
- const llama_pos pos = cells.pos_get(head_cur + i);
+ if (!cells.is_empty(head_cur + idx)) {
+ assert(cells.seq_count(head_cur + idx) == 1);
- seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
+ const llama_seq_id seq_id = cells.seq_get(head_cur + idx);
+ const llama_pos pos = cells.pos_get(head_cur + idx);
- cells.rm(head_cur + i);
- }
+ seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
+
+ cells.rm(head_cur + idx);
+ }
- cells.pos_set(head_cur + i, ubatch.pos[i]);
+ cells.pos_set(head_cur + idx, ubatch.pos[idx]);
- for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
- cells.seq_add(head_cur + i, ubatch.seq_id[i][j]);
+ for (int32_t i = 0; i < ubatch.n_seq_id[s]; i++) {
+ cells.seq_add(head_cur + idx, ubatch.seq_id[s][i]);
+ }
}
}
seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
}
}
-
// move the head at the end of the slot
head = head_cur + ubatch.n_tokens;
}
}
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
- const int64_t n_tokens = ubatch->n_tokens;
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
- const int64_t n_seqs = ubatch->n_seqs;
+ const uint32_t n_tokens = ubatch->n_tokens;
+ const uint32_t n_seq_tokens = ubatch->n_seq_tokens;
+ const uint32_t n_seqs = ubatch->n_seqs;
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
float * data = (float *) dst->data;
- const auto n_kv = dst->ne[0];
+ const int64_t n_kv = dst->ne[0];
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
// xxxxx-----
// xxxxx-----
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
- for (int h = 0; h < 1; ++h) {
- for (int s = 0; s < n_seqs; ++s) {
+ for (uint32_t h = 0; h < 1; ++h) {
+ for (uint32_t s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[s][0];
- for (int j = 0; j < n_seq_tokens; ++j) {
- const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
+ for (uint32_t j = 0; j < n_seq_tokens; ++j) {
+ const uint32_t idx = s*n_seq_tokens + j;
+
+ const llama_pos p1 = ubatch->pos[idx];
for (uint32_t i = 0; i < n_kv; ++i) {
float f = 0.0f;
f = -INFINITY;
}
- data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
+ data[h*(n_kv*n_tokens) + idx*n_kv + i] = f;
}
}
}
// mask padded tokens
if (data) {
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
- for (uint32_t j = 0; j < n_kv; ++j) {
- data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
+ for (uint32_t j = n_tokens; j < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++j) {
+ for (uint32_t i = 0; i < n_kv; ++i) {
+ data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
}
}
}
seq_rm(dest_seq_id, -1, -1);
llama_sbatch sbatch;
- llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
+ llama_ubatch ubatch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
- batch.n_tokens = cell_count;
+ ubatch.n_tokens = cell_count;
+ ubatch.n_seq_tokens = cell_count;
+ ubatch.n_seqs = 1;
for (uint32_t i = 0; i < cell_count; ++i) {
llama_pos pos;
io.read_to(&seq_id, sizeof(seq_id));
}
- batch.pos[i] = pos;
- batch.n_seq_id[i] = n_seq_id;
- batch.seq_id[i] = &dest_seq_id;
+ ubatch.pos[i] = pos;
+ ubatch.n_seq_id[i] = n_seq_id;
+ ubatch.seq_id[i] = &dest_seq_id;
}
- const auto head_cur = find_slot(batch);
+ const auto head_cur = find_slot(ubatch);
if (head_cur < 0) {
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
return false;
}
- apply_ubatch(head_cur, batch);
+ apply_ubatch(head_cur, ubatch);
// keep the head at the old position because we will read the KV data into it in state_read_data()
head = head_cur;
// DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
// Assume that this is one contiguous block of cells
GGML_ASSERT(head_cur + cell_count <= cells.size());
- GGML_ASSERT(cells.pos_get(head_cur) == batch.pos[0]);
- GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == batch.pos[cell_count - 1]);
+ GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]);
+ GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
} else {