llama_token id_last = inp.back();
// all tokens currently in the target context
- auto prompt_tgt = std::vector<llama_token>(inp.begin(), inp.end() - 1);
+ llama_tokens prompt_tgt(inp.begin(), inp.end() - 1);
+ prompt_tgt.reserve(llama_n_ctx(ctx_tgt));
int n_past = inp.size() - 1;
GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token
n_past += ids.size() - 1;
- n_drafted += batch_tgt.n_tokens - 1;
+ n_drafted += draft.size(); // note: we ignore the discarded small drafts
n_accept += ids.size() - 1;
+ n_predict += ids.size();
// process the accepted tokens and update contexts
//
// this is the standard token post-processing that we normally do
// in this case, we do it for a group of accepted tokens at once
//
- {
- llama_token id;
- std::string token_str;
-
- for (size_t i = 0; i < ids.size(); ++i) {
- id = ids[i];
-
- ++n_predict;
-
- if (llama_token_is_eog(model_tgt, id)) {
- has_eos = true;
- break;
- }
-
- token_str = common_token_to_piece(ctx_tgt, id);
+ for (size_t i = 0; i < ids.size(); ++i) {
+ prompt_tgt.push_back(id_last);
- if (params.use_color && i + 1 < ids.size()) {
- LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str());
- } else {
- LOG("%s", token_str.c_str());
- }
- }
+ id_last = ids[i];
- if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
+ if (llama_token_is_eog(model_tgt, id_last)) {
+ has_eos = true;
break;
}
- LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d, '%s')\n", (int) ids.size() - 1, (int) draft.size(), id, token_str.c_str());
+ const std::string token_str = common_token_to_piece(ctx_tgt, id_last);
- {
- LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
-
- llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
+ if (params.use_color && i + 1 < ids.size()) {
+ LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str());
+ } else {
+ LOG("%s", token_str.c_str());
}
+ }
- prompt_tgt.push_back(id_last);
- prompt_tgt.insert(prompt_tgt.end(), ids.begin(), ids.end() - 1);
+ LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d)\n", (int) ids.size() - 1, (int) draft.size(), id_last);
+
+ {
+ LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
+
+ llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
+ }
- // remember the last accepted token for the next iteration
- id_last = id;
+ if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
+ break;
}
}