// TODO: something cleaner
const auto n_outputs_save = n_outputs;
- // max number of outputs
- n_outputs = n_tokens;
-
- LLAMA_LOG_DEBUG("%s: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
+ LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
int n_splits_pp = -1;
int n_nodes_pp = -1;
// reserve pp graph first so that buffers are only allocated once
{
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
+
+ // max number of outputs
+ n_outputs = ubatch_pp.n_tokens;
+
+ LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
+
auto * gf = graph_init();
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
+
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
throw std::runtime_error("failed to allocate compute pp buffers");
}
// reserve with tg graph to get the number of splits and nodes
{
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
+
+ n_outputs = ubatch_tg.n_tokens;
+
+ LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs);
+
auto * gf = graph_init();
graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
+
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
throw std::runtime_error("failed to allocate compute tg buffers");
}
+
n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
n_nodes_tg = ggml_graph_n_nodes(gf);
}
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
{
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
+
+ n_outputs = ubatch_pp.n_tokens;
+
+ LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
+
auto * gf = graph_init();
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
+
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
throw std::runtime_error("failed to allocate compute pp buffers");
}