From: Georgi Gerganov Date: Tue, 25 Mar 2025 07:19:23 +0000 (+0200) Subject: context : fix worst-case reserve outputs (#12545) X-Git-Tag: upstream/0.0.5028~75 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=2d77d88e70d017cd82c3f1a4517e3102e2028ac4;p=pkg%2Fggml%2Fsources%2Fllama.cpp context : fix worst-case reserve outputs (#12545) ggml-ci --- diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 5bec63e2..aa363df6 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -294,10 +294,7 @@ llama_context::llama_context( // TODO: something cleaner const auto n_outputs_save = n_outputs; - // max number of outputs - n_outputs = n_tokens; - - LLAMA_LOG_DEBUG("%s: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); + LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); int n_splits_pp = -1; int n_nodes_pp = -1; @@ -313,8 +310,15 @@ llama_context::llama_context( // reserve pp graph first so that buffers are only allocated once { llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + + // max number of outputs + n_outputs = ubatch_pp.n_tokens; + + LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs); + auto * gf = graph_init(); graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT); + if (!ggml_backend_sched_reserve(sched.get(), gf)) { throw std::runtime_error("failed to allocate compute pp buffers"); } @@ -326,11 +330,18 @@ llama_context::llama_context( // reserve with tg graph to get the number of splits and nodes { llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + + n_outputs = ubatch_tg.n_tokens; + + LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs); + auto * gf = graph_init(); graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT); + if (!ggml_backend_sched_reserve(sched.get(), gf)) { throw std::runtime_error("failed to allocate compute tg buffers"); } + n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); n_nodes_tg = ggml_graph_n_nodes(gf); } @@ -338,8 +349,14 @@ llama_context::llama_context( // reserve again with pp graph to avoid ggml-alloc reallocations during inference { llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + + n_outputs = ubatch_pp.n_tokens; + + LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs); + auto * gf = graph_init(); graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT); + if (!ggml_backend_sched_reserve(sched.get(), gf)) { throw std::runtime_error("failed to allocate compute pp buffers"); }