]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : reuse compute graphs (#14482)
authorGeorgi Gerganov <redacted>
Thu, 17 Jul 2025 16:08:33 +0000 (19:08 +0300)
committerGitHub <redacted>
Thu, 17 Jul 2025 16:08:33 +0000 (19:08 +0300)
* llama : reuse compute graphs

ggml-ci

* llama-bench : add graph reuse parameter

ggml-ci

* cont : remove the parameter and the sched resets

ggml-ci

* graph : rename update() to can_reuse()

ggml-ci

* params : remove is_same()

ggml-ci

* graph : set res->params in llm_graph_context constructor

ggml-ci

* graph : avoid set_max_nodes in llm_graph_result

ggml-ci

* kv-cache : reuse llama_context's graph result instance

ggml-ci

* context : reset the previous graph result upon memory updates

ggml-ci

* batch : llama_ubatch now carries its data instead of pointing to balloc

ggml-ci

* merge : fix build

ggml-ci

* graph : fix can_reuse() checks when flash-attention is disabled

* graph : move llm_graph_result impl in source file + debug env

ggml-ci

12 files changed:
include/llama.h
src/llama-batch.cpp
src/llama-batch.h
src/llama-context.cpp
src/llama-context.h
src/llama-graph.cpp
src/llama-graph.h
src/llama-kv-cache-unified.cpp
src/llama-kv-cache-unified.h
src/llama-memory-recurrent.cpp
src/llama-model.cpp
src/llama-model.h

index db6a5337b02a71f8b1804badbe31232a03093a0f..1c3a1cd1b4e7de305bcba2d9a7fab1bedd5ed699 100644 (file)
@@ -1394,6 +1394,7 @@ extern "C" {
 
         int32_t n_p_eval;
         int32_t n_eval;
+        int32_t n_reused; // number of times a ggml compute graph had been reused
     };
 
     struct llama_perf_sampler_data {
index eb15d2c41f1e26b5728da89630709ed6d7c0fe3d..a546063c0a7c8c64808c1e87a8db4b3f14dfab1e 100644 (file)
@@ -210,7 +210,7 @@ bool llama_batch_allocr::init(
         LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
 
         llama_ubatch ubatch {
-            /*.equal_seqs   =*/ false,
+            /*.b_equal_seqs =*/ false,
             /*.n_tokens     =*/ (uint32_t) batch.n_tokens,
             /*.n_seq_tokens =*/ (uint32_t) 1,
             /*.n_seqs       =*/ (uint32_t) batch.n_tokens,
@@ -223,6 +223,7 @@ bool llama_batch_allocr::init(
             /*.seq_id_unq   =*/ this->seq_id_unq.data(),
             /*.seq_idx      =*/ this->seq_idx.data(),
             /*.output       =*/ batch.logits,
+            /*.data         =*/ {},
         };
 
         ubatch_print(ubatch, debug);
@@ -366,39 +367,38 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
     clear();
     split_reset();
 
-    ubatches.emplace_back();
+    auto udata = std::make_shared<llama_ubatch::data_t>();
 
-    auto & ubatch = ubatches.back();
-
-    ubatch.token     .resize(n_tokens);
-    ubatch.embd      .clear();
-    ubatch.pos       .resize(n_tokens);
-    ubatch.n_seq_id  .resize(n_tokens);
-    ubatch.seq_id    .resize(n_tokens);
-    ubatch.seq_id_unq.resize(0);
-    ubatch.seq_idx   .resize(LLAMA_MAX_SEQ, -1);
-    ubatch.output    .resize(n_tokens);
+    udata->token     .resize(n_tokens);
+    udata->embd      .clear();
+    udata->pos       .resize(n_tokens);
+    udata->n_seq_id  .resize(n_tokens);
+    udata->seq_id    .resize(n_tokens);
+    udata->seq_id_unq.resize(0);
+    udata->seq_idx   .resize(LLAMA_MAX_SEQ, -1);
+    udata->output    .resize(n_tokens);
 
     for (uint32_t s = 0; s < n_seqs; ++s) {
-        ubatch.seq_idx[s] = s;
-        ubatch.seq_id_unq.push_back(s);
+        udata->seq_idx[s] = s;
+        udata->seq_id_unq.push_back(s);
     }
 
     llama_ubatch res {
-        /*.equal_seqs   =*/ true,
+        /*.b_equal_seqs =*/ true,
         /*.n_tokens     =*/ n_tokens,
         /*.n_seq_tokens =*/ n_seq_tokens,
         /*.n_seqs       =*/ n_seqs,
         /*.n_seqs_unq   =*/ n_seqs,
 
-        /*.token        =*/ ubatch.token.data(),
+        /*.token        =*/ udata->token.data(),
         /*.embd         =*/ nullptr,
-        /*.pos          =*/ ubatch.pos.data(),
-        /*.n_seq_id     =*/ ubatch.n_seq_id.data(),
-        /*.seq_id       =*/ ubatch.seq_id.data(),
-        /*.seq_id_unq   =*/ ubatch.seq_id_unq.data(),
-        /*.seq_idx      =*/ ubatch.seq_idx.data(),
-        /*.output       =*/ ubatch.output.data(),
+        /*.pos          =*/ udata->pos.data(),
+        /*.n_seq_id     =*/ udata->n_seq_id.data(),
+        /*.seq_id       =*/ udata->seq_id.data(),
+        /*.seq_id_unq   =*/ udata->seq_id_unq.data(),
+        /*.seq_idx      =*/ udata->seq_idx.data(),
+        /*.output       =*/ udata->output.data(),
+        /*.data         =*/ std::move(udata),
     };
 
     return res;
@@ -439,8 +439,6 @@ void llama_batch_allocr::split_reset() {
 
     used.clear();
     used.resize(get_n_tokens(), false);
-
-    ubatches.clear();
 }
 
 llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
@@ -655,78 +653,77 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
 
     assert(n_tokens%n_seqs == 0);
 
-    ubatches.emplace_back();
-
-    auto & ubatch = ubatches.back();
+    auto udata = std::make_shared<llama_ubatch::data_t>();
 
     const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
 
     const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
     const int64_t n_pos_all  =              (int64_t) n_tokens*n_pos_cur;
 
-    ubatch.token     .resize(n_tokens);
-    ubatch.embd      .resize(n_embd_all);
-    ubatch.pos       .resize(n_pos_all);
-    ubatch.n_seq_id  .resize(n_tokens);
-    ubatch.seq_id    .resize(n_tokens);
-    ubatch.seq_id_unq.resize(0);
-    ubatch.seq_idx   .resize(LLAMA_MAX_SEQ, -1);
-    ubatch.output    .resize(n_tokens);
+    udata->token     .resize(n_tokens);
+    udata->embd      .resize(n_embd_all);
+    udata->pos       .resize(n_pos_all);
+    udata->n_seq_id  .resize(n_tokens);
+    udata->seq_id    .resize(n_tokens);
+    udata->seq_id_unq.resize(0);
+    udata->seq_idx   .resize(LLAMA_MAX_SEQ, -1);
+    udata->output    .resize(n_tokens);
 
     seq_set_t seq_set_unq;
 
     for (size_t i = 0; i < idxs.size(); ++i) {
         if (batch.token) {
-            ubatch.token[i] = batch.token[idxs[i]];
+            udata->token[i] = batch.token[idxs[i]];
         }
 
         if (batch.embd) {
-            memcpy(ubatch.embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
+            memcpy(udata->embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
         }
 
         for (int j = 0; j < n_pos_cur; ++j) {
-            ubatch.pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
+            udata->pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
         }
 
-        ubatch.n_seq_id[i] = batch.n_seq_id[idxs[i]];
-        ubatch.seq_id[i]   = batch.seq_id[idxs[i]];
-        ubatch.output[i]   = batch.logits[idxs[i]];
+        udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
+        udata->seq_id[i]   = batch.seq_id[idxs[i]];
+        udata->output[i]   = batch.logits[idxs[i]];
 
-        for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
-            seq_set_unq.set(ubatch.seq_id[i][s]);
+        for (int s = 0; s < udata->n_seq_id[i]; ++s) {
+            seq_set_unq.set(udata->seq_id[i][s]);
         }
 
-        if (ubatch.output[i]) {
+        if (udata->output[i]) {
             out_ids.push_back(idxs[i]);
         }
     }
 
     for (uint32_t s = 0; s < n_seq_max; ++s) {
         if (seq_set_unq.test(s)) {
-            ubatch.seq_idx[s] = ubatch.seq_id_unq.size();
-            ubatch.seq_id_unq.push_back(s);
+            udata->seq_idx[s] = udata->seq_id_unq.size();
+            udata->seq_id_unq.push_back(s);
         }
     }
 
     llama_ubatch res {
-        /*.equal_seqs   =*/ equal_seqs,
+        /*.b_equal_seqs =*/ equal_seqs,
         /*.n_tokens     =*/ n_tokens,
         /*.n_seq_tokens =*/ n_tokens/n_seqs,
         /*.n_seqs       =*/ n_seqs,
-        /*.n_seqs_unq   =*/ (uint32_t) ubatch.seq_id_unq.size(),
-
-        /*.token        =*/ batch.token ? ubatch.token.data() : nullptr,
-        /*.embd         =*/ batch.embd ? ubatch.embd.data() : nullptr,
-        /*.pos          =*/ ubatch.pos.data(),
-        /*.n_seq_id     =*/ ubatch.n_seq_id.data(),
-        /*.seq_id       =*/ ubatch.seq_id.data(),
-        /*.seq_id_unq   =*/ ubatch.seq_id_unq.data(),
-        /*.seq_idx      =*/ ubatch.seq_idx.data(),
-        /*.output       =*/ ubatch.output.data(),
+        /*.n_seqs_unq   =*/ (uint32_t) udata->seq_id_unq.size(),
+
+        /*.token        =*/ batch.token ? udata->token.data() : nullptr,
+        /*.embd         =*/ batch.embd ? udata->embd.data() : nullptr,
+        /*.pos          =*/ udata->pos.data(),
+        /*.n_seq_id     =*/ udata->n_seq_id.data(),
+        /*.seq_id       =*/ udata->seq_id.data(),
+        /*.seq_id_unq   =*/ udata->seq_id_unq.data(),
+        /*.seq_idx      =*/ udata->seq_idx.data(),
+        /*.output       =*/ udata->output.data(),
+        /*.data         =*/ std::move(udata),
     };
 
     if (debug > 0) {
-        LLAMA_LOG_DEBUG("%s: added ubatch %d to split:\n", __func__, (int) ubatches.size() - 1);
+        LLAMA_LOG_DEBUG("%s: added ubatch to split:\n", __func__);
 
         ubatch_print(res, debug);
     }
@@ -736,7 +733,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
 
 void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
     if (debug > 0) {
-        LLAMA_LOG_DEBUG("%s:   equal_seqs   = %d\n", __func__, ubatch.equal_seqs);
+        LLAMA_LOG_DEBUG("%s:   equal_seqs   = %d\n", __func__, ubatch.equal_seqs());
         LLAMA_LOG_DEBUG("%s:   n_tokens     = %d\n", __func__, ubatch.n_tokens);
         LLAMA_LOG_DEBUG("%s:   n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
         LLAMA_LOG_DEBUG("%s:   n_seqs       = %d\n", __func__, ubatch.n_seqs);
index c811aef439588455ed92ab5dcbf3845075792b0a..d563adc66aaf561eb037f9650cb7c07574cac3a7 100644 (file)
@@ -8,12 +8,17 @@
 #include <vector>
 #include <set>
 #include <bitset>
+#include <memory>
 #include <unordered_map>
 
 // keep this struct lightweight
-// it points to data in `llama_batch_allocr`
 struct llama_ubatch {
-    bool equal_seqs;
+    bool equal_seqs() const {
+        return b_equal_seqs != 0;
+    }
+
+    uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
+                           //       otherwise address sanitizer complains
     // TODO: whole_seqs for embeddings?
 
     uint32_t n_tokens;     // total tokens (n_seq_tokens * n_seqs)
@@ -34,6 +39,20 @@ struct llama_ubatch {
     llama_seq_id *  seq_id_unq; // [n_seqs_unq]       | s   | seq_id
     int32_t      *  seq_idx;    // [LLAMA_MAX_SEQ]    | -   | seq_idx
     int8_t       *  output;     // [n_tokens]         | i   | -
+
+    struct data_t {
+        std::vector<llama_token>    token;
+        std::vector<float>          embd;
+        std::vector<llama_pos>      pos;
+        std::vector<int32_t>        n_seq_id;
+        std::vector<llama_seq_id *> seq_id;
+        std::vector<llama_seq_id>   seq_id_unq;
+        std::vector<int32_t>        seq_idx;
+        std::vector<int8_t>         output;
+    };
+
+    // the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data
+    std::shared_ptr<data_t> data;
 };
 
 // a helper for sanitizing, fulfilling and splitting a batch
@@ -137,20 +156,5 @@ private:
     // used[i] indicates if token i has already been used in a previous ubatch
     std::vector<bool> used;
 
-    // llama_ubatch points to this data:
-    struct ubatch {
-        std::vector<llama_token>    token;
-        std::vector<float>          embd;
-        std::vector<llama_pos>      pos;
-        std::vector<int32_t>        n_seq_id;
-        std::vector<llama_seq_id *> seq_id;
-        std::vector<llama_seq_id>   seq_id_unq;
-        std::vector<int32_t>        seq_idx;
-        std::vector<int8_t>         output;
-    };
-
-    // current splitting state:
-    std::vector<ubatch> ubatches;
-
     int debug;
 };
index 840ec9a9aaca1da466ba558cd068afd8f94d8ab2..4e1d911593decd2346384a1f507f142ef4954c82 100644 (file)
@@ -105,7 +105,7 @@ llama_context::llama_context(
 
     {
         const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
-        const bool supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
+        const bool supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : false;
 
         if (!supports_set_rows && !cparams.kv_unified) {
             LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__);
@@ -238,8 +238,8 @@ llama_context::llama_context(
 
         LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
 
-        // buffer used to store the computation graph and the tensor meta data
-        buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
+        gf_res_prev.reset(new llm_graph_result(max_nodes));
+        gf_res_reserve.reset(new llm_graph_result(max_nodes));
 
         // TODO: move these checks to ggml_backend_sched
         // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
@@ -403,10 +403,6 @@ ggml_backend_sched_t llama_context::get_sched() const {
     return sched.get();
 }
 
-ggml_context * llama_context::get_ctx_compute() const {
-    return ctx_compute.get();
-}
-
 uint32_t llama_context::n_ctx() const {
     return cparams.n_ctx;
 }
@@ -478,6 +474,11 @@ bool llama_context::kv_self_update(bool optimize) {
                 }
         }
 
+        // reset the previous graph result to make sure that it won't be reused
+        // TODO: change the mctx->apply() to return information if a graph reserve is needed
+        //       reset the graph result only if the memory module did reset the scheduler
+        gf_res_prev->reset();
+
         if (!mctx->apply()) {
             LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
         }
@@ -693,38 +694,59 @@ bool llama_context::apply_adapter_cvec(
     return cvec.apply(model, data, len, n_embd, il_start, il_end);
 }
 
-llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
+llm_graph_result_i * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
     if (mctx && !mctx->apply()) {
         LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
         ret = GGML_STATUS_FAILED;
         return nullptr;
     }
 
-    auto * gf = graph_init();
-    if (!gf) {
-        LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
-        ret = GGML_STATUS_FAILED;
-        return nullptr;
-    }
+    auto * res = gf_res_prev.get();
+    auto * gf  = res->get_gf();
 
-    auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
-    if (!res) {
-        LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
-        ret = GGML_STATUS_FAILED;
-        return nullptr;
-    }
+    // the new graph parameters
+    // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
+    const auto gparams = graph_params(res, ubatch, mctx, gtype);
 
-    // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
+    if (res->can_reuse(gparams)) {
+        //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__);
 
-    if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
-        LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
-        ret = GGML_STATUS_ALLOC_FAILED;
-        return nullptr;
+        n_reused++;
+    } else {
+        res->reset();
+
+        ggml_backend_sched_reset(sched.get());
+        ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
+
+        //const auto t_start_us = ggml_time_us();
+
+        gf = model.build_graph(gparams);
+
+        //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
+
+        if (!gf) {
+            LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
+            ret = GGML_STATUS_FAILED;
+            return nullptr;
+        }
+
+        if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
+            LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
+            ret = GGML_STATUS_ALLOC_FAILED;
+            return nullptr;
+        }
     }
 
-    res->set_inputs(&ubatch);
+    // set the input data for the input tensors
+    {
+        //const auto t_start_us = ggml_time_us();
+
+        res->set_inputs(&ubatch);
+
+        //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
+    }
 
-    const auto status = graph_compute(gf, ubatch.n_tokens > 1);
+    const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1);
     if (status != GGML_STATUS_SUCCESS) {
         LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
         ret = status;
@@ -785,9 +807,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
 
     n_outputs = n_tokens;
 
-    ggml_backend_sched_reset(sched.get());
-    ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
-
     const auto causal_attn_org = cparams.causal_attn;
 
     // always use non-causal attention for encoder graphs
@@ -796,7 +815,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
     cparams.causal_attn = false;
 
     ggml_status status;
-    const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
+    const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
 
     cparams.causal_attn = causal_attn_org;
 
@@ -872,10 +891,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
         }
     }
 
-    // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
-    // overlap with device computation.
-    ggml_backend_sched_reset(sched.get());
-
     // TODO: hacky solution
     if (model.arch == LLM_ARCH_T5 && t_embd) {
         //cross.t_embd = t_embd;
@@ -1033,11 +1048,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
             n_outputs = n_outputs_new;
         }
 
-        ggml_backend_sched_reset(sched.get());
-        ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
-
         ggml_status status;
-        const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
+        const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
 
         if (!res) {
             // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@@ -1218,10 +1230,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
     // wait for the computation to finish (automatically done when obtaining the model output)
     //synchronize();
 
-    // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
-    // overlap with device computation.
-    ggml_backend_sched_reset(sched.get());
-
     return 0;
 }
 
@@ -1303,20 +1311,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
 // graph
 //
 
-int32_t llama_context::graph_max_nodes() const {
-    return std::max<int32_t>(65536, 5*model.n_tensors());
+uint32_t llama_context::graph_max_nodes() const {
+    return std::max<uint32_t>(65536u, 5u*model.n_tensors());
 }
 
-ggml_cgraph * llama_context::graph_init() {
-    ggml_init_params params = {
-        /*.mem_size   =*/ buf_compute_meta.size(),
-        /*.mem_buffer =*/ buf_compute_meta.data(),
-        /*.no_alloc   =*/ true,
-    };
-
-    ctx_compute.reset(ggml_init(params));
-
-    return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
+llm_graph_result * llama_context::get_gf_res_reserve() const {
+    return static_cast<llm_graph_result *>(gf_res_reserve.get());
 }
 
 ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
@@ -1329,6 +1329,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
         LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
     }
 
+    ggml_backend_sched_reset(sched.get());
+
+    // when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that
+    gf_res_prev->reset();
+
     // store the n_outputs as it is, and restore it afterwards
     // TODO: not sure if needed, might simplify in the future by removing this
     const auto save_n_outputs = this->n_outputs;
@@ -1338,17 +1343,15 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
     llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
     llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
 
-    auto * gf = graph_init();
-    auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
+    auto * res = gf_res_reserve.get();
 
-    this->n_outputs = save_n_outputs;
+    const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
 
-    if (!res) {
-        LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
-        return nullptr;
-    }
+    res->reset();
 
-    ggml_backend_sched_reset(sched.get());
+    auto * gf = model.build_graph(gparams);
+
+    this->n_outputs = save_n_outputs;
 
     // initialize scheduler with the specified graph
     if (!ggml_backend_sched_reserve(sched.get(), gf)) {
@@ -1359,28 +1362,27 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
     return gf;
 }
 
-llm_graph_result_ptr llama_context::graph_build(
-                      ggml_context * ctx,
-                       ggml_cgraph * gf,
-                const llama_ubatch & ubatch,
-                    llm_graph_type   gtype,
-      const llama_memory_context_i * mctx) {
-    return model.build_graph(
-            {
-                /*.ctx         =*/ ctx,
-                /*.arch        =*/ model.arch,
-                /*.hparams     =*/ model.hparams,
-                /*.cparams     =*/ cparams,
-                /*.ubatch      =*/ ubatch,
-                /*.sched       =*/ sched.get(),
-                /*.backend_cpu =*/ backend_cpu,
-                /*.cvec        =*/ &cvec,
-                /*.loras       =*/ &loras,
-                /*.mctx        =*/ mctx,
-                /*.cross       =*/ &cross,
-                /*.n_outputs   =*/ n_outputs,
-                /*.cb          =*/ graph_get_cb(),
-            }, gf, gtype);
+llm_graph_params llama_context::graph_params(
+                      llm_graph_result_i * res,
+                      const llama_ubatch & ubatch,
+            const llama_memory_context_i * mctx,
+            llm_graph_type   gtype) const {
+    return {
+        /*.arch        =*/ model.arch,
+        /*.hparams     =*/ model.hparams,
+        /*.cparams     =*/ cparams,
+        /*.ubatch      =*/ ubatch,
+        /*.gtype       =*/ gtype,
+        /*.sched       =*/ sched.get(),
+        /*.backend_cpu =*/ backend_cpu,
+        /*.cvec        =*/ &cvec,
+        /*.loras       =*/ &loras,
+        /*.mctx        =*/ mctx,
+        /*.cross       =*/ &cross,
+        /*.n_outputs   =*/ n_outputs,
+        /*.cb          =*/ graph_get_cb(),
+        /*.res         =*/ res,
+    };
 }
 
 ggml_status llama_context::graph_compute(
@@ -1958,6 +1960,7 @@ llama_perf_context_data llama_context::perf_get_data() const {
     data.t_eval_ms   = 1e-3 * t_eval_us;
     data.n_p_eval    = std::max(1, n_p_eval);
     data.n_eval      = std::max(1, n_eval);
+    data.n_reused    = std::max(0, n_reused);
 
     return data;
 }
@@ -1966,6 +1969,7 @@ void llama_context::perf_reset() {
     t_start_us  = ggml_time_us();
     t_eval_us   = n_eval = 0;
     t_p_eval_us = n_p_eval = 0;
+    n_reused    = 0;
 }
 
 //
@@ -2092,8 +2096,13 @@ void llama_context::opt_epoch_iter(
                 break;
             }
 
-            auto * gf = graph_init();
-            auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
+            auto * res = gf_res_prev.get();
+
+            const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT);
+
+            res->reset();
+
+            auto * gf = model.build_graph(gparams);
 
             struct ggml_context * ctx_compute_opt;
             {
@@ -2836,6 +2845,7 @@ void llama_perf_context_print(const llama_context * ctx) {
     LLAMA_LOG_INFO("%s:        eval time = %10.2f ms / %5d runs   (%8.2f ms per token, %8.2f tokens per second)\n",
             __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
     LLAMA_LOG_INFO("%s:       total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
+    LLAMA_LOG_INFO("%s:    graphs reused = %10d\n", __func__, data.n_reused);
 }
 
 void llama_perf_context_reset(llama_context * ctx) {
index 9ce05715a8c0306312bd03f516ae38f1f79cb2f6..fd480af6ec8754e89cf1435741754992dc346dec 100644 (file)
@@ -35,8 +35,6 @@ struct llama_context {
 
     ggml_backend_sched_t get_sched() const;
 
-    ggml_context * get_ctx_compute() const;
-
     uint32_t n_ctx()         const;
     uint32_t n_ctx_per_seq() const;
     uint32_t n_batch()       const;
@@ -96,7 +94,7 @@ struct llama_context {
     // if memory_context is provided, it will be applied first to the context's memory
     // ret contains the status of the graph computation
     // returns nullptr only if ret != GGML_STATUS_SUCCESS
-    llm_graph_result_ptr process_ubatch(
+    llm_graph_result_i * process_ubatch(
                 const llama_ubatch & ubatch,
                     llm_graph_type   gtype,
             llama_memory_context_i * mctx,
@@ -188,10 +186,10 @@ private:
     //
 
 public:
-    int32_t graph_max_nodes() const;
+    uint32_t graph_max_nodes() const;
 
-    // zero-out inputs and create the ctx_compute for the compute graph
-    ggml_cgraph * graph_init();
+    // can reuse the llm_graph_result instance of the context (for example to update a memory module)
+    llm_graph_result * get_gf_res_reserve() const;
 
     // returns the result of ggml_backend_sched_graph_compute_async execution
     ggml_status graph_compute(ggml_cgraph * gf, bool batched);
@@ -200,12 +198,11 @@ public:
     ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
 
 private:
-    llm_graph_result_ptr graph_build(
-                      ggml_context * ctx,
-                       ggml_cgraph * gf,
-                const llama_ubatch & ubatch,
-                    llm_graph_type   gtype,
-      const llama_memory_context_i * mctx);
+    llm_graph_params graph_params(
+                      llm_graph_result_i * res,
+                      const llama_ubatch & ubatch,
+            const llama_memory_context_i * mctx,
+                          llm_graph_type   gtype) const;
 
     llm_graph_cb graph_get_cb() const;
 
@@ -258,8 +255,6 @@ private:
     ggml_backend_t backend_cpu = nullptr;
     std::vector<ggml_backend_ptr> backends;
 
-    ggml_context_ptr ctx_compute;
-
     // training
     ggml_opt_context_t opt_ctx = nullptr;
 
@@ -275,8 +270,8 @@ private:
     std::vector<ggml_backend_t>             backend_ptrs;
     std::vector<ggml_backend_buffer_type_t> backend_buft;
 
-    // memory buffers used to evaluate the model
-    std::vector<uint8_t> buf_compute_meta;
+    llm_graph_result_ptr gf_res_prev;
+    llm_graph_result_ptr gf_res_reserve;
 
     // host buffer for the model output (logits and embeddings)
     ggml_backend_buffer_ptr buf_output;
@@ -294,4 +289,6 @@ private:
 
     mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
     mutable int32_t n_eval   = 0; // number of eval calls
+
+    mutable int32_t n_reused = 0; // number of times the previous graph was reused
 };
index 1a6355e85d11ea651255c8b46ca3afa872e64249..f47538ef0737a60e17120fa17971f67d12c301ab 100644 (file)
@@ -28,6 +28,15 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
     }
 }
 
+bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
+    bool res = true;
+
+    res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
+    res &= (!embd   && !params.ubatch.embd)  || (embd   &&   embd->ne[0] == params.ubatch.n_tokens);
+
+    return res;
+}
+
 void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
     if (ubatch->pos && pos) {
         const int64_t n_tokens = ubatch->n_tokens;
@@ -50,6 +59,14 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
     }
 }
 
+bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
+    bool res = true;
+
+    res &= pos->ne[0] == params.ubatch.n_tokens;
+
+    return res;
+}
+
 void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
     if (ubatch->pos && attn_scale) {
         const int64_t n_tokens = ubatch->n_tokens;
@@ -71,7 +88,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
         const int64_t n_tokens = ubatch->n_tokens;
 
         GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
-        GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
+        GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
 
         int32_t * data = (int32_t *) pos_bucket->data;
 
@@ -118,6 +135,14 @@ void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
     }
 }
 
+bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
+    bool res = true;
+
+    res &= n_outputs == params.n_outputs;
+
+    return res;
+}
+
 void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
     if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
         const int64_t n_tokens     = ubatch->n_tokens;
@@ -287,6 +312,24 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
     mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
 }
 
+bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params) {
+    const auto * mctx = static_cast<const llama_kv_cache_unified_context *>(params.mctx);
+
+    this->mctx = mctx;
+
+    bool res = true;
+
+    res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
+  //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
+
+    res &= self_kq_mask->ne[0] == mctx->get_n_kv();
+    res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
+
+    res &= mctx->get_supports_set_rows(); // TODO: tmp
+
+    return res;
+}
+
 void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
     mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
     mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
@@ -299,6 +342,30 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch
     mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
 }
 
+bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & params) {
+    const auto * mctx = static_cast<const llama_kv_cache_unified_iswa_context *>(params.mctx);
+
+    this->mctx = mctx;
+
+    bool res = true;
+
+    res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
+  //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
+
+    res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
+  //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
+
+    res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
+    res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
+
+    res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
+    res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
+
+    res &= mctx->get_base()->get_supports_set_rows(); // TODO: tmp
+
+    return res;
+}
+
 void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
     GGML_ASSERT(cross_kq_mask);
 
@@ -306,7 +373,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
     const int64_t n_tokens = ubatch->n_tokens;
 
     GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
-    GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
+    GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
 
     float * data = (float *) cross_kq_mask->data;
 
@@ -340,6 +407,83 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
     inp_rs->set_input(ubatch);
 }
 
+//
+// llm_graph_result
+//
+
+llm_graph_result::llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) {
+    reset();
+
+    const char * LLAMA_GRAPH_RESULT_DEBUG = getenv("LLAMA_GRAPH_RESULT_DEBUG");
+    debug = LLAMA_GRAPH_RESULT_DEBUG ? atoi(LLAMA_GRAPH_RESULT_DEBUG) : 0;
+}
+
+int64_t llm_graph_result::get_max_nodes() const {
+    return max_nodes;
+}
+
+void llm_graph_result::reset() {
+    t_tokens      = nullptr;
+    t_logits      = nullptr;
+    t_embd        = nullptr;
+    t_embd_pooled = nullptr;
+
+    inputs.clear();
+
+    buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
+
+    ggml_init_params params = {
+        /*.mem_size   =*/ buf_compute_meta.size(),
+        /*.mem_buffer =*/ buf_compute_meta.data(),
+        /*.no_alloc   =*/ true,
+    };
+
+    ctx_compute.reset(ggml_init(params));
+
+    gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false);
+}
+
+void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
+    for (auto & input : inputs) {
+        input->set_input(ubatch);
+    }
+}
+
+bool llm_graph_result::can_reuse(const llm_graph_params & params) {
+    if (!this->params.allow_reuse(params)) {
+        if (debug > 1) {
+            LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__);
+        }
+
+        return false;
+    }
+
+    if (debug > 1) {
+        LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size());
+    }
+
+    bool res = true;
+
+    for (auto & input : inputs) {
+        const bool cur = input->can_reuse(params);
+
+        LLAMA_LOG_DEBUG("  %s: can_reuse = %d\n", "placeholder", cur);
+
+        res = res && cur;
+    }
+
+    if (debug > 0) {
+        LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res);
+    }
+
+    return res;
+}
+
+llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {
+    inputs.emplace_back(std::move(input));
+    return inputs.back().get();
+}
+
 //
 // llm_graph_context
 //
