}
};
+struct ctx_state {
+ int depth = 0; // in tokens
+
+ std::vector<uint8_t> buf; // the llama_context state buffer
+};
+
static bool test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
llama_set_n_threads(ctx, n_threads, n_threads);
llama_model * lmodel = nullptr;
const cmd_params_instance * prev_inst = nullptr;
+ // store the llama_context state at the previous depth that we performed a test
+ // ref: https://github.com/ggml-org/llama.cpp/pull/16944#issuecomment-3478151721
+ ctx_state cstate;
+
int params_idx = 0;
auto params_count = params_instances.size();
for (const auto & inst : params_instances) {
llama_memory_clear(llama_get_memory(ctx), false);
if (t.n_depth > 0) {
- if (params.progress) {
- fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count,
- i + 1, params.reps);
+ bool is_cached = t.n_depth == cstate.depth;
+
+ if (is_cached) {
+ // if previously we have computed at this depth, just restore the state
+ const size_t ret = llama_state_seq_set_data(ctx, cstate.buf.data(), cstate.buf.size(), 0);
+ if (ret == 0) {
+ // if the old state is incompatible with the current context - reprocess from scratch
+ is_cached = false;
+ }
}
- bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
- if (!res) {
- fprintf(stderr, "%s: error: failed to run depth\n", __func__);
- exit(1);
+
+ if (!is_cached) {
+ if (params.progress) {
+ fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count,
+ i + 1, params.reps);
+ }
+ bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
+ if (!res) {
+ fprintf(stderr, "%s: error: failed to run depth\n", __func__);
+ exit(1);
+ }
+
+ // store the context state for reuse in later runs
+ cstate.depth = t.n_depth;
+ cstate.buf.resize(llama_state_seq_get_size(ctx, 0));
+ llama_state_seq_get_data(ctx, cstate.buf.data(), cstate.buf.size(), 0);
+ } else {
+ if (params.progress) {
+ fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d (cached)\n", params_idx, params_count,
+ i + 1, params.reps);
+ }
}
}