const int batch_size = std::min(end - batch_start, n_batch);
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
+ // TODO: use llama_batch.logits instead of relying on logits_all == true
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
//fprintf(stderr, "%s : failed to eval\n", __func__);
return {tokens, -1, logit_history, prob_history};
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);
+ int n_outputs = 0;
+
batch.n_tokens = 0;
for (int seq = 0; seq < n_seq_batch; seq++) {
int seq_start = batch_start + seq*n_ctx;
for (int k = 0; k < batch_size; ++k) {
const int idx = seq*n_ctx + k;
- batch.token[idx] = tokens[seq_start + k];
- batch.pos[idx] = j*n_batch + k;
- batch.n_seq_id[idx] = 1;
- batch.seq_id[idx][0] = seq;
- batch.logits[idx] = batch.pos[idx] >= first ? 1 : 0;
+ batch.token [idx] = tokens[seq_start + k];
+ batch.pos [idx] = j*n_batch + k;
+ batch.n_seq_id[idx] = 1;
+ batch.seq_id [idx][0] = seq;
+ batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0;
+
+ n_outputs += batch.logits[idx] != 0;
}
batch.n_tokens += batch_size;
return {tokens, -1, logit_history, prob_history};
}
- if (num_batches > 1) {
+ if (num_batches > 1 && n_outputs > 0) {
const auto * batch_logits = llama_get_logits(ctx);
- logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
+ logits.insert(logits.end(), batch_logits, batch_logits + n_outputs * n_vocab);
}
}
}
for (int seq = 0; seq < n_seq_batch; seq++) {
- const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx);
+ const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
+
llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
if (!params.logits_file.empty()) {
- process_logits(logits_stream, n_vocab, all_logits + first*n_vocab,
+ process_logits(logits_stream, n_vocab, all_logits,
tokens_data, n_ctx - 1 - first,
workers, log_probs, nll, nll2);
} else {
- process_logits(n_vocab, all_logits + first*n_vocab,
+ process_logits(n_vocab, all_logits,
tokens_data, n_ctx - 1 - first,
workers, nll, nll2,
logit_history.data() + start + seq*n_ctx + first,
}
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int32_t n_batch, int32_t n_vocab) {
+ int prev_outputs = 0;
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
return false;
}
- memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float));
+ int n_outputs = 0;
+ for (int i = 0; i < n_tokens; ++i) {
+ n_outputs += batch_view.logits[i] != 0;
+ }
+
+ memcpy(batch_logits.data() + prev_outputs*n_vocab, llama_get_logits(ctx), n_outputs*n_vocab*sizeof(float));
+
+ prev_outputs += n_outputs;
}
return true;
size_t ending_logprob_count[4];
double ending_logprob[4];
- size_t i_batch; // starting index in the llama_batch
+ size_t i_logits; // starting index of logits in the llama_batch
size_t common_prefix; // max number of initial tokens that are the same in all sentences
size_t required_tokens; // needed number of tokens to evaluate all 4 endings
std::vector<llama_token> seq_tokens[4];
const int max_tasks_per_batch = 32;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
- llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
+ llama_batch batch = llama_batch_init(n_ctx, 0, 4);
std::vector<float> tok_logits(n_vocab);
+ // TODO: this could be made smaller; it's currently the worst-case size
std::vector<float> batch_logits(n_vocab*n_ctx);
std::vector<std::pair<size_t, llama_token>> eval_pairs;
int n_cur = 0;
size_t i1 = i0;
- size_t i_batch = 0; // this tells us where in `llama_batch` we are currently
+ size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
llama_batch_clear(batch);
// batch as much tasks as possible into the available context
- // each task has 4 unique seuqnce ids - one for each ending
+ // each task has 4 unique sequence ids - one for each ending
// the common prefix is shared among the 4 sequences to save tokens
// we extract logits only from the last common token and from all ending tokens of each sequence
while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) {
auto & hs_cur = hs_data[i1];
+ int n_logits = 0;
const int s0 = 4*(i1 - i0);
if (s0 + 4 > max_seq) {
}
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
- llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
+ llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
}
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
+ n_logits += 1;
for (int s = 0; s < 4; ++s) {
- for (size_t i = hs_cur.common_prefix; i < hs_cur.seq_tokens[s].size(); ++i) {
- llama_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, true);
+ const size_t seq_tokens_size = hs_cur.seq_tokens[s].size();
+ // TODO: don't evaluate the last token of each sequence
+ for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) {
+ const bool needs_logits = i < seq_tokens_size - 1;
+ llama_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
+ n_logits += needs_logits;
}
}
- hs_cur.i_batch = i_batch;
- i_batch += hs_cur.required_tokens;
+ hs_cur.i_logits = i_logits;
+ i_logits += n_logits;
n_cur += hs_data[i1].required_tokens;
if (++i1 == hs_task_count) {
eval_pairs.clear();
for (size_t i = i0; i < i1; ++i) {
auto & hs_cur = hs_data[i];
- size_t li = hs_cur.common_prefix;
+ size_t li = 1; // skip the last logit of the common prefix (computed separately below)
for (int s = 0; s < 4; ++s) {
for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
- eval_pairs.emplace_back(hs_cur.i_batch + li++, hs_cur.seq_tokens[s][j + 1]);
+ eval_pairs.emplace_back(hs_cur.i_logits + li++, hs_cur.seq_tokens[s][j + 1]);
}
- ++li;
}
}
// Then we do the actual calculation
for (size_t i = i0; i < i1; ++i) {
auto & hs_cur = hs_data[i];
- std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + hs_cur.common_prefix - 1), n_vocab*sizeof(float));
+ // get the logits of the last token of the common prefix
+ std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*hs_cur.i_logits, n_vocab*sizeof(float));
const auto first_probs = softmax(tok_logits);
std::array<std::string, 2> choices;
int answer;
- size_t i_batch;
+ size_t i_logits;
size_t common_prefix;
size_t required_tokens;
size_t n_base1; // number of tokens for context + choice 1
task.common_prefix++;
}
+ // TODO: the last token of each of the sequences don't need to be evaluated
task.required_tokens = task.common_prefix +
task.seq_tokens[0].size() - task.common_prefix +
task.seq_tokens[1].size() - task.common_prefix;
const int max_tasks_per_batch = 128;
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
- llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
+ llama_batch batch = llama_batch_init(n_ctx, 0, 2);
std::vector<float> tok_logits(n_vocab);
+ // TODO: this could be made smaller; it's currently the worst-case size
std::vector<float> batch_logits(n_vocab*n_ctx);
std::vector<std::pair<size_t, llama_token>> eval_pairs;
int n_cur = 0;
size_t i1 = i0;
- size_t i_batch = 0;
+ size_t i_logits = 0;
llama_batch_clear(batch);
while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
+ int n_logits = 0;
const int s0 = 2*(i1 - i0);
if (s0 + 2 > max_seq) {
break;
}
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
- llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1}, false);
+ llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
}
batch.logits[batch.n_tokens - 1] = true;
+ n_logits += 1;
for (int s = 0; s < 2; ++s) {
+ // TODO: end before the last token, no need to predict past the end of the sequences
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
llama_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
+ n_logits += 1;
}
}
- data[i1].i_batch = i_batch;
- i_batch += data[i1].required_tokens;
+ data[i1].i_logits = i_logits;
+ i_logits += n_logits;
n_cur += data[i1].required_tokens;
if (++i1 == data.size()) {
const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
- size_t li = n_base1 - 1;
+ size_t li = n_base1 - task.common_prefix;
for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
- eval_pairs.emplace_back(task.i_batch + li++, task.seq_tokens[0][j+1]);
+ eval_pairs.emplace_back(task.i_logits + li++, task.seq_tokens[0][j+1]);
}
const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
- li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - 1;
+ // FIXME: this uses the wrong first logits when not skipping the choice word
+ li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - task.common_prefix;
for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
- eval_pairs.emplace_back(task.i_batch + li++, task.seq_tokens[1][j+1]);
+ eval_pairs.emplace_back(task.i_logits + li++, task.seq_tokens[1][j+1]);
}
}
compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
}
// For evaluation
- size_t i_batch; // starting index in the llama_batch
+ size_t i_logits; // starting index of logits in the llama_batch
size_t common_prefix; // max number of initial tokens that are the same in all sentences
size_t required_tokens; // needed number of tokens to evaluate all answers
std::vector<std::vector<llama_token>> seq_tokens;
std::vector<uint32_t> task_pos(n_task);
strstream.read((char *)task_pos.data(), task_pos.size()*sizeof(uint32_t));
if (strstream.fail()) {
- printf("%s: failed to raad task positions from prompt\n", __func__);
+ printf("%s: failed to read task positions from prompt\n", __func__);
return;
}
return;
}
} else {
- int n_dot = n_task/100;
+ int n_dot = std::max((int) n_task/100, 1);
int i_task = 0;
for (auto& task : tasks) {
++i_task;
int n_cur = 0;
size_t i1 = i0;
- size_t i_batch = 0; // this tells us where in `llama_batch` we are currently
+ size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
llama_batch_clear(batch);
// batch as much tasks as possible into the available context
- // each task has 4 unique seuqnce ids - one for each ending
+ // each task has 4 unique sequence ids - one for each ending
// the common prefix is shared among the 4 sequences to save tokens
// we extract logits only from the last common token and from all ending tokens of each sequence
int s0 = 0;
while (n_cur + (int) tasks[i1].required_tokens <= n_ctx) {
auto& cur_task = tasks[i1];
+ int n_logits = 0;
int num_answers = cur_task.seq_tokens.size();
if (s0 + num_answers > max_seq) {
llama_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
}
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
+ n_logits += 1;
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
- for (size_t i = cur_task.common_prefix; i < cur_task.seq_tokens[s].size(); ++i) {
- llama_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, true);
+ const size_t seq_tokens_size = cur_task.seq_tokens[s].size();
+ // TODO: don't evaluate the last token of each sequence
+ for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) {
+ const bool needs_logits = i < seq_tokens_size - 1;
+ llama_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
+ n_logits += needs_logits;
}
}
s0 += num_answers;
- cur_task.i_batch = i_batch;
- i_batch += cur_task.required_tokens;
+ cur_task.i_logits = i_logits;
+ i_logits += n_logits;
n_cur += cur_task.required_tokens;
if (++i1 == tasks.size()) {
eval_pairs.clear();
for (size_t i = i0; i < i1; ++i) {
auto& cur_task = tasks[i];
- size_t li = cur_task.common_prefix;
+ size_t li = 1; // skip the last logit of the common prefix (computed separately below)
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) {
- eval_pairs.emplace_back(cur_task.i_batch + li++, cur_task.seq_tokens[s][j + 1]);
+ eval_pairs.emplace_back(cur_task.i_logits + li++, cur_task.seq_tokens[s][j + 1]);
}
- ++li;
}
}
// Then we do the actual calculation
//}
//printf("\n common_prefix: %zu\n", cur_task.common_prefix);
- std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(cur_task.i_batch + cur_task.common_prefix - 1), n_vocab*sizeof(float));
+ // get the logits of the last token of the common prefix
+ std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*cur_task.i_logits, n_vocab*sizeof(float));
const auto first_probs = softmax(tok_logits);
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
}
+ // TODO: use llama_batch.logits instead of relying on logits_all == true
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
uint32_t n_ctx; // context size used during inference
uint32_t n_batch;
uint32_t n_ubatch;
+ uint32_t n_seq_max;
uint32_t n_threads; // number of threads to use for generation
uint32_t n_threads_batch; // number of threads to use for batch processing
// host buffer for the model output (logits and embeddings)
ggml_backend_buffer_t buf_output = nullptr;
- // decode output (2-dimensional array: [n_tokens][n_vocab])
- size_t logits_size = 0;
- float * logits = nullptr;
+ // decode output (2-dimensional array: [n_outputs][n_vocab])
+ size_t logits_size = 0; // capacity (of floats) for logits
+ float * logits = nullptr;
+
+ std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
+ size_t output_size = 0; // capacity (of tokens positions) for the output buffers
+ int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch
-#ifndef NDEBUG
- // guard against access to unset logits
- std::vector<bool> logits_valid;
-#endif
bool logits_all = false;
- // embeddings output (2-dimensional array: [n_tokens][n_embd])
+ // embeddings output (2-dimensional array: [n_outputs][n_embd])
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
- size_t embd_size = 0;
- float * embd = nullptr;
+ size_t embd_size = 0; // capacity (of floats) for embeddings
+ float * embd = nullptr;
// sequence embeddings output (map of [n_embd] vectors)
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
struct ggml_tensor * inp_tokens; // I32 [n_batch]
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
struct ggml_tensor * inp_pos; // I32 [n_batch]
+ struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
- struct ggml_tensor * inp_KQ_pos; // F32 [kv_size]
+ struct ggml_tensor * inp_KQ_pos; // F32 [n_kv]
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
struct ggml_tensor * inp_cls; // I32 [n_batch]
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
- struct ggml_tensor * inp_s_mask; // F32 [1, kv_size]
- struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch]
+ struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
+ struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
// control vectors
struct llama_control_vector cvec;
const float norm_rms_eps;
const int32_t n_tokens;
- const int32_t n_kv; // size of KV cache to consider (n_kv <= n_ctx)
+ const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size)
+ const int32_t n_outputs;
const int32_t kv_head; // index of where we store new KV data in the cache
const int32_t n_orig_ctx;
norm_rms_eps (hparams.f_norm_rms_eps),
n_tokens (batch.n_tokens),
n_kv (worst_case ? kv_self.size : kv_self.n),
+ n_outputs (worst_case ? n_tokens : lctx.n_outputs),
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
n_orig_ctx (cparams.n_yarn_orig_ctx),
pooling_type (cparams.pooling_type),
lctx.inp_tokens = nullptr;
lctx.inp_embd = nullptr;
lctx.inp_pos = nullptr;
+ lctx.inp_out_ids = nullptr;
lctx.inp_KQ_mask = nullptr;
lctx.inp_KQ_pos = nullptr;
lctx.inp_K_shift = nullptr;
return lctx.inp_pos;
}
+ struct ggml_tensor * build_inp_out_ids() {
+ lctx.inp_out_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs);
+ cb(lctx.inp_out_ids, "inp_out_ids", -1);
+ ggml_set_input(lctx.inp_out_ids);
+ return lctx.inp_out_ids;
+ }
+
struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
if (causal) {
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, n_tokens);
struct ggml_cgraph * build_llama() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+ // mutable variable, needed during the last layer of the computation to skip unused tokens
+ int32_t n_tokens = this->n_tokens;
+
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ n_tokens = n_outputs;
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ attn_norm = ggml_get_rows(ctx0, attn_norm, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = cur;
// feed forward
struct ggml_cgraph * build_grok() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+ // mutable variable, needed during the last layer of the computation to skip unused tokens
+ int32_t n_tokens = this->n_tokens;
+
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ n_tokens = n_outputs;
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
// Grok
// if attn_out_norm is present then apply it before adding the input
if (model.layers[il].attn_out_norm) {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
// add the input
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
Kcur, Vcur, Q, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ residual = ggml_get_rows(ctx0, residual, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur);
cb(ffn_inp, "ffn_inp", il);
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
}
cb(cur, "kqv_out", il);
+ if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
// re-add the layer input
cur = ggml_add(ctx0, cur, inpL);
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
// Add the input
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
// Add the input
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ attn_norm_output = ggml_get_rows(ctx0, attn_norm_output, inp_out_ids);
+ }
+
// FF
{
ffn_output = llm_build_ffn(ctx0, attn_norm_output,
cur = attention_norm;
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ sa_out = ggml_get_rows(ctx0, sa_out, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
// feed-forward network
{
cur = llm_build_ffn(ctx0, cur,
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
// add the input
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
// add the input
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
// scale_res - scale the hidden states for residual connection
const float scale_res = scale_depth/sqrtf(float(n_layer));
cur = ggml_scale(ctx0, cur, scale_res);
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
cb(sa_out, "sa_out", il);
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0);
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ x = ggml_get_rows(ctx0, x, inp_out_ids);
+ y = ggml_get_rows(ctx0, y, inp_out_ids);
+ z = ggml_get_rows(ctx0, z, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
// {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens}
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
struct ggml_tensor * attn_out = cur;
// feed-forward network
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
}
+ if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
+ GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
+ const int64_t n_tokens = batch.n_tokens;
+
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer));
+ int32_t * data = (int32_t *) lctx.inp_out_ids->data;
+
+ if (lctx.n_outputs == n_tokens) {
+ for (int i = 0; i < n_tokens; ++i) {
+ data[i] = i;
+ }
+ } else if (batch.logits) {
+ int32_t n_outputs = 0;
+ for (int i = 0; i < n_tokens; ++i) {
+ if (batch.logits[i]) {
+ data[n_outputs++] = i;
+ }
+ }
+ // the graph needs to have been passed the correct number of outputs
+ GGML_ASSERT(lctx.n_outputs == n_outputs);
+ } else if (lctx.n_outputs == 1) {
+ // only keep last output
+ data[0] = n_tokens - 1;
+ } else {
+ GGML_ASSERT(lctx.n_outputs == 0);
+ }
+ }
+
GGML_ASSERT(
+ // (!a || b) is a logical implication (a -> b)
+ // !hparams.causal_attn -> !cparams.causal_attn
(hparams.causal_attn || !cparams.causal_attn) &&
- "non-causal attention with generative models is not supported"
+ "causal attention with embedding models is not supported"
);
if (lctx.inp_KQ_mask) {
}
}
+// Make sure enough space is available for outputs.
+// Returns max number of outputs for which space was reserved.
+static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
+ const auto & cparams = lctx.cparams;
+ const auto & hparams = lctx.model.hparams;
+
+ const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max);
+
+ const auto n_batch = cparams.n_batch;
+ const auto n_vocab = hparams.n_vocab;
+ const auto n_embd = hparams.n_embd;
+
+ // TODO: use a per-batch flag for logits presence instead
+ const bool has_logits = cparams.causal_attn;
+ const bool has_embd = cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
+
+ const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
+ const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
+
+ if (lctx.output_ids.empty()) {
+ // init, never resized afterwards
+ lctx.output_ids.resize(n_batch);
+ }
+
+ const size_t prev_size = lctx.buf_output ? ggml_backend_buffer_get_size(lctx.buf_output) : 0;
+ const size_t new_size = (logits_size + embd_size) * sizeof(float);
+
+ // alloc only when more than the current capacity is required
+ // TODO: also consider shrinking the buffer
+ if (!lctx.buf_output || prev_size < new_size) {
+ if (lctx.buf_output) {
+#ifndef NDEBUG
+ // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
+ LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
+#endif
+ ggml_backend_buffer_free(lctx.buf_output);
+ lctx.buf_output = nullptr;
+ lctx.logits = nullptr;
+ lctx.embd = nullptr;
+ }
+
+ lctx.buf_output = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), new_size);
+ if (lctx.buf_output == nullptr) {
+ LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
+ return 0;
+ }
+ }
+
+ float * output_base = (float *) ggml_backend_buffer_get_base(lctx.buf_output);
+
+ lctx.logits = has_logits ? output_base : nullptr;
+ lctx.embd = has_embd ? output_base + logits_size : nullptr;
+
+ lctx.output_size = n_outputs_max;
+ lctx.logits_size = logits_size;
+ lctx.embd_size = embd_size;
+
+ // set all ids as invalid (negative)
+ std::fill(lctx.output_ids.begin(), lctx.output_ids.end(), -1);
+
+ ggml_backend_buffer_clear(lctx.buf_output, 0);
+
+ lctx.n_outputs = 0;
+
+ return n_outputs_max;
+}
+
+
static void llama_graph_compute(
llama_context & lctx,
ggml_cgraph * gf,
const int64_t n_embd = hparams.n_embd;
const int64_t n_vocab = hparams.n_vocab;
-
- auto * logits_out = lctx.logits;
-
-#ifndef NDEBUG
- auto & logits_valid = lctx.logits_valid;
- logits_valid.clear();
- logits_valid.resize(n_tokens_all);
-
- memset(logits_out, 0, lctx.logits_size*sizeof(float));
-#endif
+ uint32_t n_outputs = 0;
+ uint32_t n_outputs_prev = 0;
const auto n_ubatch = cparams.n_ubatch;
std::vector<llama_seq_id *> seq_id_arr;
std::vector<std::vector<llama_seq_id>> seq_id;
+ // count outputs
+ if (batch_all.logits) {
+ for (uint32_t i = 0; i < n_tokens_all; ++i) {
+ n_outputs += batch_all.logits[i] != 0;
+ }
+ } else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
+ n_outputs = n_tokens_all;
+ } else {
+ // keep last output only
+ n_outputs = 1;
+ }
+
+ // reserve output buffer
+ if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
+ return -2;
+ };
+
+ // set output mappings
+ if (batch_all.logits) {
+ int32_t i_logits = 0;
+ for (uint32_t i = 0; i < n_tokens_all; ++i) {
+ if (batch_all.logits[i]) {
+ lctx.output_ids[i] = i_logits++;
+ }
+ }
+ } else {
+ for (uint32_t i = 0; i < n_outputs; ++i) {
+ lctx.output_ids[i] = i;
+ }
+ }
+
for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) {
const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
llama_batch u_batch = {
/* .all_seq_id = */ batch_all.all_seq_id,
};
+ // count the outputs in this u_batch
+ {
+ int32_t n_outputs_new = 0;
+
+ if (u_batch.logits) {
+ for (uint32_t i = 0; i < n_tokens; i++) {
+ n_outputs_new += u_batch.logits[i] != 0;
+ }
+ } else if (n_outputs == n_tokens_all) {
+ n_outputs_new = n_tokens;
+ } else {
+ // keep last output only
+ if (cur_token + n_tokens >= n_tokens_all) {
+ n_outputs_new = 1;
+ }
+ }
+
+ // needs to happen before the graph is built
+ lctx.n_outputs = n_outputs_new;
+ }
+
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
GGML_ASSERT(n_threads > 0);
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
- if (!hparams.causal_attn) {
+ if (lctx.n_outputs == 0) {
+ // no output
+ res = nullptr;
+ embd = nullptr;
+ } else if (!hparams.causal_attn) {
res = nullptr; // do not extract logits for embedding models such as BERT
// token or sequence embeddings
embd = gf->nodes[gf->n_nodes - 1];
GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
- } else {
- if (strcmp(res->name, "result_output") == 0) {
- // the token embeddings could be the second to last tensor, or the third to last tensor
- if (strcmp(embd->name, "result_norm") != 0) {
- embd = gf->nodes[gf->n_nodes - 3];
- GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
- }
- } else {
- GGML_ASSERT(false && "missing result_output tensor");
+ } else if (cparams.embeddings) {
+ // the embeddings could be in the second to last tensor, or any of the previous tensors
+ int i_embd = gf->n_nodes - 2;
+ for (int i = 3; strcmp(embd->name, "result_norm") != 0; ++i) {
+ i_embd = gf->n_nodes - i;
+ if (i_embd < 0) { break; }
+ embd = gf->nodes[i_embd];
+ }
+ GGML_ASSERT(i_embd >= 0 && "missing result_norm tensor");
+
+ // TODO: use a per-batch flag to know when to skip logits while keeping embeddings
+ if (!cparams.causal_attn) {
+ res = nullptr; // do not extract logits when not needed
+ // skip computing logits
+ // TODO: is this safe?
+ gf->n_nodes = i_embd + 1;
}
+ } else {
+ embd = nullptr; // do not extract embeddings when not needed
+ GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
}
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
//}
// extract logits
- // TODO: do not compute and extract logits if only embeddings are needed
- // update the graphs to skip "result_output" if logits are not needed
if (res) {
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
GGML_ASSERT(backend_res != nullptr);
- if (u_batch.logits) {
- int32_t i_first = -1;
- for (uint32_t i = 0; i < n_tokens; i++) {
- if (u_batch.logits[i] && i_first == -1) {
- i_first = (int32_t) i;
- }
- if (u_batch.logits[i] == 0 || i == n_tokens - 1) {
- if (i_first != -1) {
- int i_last = u_batch.logits[i] == 0 ? i : i + 1;
- // extract logits for the range [i_first, i_last)
- // group the requests to minimize the number of calls to the backend
- ggml_backend_tensor_get_async(backend_res, res,
- logits_out + n_vocab*(cur_token + i_first),
- i_first*n_vocab*sizeof(float),
- (i_last - i_first)*n_vocab*sizeof(float));
- i_first = -1;
- }
- }
-#ifndef NDEBUG
- logits_valid[cur_token + i] = u_batch.logits[i] != 0;;
-#endif
- }
- } else if (lctx.logits_all) {
- ggml_backend_tensor_get_async(backend_res, res, logits_out + n_vocab*cur_token, 0, n_vocab*n_tokens*sizeof(float));
-#ifndef NDEBUG
- std::fill(logits_valid.begin() + cur_token, logits_valid.begin() + cur_token + n_tokens, true);
-#endif
- } else {
- if (cur_token + n_tokens >= n_tokens_all) {
- ggml_backend_tensor_get_async(backend_res, res, logits_out, n_vocab*(n_tokens - 1)*sizeof(float), n_vocab*sizeof(float));
-#ifndef NDEBUG
- logits_valid[0] = true;
-#endif
- }
+ GGML_ASSERT(lctx.logits != nullptr);
+
+ float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
+ const int32_t n_outputs_new = lctx.n_outputs;
+
+ if (n_outputs_new) {
+ GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
+ GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size);
+ ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float));
}
}
// extract embeddings
- if (cparams.embeddings && embd) {
+ if (embd) {
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
GGML_ASSERT(backend_embd != nullptr);
case LLAMA_POOLING_TYPE_NONE:
{
// extract token embeddings
- auto & embd_out = lctx.embd;
-
- if (u_batch.logits) {
- //embd_out.resize(n_embd * n_tokens);
- for (uint32_t i = 0; i < n_tokens; i++) {
- if (u_batch.logits[i] == 0) {
- continue;
- }
- ggml_backend_tensor_get_async(backend_embd, embd, embd_out + n_embd*(i + cur_token), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
- }
+ GGML_ASSERT(lctx.embd != nullptr);
+ float * embd_out = lctx.embd + n_outputs_prev*n_embd;
+ const int32_t n_outputs_new = lctx.n_outputs;
+
+ if (n_outputs_new) {
+ GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
+ GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_embd <= (int64_t) lctx.embd_size);
+ ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_CLS:
} break;
}
}
+ n_outputs_prev += lctx.n_outputs;
}
// wait for the computation to finish (automatically done when obtaining the model output)
const auto & hparams = model->hparams;
auto & cparams = ctx->cparams;
- // TODO: maybe add n_seq_max here too
+ cparams.n_seq_max = std::max(1u, params.n_seq_max);
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor;
// graph outputs buffer
{
- // resized during inference, reserve maximum
- ctx->logits_size = hparams.n_vocab*cparams.n_batch;
- ctx->embd_size = params.embeddings ? hparams.n_embd*cparams.n_batch : 0;
-
- const size_t buf_output_size = (ctx->logits_size + ctx->embd_size)*sizeof(float);
-
- ctx->buf_output = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buf_output_size);
- if (ctx->buf_output == nullptr) {
- LLAMA_LOG_ERROR("%s: failed to allocate logits buffer\n", __func__);
+ // resized during inference when a batch uses more outputs
+ if (llama_output_reserve(*ctx, params.n_seq_max) < params.n_seq_max) {
+ LLAMA_LOG_ERROR("%s: failed to reserve initial output buffer\n", __func__);
llama_free(ctx);
return nullptr;
}
- ggml_backend_buffer_clear(ctx->buf_output, 0);
-
-
- ctx->logits = (float *) ggml_backend_buffer_get_base(ctx->buf_output);
- if (params.embeddings) {
- ctx->embd = ctx->logits + ctx->logits_size;
- }
LLAMA_LOG_INFO("%s: %10s output buffer size = %8.2f MiB\n", __func__,
ggml_backend_buffer_name(ctx->buf_output),
// Returns the *maximum* size of the state
size_t llama_get_state_size(const struct llama_context * ctx) {
+ const auto & cparams = ctx->cparams;
+ const auto & hparams = ctx->model.hparams;
+
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
// for reference, std::mt19937(1337) serializes to 6701 bytes.
const size_t s_rng_size = sizeof(size_t);
const size_t s_rng = LLAMA_MAX_RNG_STATE;
+ const size_t s_n_outputs = sizeof(size_t);
+ // assume worst case for outputs although only currently set ones are serialized
+ const size_t s_output_pos = ctx->cparams.n_batch * sizeof(int32_t);
const size_t s_logits_size = sizeof(size_t);
- // assume worst case for logits although only currently set ones are serialized
- const size_t s_logits = ctx->logits_size * sizeof(float);
+ const size_t s_logits = ctx->logits_size ? cparams.n_batch * hparams.n_vocab * sizeof(float) : 0;
const size_t s_embedding_size = sizeof(size_t);
- const size_t s_embedding = ctx->embd_size * sizeof(float);
+ const size_t s_embedding = ctx->embd_size ? cparams.n_batch * hparams.n_embd * sizeof(float) : 0;
const size_t s_kv_buf_size = sizeof(size_t);
const size_t s_kv_head = sizeof(uint32_t);
const size_t s_kv_size = sizeof(uint32_t);
const size_t s_kv_used = sizeof(uint32_t);
const size_t s_kv = ctx->kv_self.total_size();
- // TODO: assume the max is more than 1 seq_id per KV cell
- const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + sizeof(llama_seq_id);
+ const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id);
const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell;
const size_t s_total = (
+ s_rng_size
+ s_rng
+ + s_n_outputs
+ + s_output_pos
+ s_logits_size
+ s_logits
+ s_embedding_size
std::ostringstream rng_ss;
rng_ss << ctx->rng;
- const std::string & rng_str = rng_ss.str();
+ const std::string & rng_str = rng_ss.str();
const size_t rng_size = rng_str.size();
GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE);
data_ctx->write(rng_str.data(), rng_size);
}
- // copy logits
+ // copy outputs
{
- const size_t logits_size = ctx->logits_size;
+ // Can't use ctx->n_outputs because it's not for the
+ // entire last batch when n_ubatch is smaller than n_batch
+ size_t n_outputs = 0;
- data_ctx->write(&logits_size, sizeof(logits_size));
+ // copy output ids
+ {
+ std::vector<int32_t> output_pos;
- if (logits_size) {
- data_ctx->write(ctx->logits, logits_size * sizeof(float));
+ const size_t n_batch = ctx->cparams.n_batch;
+ const auto & output_ids = ctx->output_ids;
+
+ output_pos.resize(ctx->output_size);
+
+ // build a more compact representation of the output ids
+ for (size_t i = 0; i < n_batch; ++i) {
+ // map an output id to a position in the batch
+ int32_t pos = output_ids[i];
+ if (pos >= 0) {
+ if ((size_t) pos >= n_outputs) {
+ n_outputs = pos + 1;
+ }
+ GGML_ASSERT((size_t) pos < ctx->output_size);
+ output_pos[pos] = i;
+ }
+ }
+
+ data_ctx->write(&n_outputs, sizeof(n_outputs));
+
+ if (n_outputs) {
+ data_ctx->write(output_pos.data(), n_outputs * sizeof(int32_t));
+ }
}
- }
- // copy embeddings
- {
- const size_t embeddings_size = ctx->embd_size;
+ // copy logits
+ {
+ const size_t logits_size = std::min(ctx->logits_size, n_outputs * ctx->model.hparams.n_vocab);
- data_ctx->write(&embeddings_size, sizeof(embeddings_size));
+ data_ctx->write(&logits_size, sizeof(logits_size));
- if (embeddings_size) {
- data_ctx->write(ctx->embd, embeddings_size * sizeof(float));
+ if (logits_size) {
+ data_ctx->write(ctx->logits, logits_size * sizeof(float));
+ }
+ }
+
+ // copy embeddings
+ {
+ const size_t embeddings_size = std::min(ctx->embd_size, n_outputs * ctx->model.hparams.n_embd);
+
+ data_ctx->write(&embeddings_size, sizeof(embeddings_size));
+
+ if (embeddings_size) {
+ data_ctx->write(ctx->embd, embeddings_size * sizeof(float));
+ }
}
}
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
- const size_t kv_buf_size = kv_self.total_size();
+ // NOTE: kv_size and kv_buf_size are mostly used for sanity checks
const uint32_t kv_head = llama_kv_cache_cell_max(kv_self);
const uint32_t kv_size = kv_self.size;
+ const size_t kv_buf_size = kv_self.total_size() / (kv_size ? kv_size : 1) * kv_head;
const uint32_t kv_used = kv_self.used;
data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
data_ctx->write(&kv_used, sizeof(kv_used));
if (kv_buf_size) {
+ const size_t pre_kv_buf_size = data_ctx->get_size_written();
+
std::vector<uint8_t> tmp_buf;
for (int il = 0; il < (int) n_layer; ++il) {
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
data_ctx->write(tmp_buf.data(), tmp_buf.size());
}
}
+ GGML_ASSERT(kv_buf_size == data_ctx->get_size_written() - pre_kv_buf_size);
}
for (uint32_t i = 0; i < kv_head; ++i) {
GGML_ASSERT(!rng_ss.fail());
}
+ // set output ids
+ {
+ size_t n_outputs;
+ std::vector<int32_t> output_pos;
+
+ memcpy(&n_outputs, inp, sizeof(n_outputs)); inp += sizeof(n_outputs);
+
+ GGML_ASSERT(n_outputs <= llama_output_reserve(*ctx, n_outputs));
+
+ if (n_outputs) {
+ output_pos.resize(n_outputs);
+ memcpy(output_pos.data(), inp, n_outputs * sizeof(int32_t));
+ inp += n_outputs * sizeof(int32_t);
+
+ for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
+ int32_t id = output_pos[i];
+ GGML_ASSERT((uint32_t) id < ctx->cparams.n_batch);
+ ctx->output_ids[id] = i;
+ }
+ }
+ }
+
// set logits
{
size_t logits_size;
memcpy(&embeddings_size, inp, sizeof(embeddings_size)); inp += sizeof(embeddings_size);
- GGML_ASSERT(ctx->embd_size == embeddings_size);
+ GGML_ASSERT(ctx->embd_size >= embeddings_size);
if (embeddings_size) {
memcpy(ctx->embd, inp, embeddings_size * sizeof(float));
memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used);
+ if (kv_self.size != kv_size) {
+ // the KV cache needs to be big enough to load all the KV cells from the saved state
+ GGML_ASSERT(kv_self.size >= kv_head);
+
+ LLAMA_LOG_INFO("%s: state contains %d KV cells, was saved with kv_size=%d, but is loaded with kv_size=%d (fine, but different)\n",
+ __func__, kv_head, kv_size, kv_self.size);
+ }
+
if (kv_buf_size) {
- GGML_ASSERT(kv_self.total_size() == kv_buf_size);
+ const size_t pre_kv_buf_size = inp - src;
+
+ GGML_ASSERT(kv_self.total_size() >= kv_buf_size);
for (int il = 0; il < (int) n_layer; ++il) {
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
// v is not contiguous, copy row by row
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
- const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
+ const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_self.size);
for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
ggml_backend_tensor_set(kv_self.v_l[il], inp, ir*v_row_stride, v_row_size);
inp += v_row_size;
}
}
+ GGML_ASSERT(kv_buf_size == inp - src - pre_kv_buf_size);
}
- GGML_ASSERT(kv_self.size == kv_size);
+ llama_kv_cache_clear(ctx);
ctx->kv_self.head = kv_head;
- ctx->kv_self.size = kv_size;
ctx->kv_self.used = kv_used;
- ctx->kv_self.cells.resize(kv_size);
-
for (uint32_t i = 0; i < kv_head; ++i) {
llama_pos pos;
size_t seq_id_size;
ctx->kv_self.cells[i].seq_id.insert(seq_id);
}
}
-
- for (uint32_t i = kv_head; i < kv_size; ++i) {
- ctx->kv_self.cells[i].pos = -1;
- ctx->kv_self.cells[i].seq_id.clear();
- }
}
const size_t nread = inp - src;
}
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
- assert(ctx->logits_valid.at(i));
-
llama_synchronize(ctx);
- return ctx->logits + i*ctx->model.hparams.n_vocab;
+ try {
+ if (ctx->logits == nullptr) {
+ throw std::runtime_error("no logits");
+ }
+ if ((size_t) i >= ctx->output_ids.size()) {
+ throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
+ }
+ const int32_t j = ctx->output_ids[i];
+
+ if (j < 0) {
+ throw std::runtime_error(format("batch.logits[%d] != true", i));
+ }
+ if ((size_t) j >= ctx->output_size) {
+ // This should not happen
+ throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size));
+ }
+
+ return ctx->logits + j*ctx->model.hparams.n_vocab;
+ } catch (const std::exception & err) {
+ LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
+#ifndef NDEBUG
+ GGML_ASSERT(false);
+#endif
+ return nullptr;
+ }
}
float * llama_get_embeddings(struct llama_context * ctx) {
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
llama_synchronize(ctx);
- return ctx->embd + i*ctx->model.hparams.n_embd;
+ try {
+ if (ctx->embd == nullptr) {
+ throw std::runtime_error("no embeddings");
+ }
+ if ((size_t) i >= ctx->output_ids.size()) {
+ throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
+ }
+ const int32_t j = ctx->output_ids[i];
+
+ if (j < 0) {
+ throw std::runtime_error(format("batch.logits[%d] != true", i));
+ }
+ if ((size_t) j >= ctx->output_size) {
+ // This should not happen
+ throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size));
+ }
+
+ return ctx->embd + j*ctx->model.hparams.n_embd;
+ } catch (const std::exception & err) {
+ LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
+#ifndef NDEBUG
+ GGML_ASSERT(false);
+#endif
+ return nullptr;
+ }
}
float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {