const int tg = n_tg[i_tg];
const int pl = n_pl[i_pl];
- const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg);
+ const int n_ctx_req = is_pp_shared ? (params.kv_unified ? pp : pl*pp) + pl*tg : pl*(pp + tg);
if (n_ctx_req > n_kv_max) {
continue;
return 1;
}
+ const auto t_pp_end = ggml_time_us();
+
if (is_pp_shared) {
for (int32_t i = 1; i < pl; ++i) {
llama_memory_seq_cp(mem, 0, i, -1, -1);
}
- }
- const auto t_pp_end = ggml_time_us();
+ if (!params.kv_unified) {
+ // run one dummy token to apply the memory copy
+ common_batch_clear(batch);
+ common_batch_add(batch, get_token_rand(), pp + 0, { 0 }, true);
+ if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
+ LOG_ERR("%s: llama_decode() failed\n", __func__);
+ return 1;
+ }
+ llama_memory_seq_rm(mem, 0, pp, -1);
+ }
+ }
const auto t_tg_start = ggml_time_us();