@@ -374,7 +518,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
     n_ctx_orig       (cparams.n_ctx_orig_yarn),
     pooling_type     (cparams.pooling_type),
     rope_type        (hparams.rope_type),
-    ctx0             (params.ctx),
     sched            (params.sched),
     backend_cpu      (params.backend_cpu),
     cvec             (params.cvec),
@@ -382,7 +525,9 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
     mctx             (params.mctx),
     cross            (params.cross),
     cb_func          (params.cb),
-    res              (std::make_unique<llm_graph_result>()) {
+    res              (static_cast<llm_graph_result *>(params.res)),
+    ctx0             (res->get_ctx()) {
+        res->params = params;
     }
 
 void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
@@ -1127,8 +1272,8 @@ ggml_tensor * llm_graph_context::build_attn(
     const auto & kq_mask = inp->get_kq_mask();
 
     // [TAG_NO_CACHE_PAD]
-    // TODO: if ubatch.equal_seqs == true, we can split the three tensors below into ubatch.n_seqs_unq streams
-    assert(ubatch.equal_seqs == false);
+    // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
+    assert(!ubatch.equal_seqs());
 
     ggml_tensor * q = q_cur;
     ggml_tensor * k = k_cur;
index 84a5b0b3f9c4033be8943301944870f814b956cf..42e636e0e3f6c31352d287ecf9057ae565d2073a 100644 (file)
@@ -1,6 +1,7 @@
 #pragma once
 
 #include "llama-arch.h"
+#include "llama-batch.h"
 #include "llama-hparams.h"
 #include "llama-adapter.h"
 
@@ -14,7 +15,6 @@ struct ggml_cgraph;
 struct ggml_context;
 struct ggml_tensor;
 
-struct llama_ubatch;
 struct llama_cparams;
 
 struct llama_memory_context_i;
@@ -69,6 +69,8 @@ struct llama_cross {
     std::vector<std::set<llama_seq_id>> seq_ids_enc;
 };
 
+struct llm_graph_params;
+
 //
 // llm_graph_input
 //
@@ -78,11 +80,19 @@ public:
     virtual ~llm_graph_input_i() = default;
 
     virtual void set_input(const llama_ubatch * ubatch) = 0;
+
+    // return true if the resulting input tensors using the provided graph parameters would be
+    //   the same as the previous input tensors that we have currently stored in the object
+    virtual bool can_reuse(const llm_graph_params & params) {
+        // returning false here by default will prevent from reusing the graph if the check
+        //   for the input type has not been implemented yet
+        GGML_UNUSED(params);
+        return false;
+    }
 };
 
 using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
 
-
 class llm_graph_input_embd : public llm_graph_input_i {
 public:
     llm_graph_input_embd()          = default;
@@ -90,6 +100,8 @@ public:
 
     void set_input(const llama_ubatch * ubatch) override;
 
+    bool can_reuse(const llm_graph_params & params) override;
+
     ggml_tensor * tokens = nullptr; // I32 [n_batch]
     ggml_tensor * embd   = nullptr; // F32 [n_embd, n_batch]
 };
@@ -101,6 +113,8 @@ public:
 
     void set_input(const llama_ubatch * ubatch) override;
 
+    bool can_reuse(const llm_graph_params & params) override;
+
     ggml_tensor * pos = nullptr; // I32 [n_batch]
 
     const uint32_t n_pos_per_embd = 1;
@@ -154,17 +168,19 @@ public:
     llm_graph_input_out_ids(
             const llama_hparams & hparams,
             const llama_cparams & cparams,
-            int32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
+            uint32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
     virtual ~llm_graph_input_out_ids() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
 
+    bool can_reuse(const llm_graph_params & params) override;
+
     ggml_tensor * out_ids; // I32 [n_outputs]
 
     const llama_hparams & hparams;
     const llama_cparams & cparams;
 
-    const int32_t n_outputs;
+    const uint32_t n_outputs;
 };
 
 class llm_graph_input_mean : public llm_graph_input_i {
@@ -249,6 +265,8 @@ public:
 
     void set_input(const llama_ubatch * ubatch) override;
 
+    bool can_reuse(const llm_graph_params & params) override;
+
     ggml_tensor * get_k_idxs() const { return self_k_idxs; }
     ggml_tensor * get_v_idxs() const { return self_v_idxs; }
 
@@ -280,6 +298,8 @@ public:
 
     void set_input(const llama_ubatch * ubatch) override;
 
+    bool can_reuse(const llm_graph_params & params) override;
+
     ggml_tensor * get_k_idxs()     const { return self_k_idxs; }
     ggml_tensor * get_v_idxs()     const { return self_v_idxs; }
     ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
@@ -351,40 +371,127 @@ public:
 // along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
 //   these are used by the llama_context to extact the relevant data, based on the compute parameters
 
+// TODO: this interface seems redundant - remove it
 class llm_graph_result_i {
 public:
     virtual ~llm_graph_result_i() = default;
 
-    virtual ggml_tensor * get_tokens()      = 0;
-    virtual ggml_tensor * get_logits()      = 0;
-    virtual ggml_tensor * get_embd()        = 0;
-    virtual ggml_tensor * get_embd_pooled() = 0;
+    virtual ggml_tensor * get_tokens()      const = 0;
+    virtual ggml_tensor * get_logits()      const = 0;
+    virtual ggml_tensor * get_embd()        const = 0;
+    virtual ggml_tensor * get_embd_pooled() const = 0;
+
+    virtual ggml_cgraph  * get_gf()  = 0;
+    virtual ggml_context * get_ctx() = 0;
+
+    virtual void reset() = 0;
 
     virtual void set_inputs(const llama_ubatch * ubatch) = 0;
+
+    virtual bool can_reuse(const llm_graph_params & params) = 0;
 };
 
 using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
 
+// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
+using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
+
+struct llm_graph_params {
+    llm_arch arch = LLM_ARCH_UNKNOWN;
+
+    llama_hparams hparams;
+    llama_cparams cparams;
+
+    llama_ubatch ubatch; // note: intentionally make a copy
+
+    llm_graph_type gtype;
+
+    ggml_backend_sched_t sched;
+    ggml_backend_t backend_cpu;
+
+    const llama_adapter_cvec     * cvec;
+    const llama_adapter_loras    * loras;
+    const llama_memory_context_i * mctx;
+    const llama_cross            * cross;
+
+    uint32_t n_outputs;
+
+    llm_graph_cb cb;
+
+    // TODO: temporary
+    llm_graph_result_i * res;
+
+    // return true if the "other" params would result in a graph with the same topology as with the current params
+    //   having the same topology allows us to reuse the graph in some cases
+    bool allow_reuse(const llm_graph_params & other) const {
+        // first check the ubatch
+        bool can_reuse_ubatch =
+            ubatch.equal_seqs() == other.ubatch.equal_seqs() &&
+            ubatch.n_tokens     == other.ubatch.n_tokens &&
+            ubatch.n_seq_tokens == other.ubatch.n_seq_tokens &&
+            ubatch.n_seqs       == other.ubatch.n_seqs &&
+            ubatch.n_seqs_unq   == other.ubatch.n_seqs_unq &&
+            (
+                (!ubatch.token && !other.ubatch.token) ||
+                (!ubatch.embd  && !other.ubatch.embd)
+            );
+
+        if (can_reuse_ubatch && !ubatch.equal_seqs()) {
+            if (!ubatch.data) {
+                // if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and
+                //   therefore we cannot perform the sequence id check. normally should never happen
+                can_reuse_ubatch = false;
+            } else {
+                for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
+                    can_reuse_ubatch &= ubatch.seq_id_unq[s] == other.ubatch.seq_id_unq[s];
+                }
+            }
+        }
+
+        if (!can_reuse_ubatch) {
+            return false;
+        }
+
+        return
+            cparams.embeddings  == other.cparams.embeddings  &&
+            cparams.causal_attn == other.cparams.causal_attn &&
+            arch      == other.arch  &&
+            gtype     == other.gtype &&
+            cvec      == other.cvec  &&
+            loras     == other.loras &&
+            cross     == other.cross &&
+            n_outputs == other.n_outputs;
+    }
+};
 
 class llm_graph_result : public llm_graph_result_i {
 public:
+    llm_graph_result(int64_t max_nodes);
+
     virtual ~llm_graph_result() = default;
 
-    ggml_tensor * get_tokens()      override { return t_tokens; }
-    ggml_tensor * get_logits()      override { return t_logits; }
-    ggml_tensor * get_embd()        override { return t_embd; }
-    ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
+    ggml_tensor * get_tokens()      const override { return t_tokens; }
+    ggml_tensor * get_logits()      const override { return t_logits; }
+    ggml_tensor * get_embd()        const override { return t_embd; }
+    ggml_tensor * get_embd_pooled() const override { return t_embd_pooled; }
 
-    void set_inputs(const llama_ubatch * ubatch) override {
-        for (auto & input : inputs) {
-            input->set_input(ubatch);
-        }
-    }
+    ggml_cgraph  * get_gf()  override { return gf; }
+    ggml_context * get_ctx() override { return ctx_compute.get(); }
 
-    llm_graph_input_i * add_input(llm_graph_input_ptr input) {
-        inputs.emplace_back(std::move(input));
-        return inputs.back().get();
-    }
+    int64_t get_max_nodes() const;
+
+    void reset() override;
+
+    void set_inputs(const llama_ubatch * ubatch) override;
+
+    // try to update the existing graph result using the new graph parameters in order to reuse it
+    // this can only be done if we determine that the resulting graph using the new graph parameters
+    //   would be identical to the existing graph. in that case, we simply have to update the memory
+    //   contexts of the input tensors of the graph and we can reuse it for another computation
+    // return true if the graph was updated and can be reused
+    bool can_reuse(const llm_graph_params & params) override;
+
+    llm_graph_input_i * add_input(llm_graph_input_ptr input);
 
     // important graph nodes
     ggml_tensor * t_tokens      = nullptr;
@@ -393,37 +500,29 @@ public:
     ggml_tensor * t_embd_pooled = nullptr;
 
     std::vector<llm_graph_input_ptr> inputs;
-};
-
-//
-// llm_graph_context
-//
 
-// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
-using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
-
-struct llm_graph_params {
-    ggml_context * ctx;
-
-    const llm_arch arch;
+    ggml_context_ptr ctx_compute;
 
-    const llama_hparams & hparams;
-    const llama_cparams & cparams;
-    const llama_ubatch  & ubatch;
+    // memory buffers used to evaluate the model
+    std::vector<uint8_t> buf_compute_meta;
 
-    ggml_backend_sched_t sched;
-    ggml_backend_t backend_cpu;
+    ggml_cgraph * gf;
 
-    const llama_adapter_cvec     * cvec;
-    const llama_adapter_loras    * loras;
-    const llama_memory_context_i * mctx;
-    const llama_cross            * cross;
+    int64_t max_nodes;
 
-    uint32_t n_outputs;
+    // keep a copy of the previous graph parameters
+    // we will use this to determine whether the graph can be reused by comparing them with the new parameters
+    // note: these are updated after constructing the new graph
+    llm_graph_params params;
 
-    const llm_graph_cb & cb;
+    // env: LLAMA_GRAPH_RESULT_DEBUG
+    int debug = 0;
 };
 
+//
+// llm_graph_context
+//
+
 // used in build_rs to properly order writes and avoid unnecessary copies
 using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
 
@@ -463,8 +562,6 @@ struct llm_graph_context {
     const enum llama_pooling_type pooling_type;
     const enum llama_rope_type    rope_type;
 
-    ggml_context * ctx0 = nullptr;
-
     ggml_backend_sched_t sched;
 
     ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
@@ -476,7 +573,9 @@ struct llm_graph_context {
 
     const llm_graph_cb & cb_func;
 
-    std::unique_ptr<llm_graph_result> res;
+    llm_graph_result * res;
+
+    ggml_context * ctx0 = nullptr;
 
     llm_graph_context(const llm_graph_params & params);
     virtual ~llm_graph_context() = default;
index baaa1d32dffb55a2bd83c7f5ffce872e39590a3b..98c01ea7ad15d42f0a501d7c5683a72a8a6a4255 100644 (file)
@@ -193,7 +193,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
     debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
 
     const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
-    supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
+    supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) != 0 : 0;
 
     if (!supports_set_rows) {
         // ref: https://github.com/ggml-org/llama.cpp/pull/14363
@@ -656,14 +656,11 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
         if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
             ggml_backend_sched_reset(sched);
 
-            auto * gf = lctx->graph_init();
+            auto * res = lctx->get_gf_res_reserve();
 
-            auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf);
-            if (!res) {
-                LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
-                return updated;
-            }
+            res->reset();
 
+            auto * gf = build_graph_shift(res, lctx);
             if (!ggml_backend_sched_alloc_graph(sched, gf)) {
                 LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__);
                 return updated;
@@ -713,14 +710,11 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
 
         ggml_backend_sched_reset(sched);
 
-        auto * gf = lctx->graph_init();
+        auto * res = lctx->get_gf_res_reserve();
 
-        auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo);
-        if (!res) {
-            LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
-            return updated;
-        }
+        res->reset();
 
+        auto * gf = build_graph_defrag(res, lctx, dinfo);
         if (!ggml_backend_sched_alloc_graph(sched, gf)) {
             LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
             return updated;
@@ -1035,6 +1029,10 @@ uint32_t llama_kv_cache_unified::get_n_kv() const {
     return result;
 }
 
+bool llama_kv_cache_unified::get_supports_set_rows() const {
+    return supports_set_rows;
+}
+
 ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
     const int32_t ikv = map_layer_ids.at(il);
 
@@ -1297,6 +1295,7 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
     //      xxxxx-----
     //      xxxxx-----
     // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
+    // TODO: optimize this section
     for (uint32_t h = 0; h < 1; ++h) {
         for (uint32_t s = 0; s < n_stream; ++s) {
             for (uint32_t ii = 0; ii < n_tps; ++ii) {
@@ -1346,7 +1345,7 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
     const auto & cells = v_cells[0];
 
     GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
-    GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
+    GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
 
     int32_t * data = (int32_t *) dst->data;
 
@@ -1464,11 +1463,9 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
     }
 }
 
-llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
-        const llama_cparams & cparams,
-               ggml_context * ctx,
-                ggml_cgraph * gf) const {
-    auto res = std::make_unique<llm_graph_result>();
+ggml_cgraph * llama_kv_cache_unified::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
+    auto * ctx = res->get_ctx();
+    auto * gf  = res->get_gf();
 
     const auto & n_embd_head_k = hparams.n_embd_head_k;
   //const auto & n_embd_head_v = hparams.n_embd_head_v;
@@ -1478,6 +1475,8 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
     inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
     ggml_set_input(inp->k_shift);
 
+    const auto & cparams = lctx->get_cparams();
+
     for (const auto & layer : layers) {
         const uint32_t il = layer.il;
 
@@ -1503,15 +1502,15 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
 
     res->add_input(std::move(inp));
 
-    return res;
+    return gf;
 }
 
-llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
-                const llama_cparams & cparams,
-                       ggml_context * ctx,
-                        ggml_cgraph * gf,
-                  const defrag_info & dinfo) const {
-    auto res = std::make_unique<llm_graph_result>();
+ggml_cgraph * llama_kv_cache_unified::build_graph_defrag(
+         llm_graph_result * res,
+            llama_context * lctx,
+        const defrag_info & dinfo) const {
+    auto * ctx = res->get_ctx();
+    auto * gf  = res->get_gf();
 
     GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
 
@@ -1519,6 +1518,8 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
 
     const auto & ids = dinfo.ids;
 
+    const auto & cparams = lctx->get_cparams();
+
 #if 0
     // CPU defrag
     //
@@ -1655,7 +1656,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
     //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
 #endif
 
-    return res;
+    return gf;
 }
 
 llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
@@ -2331,6 +2332,10 @@ uint32_t llama_kv_cache_unified_context::get_n_kv() const {
     return n_kv;
 }
 
+bool llama_kv_cache_unified_context::get_supports_set_rows() const {
+    return kv->get_supports_set_rows();
+}
+
 ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
     return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
 }
index 3bfda4600d8432bac2d5ffbac53bf8a4aea9f863..3e28e346c3fcf8d1ec09fa15de1b2bd6c4dcb3b4 100644 (file)
@@ -154,6 +154,9 @@ public:
 
     uint32_t get_n_kv() const;
 
+    // TODO: temporary
+    bool get_supports_set_rows() const;
+
     // get views of the current state of the cache
     ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
     ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
@@ -227,7 +230,7 @@ private:
 
     // env: LLAMA_SET_ROWS (temporary)
     // ref: https://github.com/ggml-org/llama.cpp/pull/14285
-    int supports_set_rows = false;
+    bool supports_set_rows = false;
 
     const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
 
@@ -270,15 +273,13 @@ private:
                           float   freq_base,
                           float   freq_scale) const;
 
-    llm_graph_result_ptr build_graph_shift(
-            const llama_cparams & cparams,
-                   ggml_context * ctx,
-                    ggml_cgraph * gf) const;
+    ggml_cgraph * build_graph_shift(
+               llm_graph_result * res,
+                  llama_context * lctx) const;
 
-    llm_graph_result_ptr build_graph_defrag(
-            const llama_cparams & cparams,
-                   ggml_context * ctx,
-                    ggml_cgraph * gf,
+    ggml_cgraph * build_graph_defrag(
+               llm_graph_result * res,
+                  llama_context * lctx,
               const defrag_info & dinfo) const;
 
     struct cell_ranges_t {
@@ -340,6 +341,9 @@ public:
 
     uint32_t get_n_kv() const;
 
+    // TODO: temporary
+    bool get_supports_set_rows() const;
+
     // get views of the current state of the cache
     ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
     ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
index 2c1ae67098ca49445ca0305e6b71cda23f09f0ce..1e1a7a9b31e4638bcf027d862c70a7154a0da611 100644 (file)
@@ -446,7 +446,7 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
     // A slot should be always be contiguous.
 
     // can only process batches with an equal number of new tokens in each sequence
-    GGML_ASSERT(ubatch.equal_seqs);
+    GGML_ASSERT(ubatch.equal_seqs());
 
     int32_t min = size - 1;
     int32_t max = 0;
index cdf1e424294e58bfa6a773e136f57cf285feb1aa..46899f48ffea0700535825b62a9ef03b5c0b43b1 100644 (file)
@@ -10463,7 +10463,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
         const int64_t n_seq_tokens = ubatch.n_seq_tokens;
 
         GGML_ASSERT(n_seqs != 0);
-        GGML_ASSERT(ubatch.equal_seqs);
+        GGML_ASSERT(ubatch.equal_seqs());
         GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
 
         ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
@@ -10598,7 +10598,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
         const int64_t n_seq_tokens = ubatch.n_seq_tokens;
 
         GGML_ASSERT(n_seqs != 0);
-        GGML_ASSERT(ubatch.equal_seqs);
+        GGML_ASSERT(ubatch.equal_seqs());
         GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
 
         ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
@@ -15870,7 +15870,7 @@ private:
         const int64_t n_seq_tokens = ubatch.n_seq_tokens;
 
         GGML_ASSERT(n_seqs != 0);
-        GGML_ASSERT(ubatch.equal_seqs);
+        GGML_ASSERT(ubatch.equal_seqs());
         GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
 
         ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
@@ -16559,7 +16559,7 @@ struct llm_build_lfm2 : public llm_graph_context {
         const int64_t  n_seq_tokens = ubatch.n_seq_tokens;
         const int64_t  n_seqs       = ubatch.n_seqs;
         GGML_ASSERT(n_seqs != 0);
-        GGML_ASSERT(ubatch.equal_seqs);
+        GGML_ASSERT(ubatch.equal_seqs());
         GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
 
         GGML_ASSERT(hparams.n_shortconv_l_cache > 1);
@@ -16728,10 +16728,10 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
     return res;
 }
 
-llm_graph_result_ptr llama_model::build_graph(
-        const llm_graph_params & params,
-                   ggml_cgraph * gf,
-                llm_graph_type   type) const {
+ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
+    // TODO: temporary - will refactor this to keep the "gf" instance in the llm_graph_context and avoid passing it everywhere
+    auto * gf = params.res->get_gf();
+
     std::unique_ptr<llm_graph_context> llm;
 
     switch (arch) {
@@ -16951,7 +16951,7 @@ llm_graph_result_ptr llama_model::build_graph(
             } break;
         case LLM_ARCH_T5:
             {
-                switch (type) {
+                switch (params.gtype) {
                     case LLM_GRAPH_TYPE_ENCODER:
                         llm = std::make_unique<llm_build_t5_enc>(*this, params, gf);
                         break;
@@ -17057,7 +17057,7 @@ llm_graph_result_ptr llama_model::build_graph(
     // add on pooling layer
     llm->build_pooling(gf, cls, cls_b, cls_out, cls_out_b);
 
-    return std::move(llm->res);
+    return llm->res->get_gf();
 }
 
 //
index 027a7f0c3e2c694ae55171efd1524df99488a118..01b7fe3e578ec8b85ec1e631f9fb8b769aa66704 100644 (file)
@@ -452,10 +452,7 @@ struct llama_model {
     llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const;
 
     // TODO: move this to new llm_arch_model_i interface
-    llm_graph_result_ptr build_graph(
-            const llm_graph_params & params,
-                       ggml_cgraph * gf,
-                    llm_graph_type   type) const;
+    ggml_cgraph * build_graph(const llm_graph_params & params) const;
 
 private:
     struct impl;