]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
memory : rename interface to llama_memory_context_i (#14296)
authorGeorgi Gerganov <redacted>
Sat, 21 Jun 2025 05:03:46 +0000 (08:03 +0300)
committerGitHub <redacted>
Sat, 21 Jun 2025 05:03:46 +0000 (08:03 +0300)
* memory : rename interface to llama_memory_context_i

ggml-ci

* cont : fix comments

* cont : use "mctx" for referencing a memory context

ggml-ci

14 files changed:
src/llama-context.cpp
src/llama-context.h
src/llama-graph.cpp
src/llama-graph.h
src/llama-kv-cache-unified-iswa.cpp
src/llama-kv-cache-unified-iswa.h
src/llama-kv-cache-unified.cpp
src/llama-kv-cache-unified.h
src/llama-memory-hybrid.cpp
src/llama-memory-hybrid.h
src/llama-memory-recurrent.cpp
src/llama-memory-recurrent.h
src/llama-memory.h
src/llama-model.cpp

index 5a18a4fb3939a10082f2ecb2a4f17ec9815a1320..e352d81e4ed7c4d0da6153f67d667d14b8306f14 100644 (file)
@@ -280,8 +280,8 @@ llama_context::llama_context(
 
         // simulate full KV cache
 
-        const auto mstate = memory->init_full();
-        if (!mstate) {
+        const auto mctx = memory->init_full();
+        if (!mctx) {
             throw std::runtime_error("failed to initialize KV cache");
         }
 
@@ -289,7 +289,7 @@ llama_context::llama_context(
 
         // reserve pp graph first so that buffers are only allocated once
         {
-            auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
+            auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
             if (!gf) {
                 throw std::runtime_error("failed to allocate compute pp buffers");
             }
@@ -300,7 +300,7 @@ llama_context::llama_context(
 
         // reserve with tg graph to get the number of splits and nodes
         {
-            auto * gf = graph_reserve(1, 1, 1, mstate.get());
+            auto * gf = graph_reserve(1, 1, 1, mctx.get());
             if (!gf) {
                 throw std::runtime_error("failed to allocate compute tg buffers");
             }
@@ -311,7 +311,7 @@ llama_context::llama_context(
 
         // reserve again with pp graph to avoid ggml-alloc reallocations during inference
         {
-            auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
+            auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
             if (!gf) {
                 throw std::runtime_error("failed to allocate compute pp buffers");
             }
@@ -444,8 +444,8 @@ bool llama_context::kv_self_update(bool optimize) {
         optimize |= memory_force_optimize;
         memory_force_optimize = false;
 
-        const auto mstate = memory->init_update(this, optimize);
-        switch (mstate->get_status()) {
+        const auto mctx = memory->init_update(this, optimize);
+        switch (mctx->get_status()) {
             case LLAMA_MEMORY_STATUS_SUCCESS:
                 {
                     // noop
@@ -463,22 +463,22 @@ bool llama_context::kv_self_update(bool optimize) {
                 }
         }
 
-        if (!mstate->apply()) {
+        if (!mctx->apply()) {
             LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
         }
     }
 
     // if the memory module did any computation, we have to reserve a new worst-case graph
     {
-        const auto mstate = memory->init_full();
-        if (!mstate) {
-            throw std::runtime_error("failed to initialize memory state");
+        const auto mctx = memory->init_full();
+        if (!mctx) {
+            throw std::runtime_error("failed to initialize memory context");
         }
 
         const uint32_t n_seqs   = cparams.n_seq_max;
         const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
 
-        auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
+        auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
         if (!gf) {
             LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
         }
@@ -678,9 +678,9 @@ bool llama_context::apply_adapter_cvec(
     return cvec.apply(model, data, len, n_embd, il_start, il_end);
 }
 
-llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) {
-    if (mstate && !mstate->apply()) {
-        LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__);
+llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
+    if (mctx && !mctx->apply()) {
+        LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
         ret = GGML_STATUS_FAILED;
         return nullptr;
     }
@@ -692,7 +692,7 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
         return nullptr;
     }
 
-    auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate);
+    auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
     if (!res) {
         LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
         ret = GGML_STATUS_FAILED;
@@ -933,21 +933,21 @@ int llama_context::decode(const llama_batch & batch_inp) {
     // handle any pending defrags/shifts
     kv_self_update(false);
 
-    llama_memory_state_ptr mstate;
+    llama_memory_context_ptr mctx;
 
     while (true) {
-        mstate = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
-        if (!mstate) {
+        mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
+        if (!mctx) {
             return -2;
         }
 
-        switch (mstate->get_status()) {
+        switch (mctx->get_status()) {
             case LLAMA_MEMORY_STATUS_SUCCESS:
                 {
                 } break;
             case LLAMA_MEMORY_STATUS_NO_UPDATE:
                 {
-                    LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status());
+                    LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status());
 
                     return -2;
                 }
@@ -987,7 +987,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
     int64_t n_outputs_prev = 0;
 
     do {
-        const auto & ubatch = mstate->get_ubatch();
+        const auto & ubatch = mctx->get_ubatch();
 
         // count the outputs in this ubatch
         {
@@ -1009,7 +1009,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
         ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
 
         ggml_status status;
-        const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status);
+        const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
 
         if (!res) {
             // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@@ -1126,7 +1126,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
         }
 
         n_outputs_prev += n_outputs;
-    } while (mstate->next());
+    } while (mctx->next());
 
     // set to total number of outputs in the batch, for use in llama_get_logits_ith
     n_outputs = n_outputs_all;
@@ -1292,7 +1292,7 @@ ggml_cgraph * llama_context::graph_init() {
     return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
 }
 
-ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) {
+ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
     LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
 
     if (n_tokens % n_seqs != 0) {
@@ -1312,7 +1312,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
     llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
 
     auto * gf = graph_init();
-    auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
+    auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
 
     this->n_outputs = save_n_outputs;
 
@@ -1333,11 +1333,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
 }
 
 llm_graph_result_ptr llama_context::graph_build(
-                    ggml_context * ctx,
-                     ggml_cgraph * gf,
-              const llama_ubatch & ubatch,
-                  llm_graph_type   gtype,
-      const llama_memory_state_i * mstate) {
+                      ggml_context * ctx,
+                       ggml_cgraph * gf,
+                const llama_ubatch & ubatch,
+                    llm_graph_type   gtype,
+      const llama_memory_context_i * mctx) {
     return model.build_graph(
             {
                 /*.ctx         =*/ ctx,
@@ -1349,7 +1349,7 @@ llm_graph_result_ptr llama_context::graph_build(
                 /*.backend_cpu =*/ backend_cpu,
                 /*.cvec        =*/ &cvec,
                 /*.loras       =*/ &loras,
-                /*.mstate      =*/ mstate,
+                /*.mctx        =*/ mctx,
                 /*.cross       =*/ &cross,
                 /*.n_outputs   =*/ n_outputs,
                 /*.cb          =*/ graph_get_cb(),
@@ -2042,8 +2042,8 @@ void llama_context::opt_epoch_iter(
 
         uint32_t n_outputs_all = n_tokens_all;
 
-        auto mstate = memory->init_batch(*balloc, cparams.n_ubatch, true);
-        if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
+        auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true);
+        if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
             LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
             break;
         }
@@ -2056,17 +2056,17 @@ void llama_context::opt_epoch_iter(
 
         uint32_t pos_batch = 0;
         do {
-            const auto & ubatch = mstate->get_ubatch();
+            const auto & ubatch = mctx->get_ubatch();
 
             n_outputs = ubatch.n_tokens;
 
-            if (!mstate->apply()) {
-                LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
+            if (!mctx->apply()) {
+                LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__);
                 break;
             }
 
             auto * gf = graph_init();
-            auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get());
+            auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
 
             struct ggml_context * ctx_compute_opt;
             {
@@ -2101,7 +2101,7 @@ void llama_context::opt_epoch_iter(
             ggml_free(ctx_compute_opt);
 
             pos_batch += ubatch.n_tokens;
-        } while (mstate->next());
+        } while (mctx->next());
     }
 }
 
index 7d300c14572e9ed2b2009f4a11f69c36eafc3303..9ce05715a8c0306312bd03f516ae38f1f79cb2f6 100644 (file)
@@ -18,7 +18,7 @@ class llama_io_read_i;
 class llama_io_write_i;
 
 struct llama_memory_i;
-struct llama_memory_state_i;
+struct llama_memory_context_i;
 
 struct llama_context {
     // init scheduler and compute buffers, reserve worst-case graphs
@@ -93,14 +93,14 @@ struct llama_context {
                 int32_t   il_end);
 
     // process a single ubatch with a specific graph type
-    // if memory_state is provided, it will be applied first to the context's memory
+    // if memory_context is provided, it will be applied first to the context's memory
     // ret contains the status of the graph computation
     // returns nullptr only if ret != GGML_STATUS_SUCCESS
     llm_graph_result_ptr process_ubatch(
-              const llama_ubatch & ubatch,
-                  llm_graph_type   gtype,
-            llama_memory_state_i * mstate,
-                     ggml_status & ret);
+                const llama_ubatch & ubatch,
+                    llm_graph_type   gtype,
+            llama_memory_context_i * mctx,
+                       ggml_status & ret);
 
     int encode(const llama_batch & batch_inp);
     int decode(const llama_batch & batch_inp);
@@ -197,15 +197,15 @@ public:
     ggml_status graph_compute(ggml_cgraph * gf, bool batched);
 
     // reserve a graph with a dummy ubatch of the specified size
-    ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate);
+    ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
 
 private:
     llm_graph_result_ptr graph_build(
-                    ggml_context * ctx,
-                     ggml_cgraph * gf,
-              const llama_ubatch & ubatch,
-                  llm_graph_type   gtype,
-      const llama_memory_state_i * mstate);
+                      ggml_context * ctx,
+                       ggml_cgraph * gf,
+                const llama_ubatch & ubatch,
+                    llm_graph_type   gtype,
+      const llama_memory_context_i * mctx);
 
     llm_graph_cb graph_get_cb() const;
 
index 7e162c555220439e9d6cdf204deedca4ba104471..48589a50ab24d4f535b6c76add9061ad393c047d 100644 (file)
@@ -87,7 +87,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
 
 void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
     if (pos_bucket) {
-        kv_state->set_input_pos_bucket(pos_bucket, ubatch);
+        mctx->set_input_pos_bucket(pos_bucket, ubatch);
     }
 }
 
@@ -221,7 +221,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
 void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
     GGML_UNUSED(ubatch);
 
-    const int64_t n_rs = mem_state->get_n_rs();
+    const int64_t n_rs = mctx->get_n_rs();
 
     if (s_copy) {
         GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
@@ -229,7 +229,7 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
 
         // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
         for (uint32_t i = 0; i < n_rs; ++i) {
-            data[i] = mem_state->s_copy(i);
+            data[i] = mctx->s_copy(i);
         }
     }
 }
@@ -282,17 +282,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
 
 void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
     if (self_kq_mask) {
-        kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
+        mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
     }
 }
 
 void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
     if (self_kq_mask) {
-        kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
+        mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
     }
 
     if (self_kq_mask_swa) {
-        kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
+        mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
     }
 }
 
@@ -334,10 +334,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
 
 void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
     if (self_kq_mask) {
-        mem_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
+        mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
     }
 
-    const int64_t n_rs = mem_state->get_state_recr()->get_n_rs();
+    const int64_t n_rs = mctx->get_recr()->get_n_rs();
 
     if (s_copy) {
         GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
@@ -345,7 +345,7 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
 
         // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
         for (uint32_t i = 0; i < n_rs; ++i) {
-            data[i] = mem_state->get_state_recr()->s_copy(i);
+            data[i] = mctx->get_recr()->s_copy(i);
         }
     }
 }
@@ -389,7 +389,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
     backend_cpu      (params.backend_cpu),
     cvec             (params.cvec),
     loras            (params.loras),
-    mstate           (params.mstate),
+    mctx             (params.mctx),
     cross            (params.cross),
     cb_func          (params.cb),
     res              (std::make_unique<llm_graph_result>()) {
@@ -950,11 +950,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
 }
 
 ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
-    const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
+    const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
 
-    auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state);
+    auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
 
-    const auto n_kv = kv_state->get_n_kv();
+    const auto n_kv = mctx_cur->get_n_kv();
 
     auto & cur = inp->pos_bucket;
 
@@ -982,14 +982,14 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
 }
 
 llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
-    const auto * mem_state = static_cast<const llama_memory_hybrid_state *>(mstate);
+    const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
 
-    auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mem_state);
+    auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
 
     {
         GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
 
-        const auto n_kv = inp->mem_state->get_state_attn()->get_n_kv();
+        const auto n_kv = inp->mctx->get_attn()->get_n_kv();
 
         inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
         //cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -999,7 +999,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
     }
 
     {
-        const auto n_rs = mem_state->get_state_recr()->get_n_rs();
+        const auto n_rs = mctx_cur->get_recr()->get_n_rs();
 
         inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
         ggml_set_input(inp->s_copy);
@@ -1183,14 +1183,14 @@ ggml_tensor * llm_graph_context::build_attn(
 }
 
 llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
-    const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
+    const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
 
-    auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state);
+    auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
 
     {
         GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
 
-        const auto n_kv = kv_state->get_n_kv();
+        const auto n_kv = mctx_cur->get_n_kv();
 
         inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
         //cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1220,19 +1220,19 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_build_forward_expand(gf, k_cur);
     ggml_build_forward_expand(gf, v_cur);
 
-    const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
+    const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
 
     // store to KV cache
     {
-        ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
-        ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
+        ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
+        ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
     }
 
     const auto & kq_mask = inp->get_kq_mask();
 
     ggml_tensor * q = q_cur;
-    ggml_tensor * k = kv_state->get_k(ctx0, il);
-    ggml_tensor * v = kv_state->get_v(ctx0, il);
+    ggml_tensor * k = mctx_cur->get_k(ctx0, il);
+    ggml_tensor * v = mctx_cur->get_v(ctx0, il);
 
     ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
     cb(cur, "kqv_out", il);
@@ -1270,23 +1270,23 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_build_forward_expand(gf, k_cur);
     ggml_build_forward_expand(gf, v_cur);
 
-    const auto * kv_state_iswa = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
+    const auto * mctx_iswa = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
 
     const bool is_swa = hparams.is_swa(il);
 
-    const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base();
+    const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
 
     // store to KV cache
     {
-        ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
-        ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
+        ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
+        ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
     }
 
     const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
 
     ggml_tensor * q = q_cur;
-    ggml_tensor * k = kv_state->get_k(ctx0, il);
-    ggml_tensor * v = kv_state->get_v(ctx0, il);
+    ggml_tensor * k = mctx_cur->get_k(ctx0, il);
+    ggml_tensor * v = mctx_cur->get_v(ctx0, il);
 
     ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
     cb(cur, "kqv_out", il);
@@ -1379,19 +1379,19 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_build_forward_expand(gf, k_cur);
     ggml_build_forward_expand(gf, v_cur);
 
-    const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_attn();
+    const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
 
     // store to KV cache
     {
-        ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
-        ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
+        ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
+        ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
     }
 
     const auto & kq_mask = inp->get_kq_mask();
 
     ggml_tensor * q = q_cur;
-    ggml_tensor * k = kv_state->get_k(ctx0, il);
-    ggml_tensor * v = kv_state->get_v(ctx0, il);
+    ggml_tensor * k = mctx_cur->get_k(ctx0, il);
+    ggml_tensor * v = mctx_cur->get_v(ctx0, il);
 
     ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
     cb(cur, "kqv_out", il);
@@ -1412,12 +1412,12 @@ ggml_tensor * llm_graph_context::build_attn(
 }
 
 llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
-    const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
+    const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
 
-    auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
+    auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
 
     {
-        const auto n_kv = kv_state->get_base()->get_n_kv();
+        const auto n_kv = mctx_cur->get_base()->get_n_kv();
 
         inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
         //cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1429,7 +1429,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
     {
         GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
 
-        const auto n_kv = kv_state->get_swa()->get_n_kv();
+        const auto n_kv = mctx_cur->get_swa()->get_n_kv();
 
         inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
         //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
@@ -1485,11 +1485,11 @@ ggml_tensor * llm_graph_context::build_rs(
 }
 
 llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
-    const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
+    const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
 
-    auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
+    auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
 
-    const auto n_rs = kv_state->get_n_rs();
+    const auto n_rs = mctx_cur->get_n_rs();
 
     inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
     ggml_set_input(inp->s_copy);
@@ -1504,9 +1504,9 @@ ggml_tensor * llm_graph_context::build_rs(
             int32_t   state_size,
             int32_t   n_seqs,
                bool   avoid_copies) const {
-    const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
+    const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
 
-    return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
+    return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
 }
 
 ggml_tensor * llm_graph_context::build_rs(
@@ -1516,9 +1516,9 @@ ggml_tensor * llm_graph_context::build_rs(
             int32_t   state_size,
             int32_t   n_seqs,
                bool   avoid_copies) const {
-    const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_recr();
+    const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
 
-    return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
+    return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
 }
 
 ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
@@ -1526,13 +1526,13 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
            ggml_cgraph * gf,
     const llama_ubatch & ubatch,
                  int   il) const {
-    const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
+    const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
 
     const auto token_shift_count = hparams.token_shift_count;
 
     const int64_t n_seqs  = ubatch.n_seqs;
 
-    ggml_tensor * token_shift_all = kv_state->get_r_l(il);
+    ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
 
     ggml_tensor * token_shift = build_rs(
             inp, gf, token_shift_all,
@@ -1547,19 +1547,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
          ggml_tensor * token_shift,
   const llama_ubatch & ubatch,
                  int   il) const {
-    const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
+    const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
 
     const auto token_shift_count = hparams.token_shift_count;
     const auto n_embd = hparams.n_embd;
 
     const int64_t n_seqs = ubatch.n_seqs;
 
-    const auto kv_head = kv_state->get_head();
+    const auto kv_head = mctx_cur->get_head();
 
     return ggml_cpy(
         ctx0,
         ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
-        ggml_view_1d(ctx0, kv_state->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(kv_state->get_r_l(il)))
+        ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
     );
 }
 
index 9e62fa60720d78b9f3aded7ff29bf1d4212197bb..b433f266d1b295e94899780b0f87cf57a52781d1 100644 (file)
@@ -17,12 +17,12 @@ struct ggml_tensor;
 struct llama_ubatch;
 struct llama_cparams;
 
-struct llama_memory_state_i;
+struct llama_memory_context_i;
 
-class llama_kv_cache_unified_state;
-class llama_kv_cache_unified_iswa_state;
-class llama_memory_recurrent_state;
-class llama_memory_hybrid_state;
+class llama_kv_cache_unified_context;
+class llama_kv_cache_unified_iswa_context;
+class llama_memory_recurrent_context;
+class llama_memory_hybrid_context;
 
 // certain models (typically multi-modal) can produce different types of graphs
 enum llm_graph_type {
@@ -136,7 +136,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
 public:
     llm_graph_input_pos_bucket_kv(
             const llama_hparams & hparams,
-            const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {}
+            const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
     virtual ~llm_graph_input_pos_bucket_kv() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
@@ -144,7 +144,8 @@ public:
     ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
 
     const llama_hparams & hparams;
-    const llama_kv_cache_unified_state * kv_state;
+
+    const llama_kv_cache_unified_context * mctx;
 };
 
 class llm_graph_input_out_ids : public llm_graph_input_i {
@@ -191,14 +192,14 @@ public:
 
 class llm_graph_input_rs : public llm_graph_input_i {
 public:
-    llm_graph_input_rs(const llama_memory_recurrent_state * mem_state) : mem_state(mem_state) {}
+    llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
     virtual ~llm_graph_input_rs() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
 
     ggml_tensor * s_copy; // I32 [kv_size]
 
-    const llama_memory_recurrent_state * mem_state;
+    const llama_memory_recurrent_context * mctx;
 };
 
 class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -238,10 +239,10 @@ public:
     llm_graph_input_attn_kv_unified(
             const llama_hparams & hparams,
             const llama_cparams & cparams,
-            const llama_kv_cache_unified_state * kv_state) :
+            const llama_kv_cache_unified_context * mctx) :
         hparams(hparams),
         cparams(cparams),
-        kv_state(kv_state) {
+        mctx(mctx) {
     }
     ~llm_graph_input_attn_kv_unified() = default;
 
@@ -255,7 +256,7 @@ public:
     const llama_hparams & hparams;
     const llama_cparams & cparams;
 
-    const llama_kv_cache_unified_state * kv_state;
+    const llama_kv_cache_unified_context * mctx;
 };
 
 class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
@@ -263,10 +264,10 @@ public:
     llm_graph_input_attn_kv_unified_iswa(
             const llama_hparams & hparams,
             const llama_cparams & cparams,
-            const llama_kv_cache_unified_iswa_state * kv_state) :
+            const llama_kv_cache_unified_iswa_context * mctx) :
         hparams(hparams),
         cparams(cparams),
-        kv_state(kv_state) {
+        mctx(mctx) {
     }
     ~llm_graph_input_attn_kv_unified_iswa() = default;
 
@@ -283,7 +284,7 @@ public:
     const llama_hparams & hparams;
     const llama_cparams & cparams;
 
-    const llama_kv_cache_unified_iswa_state * kv_state;
+    const llama_kv_cache_unified_iswa_context * mctx;
 };
 
 class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -306,10 +307,10 @@ public:
     llm_graph_input_mem_hybrid(
             const llama_hparams & hparams,
             const llama_cparams & cparams,
-            const llama_memory_hybrid_state * mem_state) :
+            const llama_memory_hybrid_context * mctx) :
         hparams(hparams),
         cparams(cparams),
-        mem_state(mem_state) {
+        mctx(mctx) {
     }
     virtual ~llm_graph_input_mem_hybrid() = default;
 
@@ -325,7 +326,7 @@ public:
     const llama_hparams & hparams;
     const llama_cparams & cparams;
 
-    const llama_memory_hybrid_state * mem_state;
+    const llama_memory_hybrid_context * mctx;
 };
 
 //
@@ -401,10 +402,10 @@ struct llm_graph_params {
     ggml_backend_sched_t sched;
     ggml_backend_t backend_cpu;
 
-    const llama_adapter_cvec   * cvec;
-    const llama_adapter_loras  * loras;
-    const llama_memory_state_i * mstate;
-    const llama_cross          * cross;
+    const llama_adapter_cvec     * cvec;
+    const llama_adapter_loras    * loras;
+    const llama_memory_context_i * mctx;
+    const llama_cross            * cross;
 
     uint32_t n_outputs;
 
@@ -453,10 +454,10 @@ struct llm_graph_context {
 
     ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
 
-    const llama_adapter_cvec   * cvec;
-    const llama_adapter_loras  * loras;
-    const llama_memory_state_i * mstate;
-    const llama_cross          * cross;
+    const llama_adapter_cvec     * cvec;
+    const llama_adapter_loras    * loras;
+    const llama_memory_context_i * mctx;
+    const llama_cross            * cross;
 
     const llm_graph_cb & cb_func;
 
index 0ced340dec6c5c787f767163d4d3849785f31d61..b9169299c0760c7dd5160683c715ed34ccc568df 100644 (file)
@@ -95,7 +95,7 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
     return kv_swa->seq_pos_max(seq_id);
 }
 
-llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
+llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
     GGML_UNUSED(embd_all);
 
     // first try simple split
@@ -125,7 +125,7 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_alloc
 
         assert(heads_base.size() == heads_swa.size());
 
-        return std::make_unique<llama_kv_cache_unified_iswa_state>(
+        return std::make_unique<llama_kv_cache_unified_iswa_context>(
                 this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
     } while (false);
 
@@ -156,22 +156,22 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_alloc
 
         assert(heads_base.size() == heads_swa.size());
 
-        return std::make_unique<llama_kv_cache_unified_iswa_state>(
+        return std::make_unique<llama_kv_cache_unified_iswa_context>(
                 this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
     } while (false);
 
     // TODO: if we fail again, we should attempt different splitting strategies
     //       but to do that properly, we first have to refactor the batches to be more flexible
 
-    return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+    return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
 }
 
-llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
-    return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
+llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
+    return std::make_unique<llama_kv_cache_unified_iswa_context>(this);
 }
 
-llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
-    return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
+llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
+    return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize);
 }
 
 bool llama_kv_cache_unified_iswa::get_can_shift() const {
@@ -197,46 +197,46 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
 }
 
 //
-// llama_kv_cache_unified_iswa_state
+// llama_kv_cache_unified_iswa_context
 //
 
-llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
+llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
 
-llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
+llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
         llama_kv_cache_unified_iswa * kv) :
-    state_base(kv->get_base()->init_full()),
-    state_swa (kv->get_swa ()->init_full()),
-    status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
+    ctx_base(kv->get_base()->init_full()),
+    ctx_swa (kv->get_swa ()->init_full()),
+    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
 }
 
-llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
+llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
         llama_kv_cache_unified_iswa * kv,
         llama_context * lctx,
         bool optimize) :
-    state_base(kv->get_base()->init_update(lctx, optimize)),
-    state_swa (kv->get_swa ()->init_update(lctx, optimize)),
-    status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
+    ctx_base(kv->get_base()->init_update(lctx, optimize)),
+    ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
+    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
 }
 
-llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
+llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
         llama_kv_cache_unified_iswa * kv,
         std::vector<uint32_t> heads_base,
         std::vector<uint32_t> heads_swa,
         std::vector<llama_ubatch> ubatches) :
     ubatches(std::move(ubatches)),
     // note: here we copy the ubatches. not sure if this is ideal
-    state_base(new llama_kv_cache_unified_state(kv->get_base(), std::move(heads_base), this->ubatches)),
-    state_swa (new llama_kv_cache_unified_state(kv->get_swa (), std::move(heads_swa),  this->ubatches)),
-    status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
+    ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
+    ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa),  this->ubatches)),
+    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
 }
 
-llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
+llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default;
 
-bool llama_kv_cache_unified_iswa_state::next() {
+bool llama_kv_cache_unified_iswa_context::next() {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
-    state_base->next();
-    state_swa ->next();
+    ctx_base->next();
+    ctx_swa ->next();
 
     if (++i_next >= ubatches.size()) {
         return false;
@@ -245,35 +245,35 @@ bool llama_kv_cache_unified_iswa_state::next() {
     return true;
 }
 
-bool llama_kv_cache_unified_iswa_state::apply() {
+bool llama_kv_cache_unified_iswa_context::apply() {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
     bool res = true;
 
-    res = res & state_base->apply();
-    res = res & state_swa ->apply();
+    res = res & ctx_base->apply();
+    res = res & ctx_swa ->apply();
 
     return res;
 }
 
-llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
+llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const {
     return status;
 }
 
-const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
+const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
     return ubatches[i_next];
 }
 
-const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
+const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
-    return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
+    return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get());
 }
 
-const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa()  const {
+const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa()  const {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
-    return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
+    return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get());
 }
index 071041585db38588217547ef556470ddbae2fdc1..46c1ed614f2f0166960d86035396d26a76520420 100644 (file)
@@ -31,14 +31,14 @@ public:
     // llama_memory_i
     //
 
-    llama_memory_state_ptr init_batch(
+    llama_memory_context_ptr init_batch(
             llama_batch_allocr & balloc,
             uint32_t n_ubatch,
             bool embd_all) override;
 
-    llama_memory_state_ptr init_full() override;
+    llama_memory_context_ptr init_full() override;
 
-    llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
+    llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
 
     bool get_can_shift() const override;
 
@@ -72,32 +72,32 @@ private:
     std::unique_ptr<llama_kv_cache_unified> kv_swa;
 };
 
-class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
+class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
 public:
     // used for errors
-    llama_kv_cache_unified_iswa_state(llama_memory_status status);
+    llama_kv_cache_unified_iswa_context(llama_memory_status status);
 
-    // used to create a full-cache state
-    llama_kv_cache_unified_iswa_state(
+    // used to create a full-cache context
+    llama_kv_cache_unified_iswa_context(
             llama_kv_cache_unified_iswa * kv);
 
-    // used to create an update state
-    llama_kv_cache_unified_iswa_state(
+    // used to create an update context
+    llama_kv_cache_unified_iswa_context(
             llama_kv_cache_unified_iswa * kv,
             llama_context * lctx,
             bool optimize);
 
-    // used to create a state from a batch
-    llama_kv_cache_unified_iswa_state(
+    // used to create a batch processing context from a batch
+    llama_kv_cache_unified_iswa_context(
             llama_kv_cache_unified_iswa * kv,
             std::vector<uint32_t> heads_base,
             std::vector<uint32_t> heads_swa,
             std::vector<llama_ubatch> ubatches);
 
-    virtual ~llama_kv_cache_unified_iswa_state();
+    virtual ~llama_kv_cache_unified_iswa_context();
 
     //
-    // llama_memory_state_i
+    // llama_memory_context_i
     //
 
     bool next()  override;
@@ -107,11 +107,11 @@ public:
     const llama_ubatch & get_ubatch() const override;
 
     //
-    // llama_kv_cache_unified_iswa_state specific API
+    // llama_kv_cache_unified_iswa_context specific API
     //
 
-    const llama_kv_cache_unified_state * get_base() const;
-    const llama_kv_cache_unified_state * get_swa()  const;
+    const llama_kv_cache_unified_context * get_base() const;
+    const llama_kv_cache_unified_context * get_swa()  const;
 
 private:
     //llama_kv_cache_unified_iswa * kv;
@@ -121,8 +121,8 @@ private:
 
     std::vector<llama_ubatch> ubatches;
 
-    const llama_memory_state_ptr state_base;
-    const llama_memory_state_ptr state_swa;
+    const llama_memory_context_ptr ctx_base;
+    const llama_memory_context_ptr ctx_swa;
 
     const llama_memory_status status;
 };
index 6897b797153dbe894b8d66cf9f626463efe47a9c..b506d32ed4d06af3096656cb6b3368dad3f800da 100644 (file)
@@ -307,7 +307,7 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
     return cells.seq_pos_max(seq_id);
 }
 
-llama_memory_state_ptr llama_kv_cache_unified::init_batch(
+llama_memory_context_ptr llama_kv_cache_unified::init_batch(
             llama_batch_allocr & balloc,
             uint32_t n_ubatch,
             bool embd_all) {
@@ -332,18 +332,18 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
             break;
         }
 
-        return std::make_unique<llama_kv_cache_unified_state>(
+        return std::make_unique<llama_kv_cache_unified_context>(
                 this, std::move(heads), std::move(ubatches));
     } while (false);
 
-    return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+    return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
 }
 
-llama_memory_state_ptr llama_kv_cache_unified::init_full() {
-    return std::make_unique<llama_kv_cache_unified_state>(this);
+llama_memory_context_ptr llama_kv_cache_unified::init_full() {
+    return std::make_unique<llama_kv_cache_unified_context>(this);
 }
 
-llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
+llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
     bool do_shift = get_has_shift();
 
     defrag_info dinfo;
@@ -373,7 +373,7 @@ llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx,
         }
     }
 
-    return std::make_unique<llama_kv_cache_unified_state>(this, lctx, do_shift, std::move(dinfo));
+    return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
 }
 
 llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
@@ -1710,18 +1710,18 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
 }
 
 //
-// llama_kv_cache_unified_state
+// llama_kv_cache_unified_context
 //
 
-llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
+llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_status status) : status(status) {}
 
-llama_kv_cache_unified_state::llama_kv_cache_unified_state(
+llama_kv_cache_unified_context::llama_kv_cache_unified_context(
         llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
     n_kv = kv->get_size();
     head = 0;
 }
 
-llama_kv_cache_unified_state::llama_kv_cache_unified_state(
+llama_kv_cache_unified_context::llama_kv_cache_unified_context(
         llama_kv_cache_unified * kv,
         llama_context * lctx,
         bool do_shift,
@@ -1731,15 +1731,15 @@ llama_kv_cache_unified_state::llama_kv_cache_unified_state(
     }
 }
 
-llama_kv_cache_unified_state::llama_kv_cache_unified_state(
+llama_kv_cache_unified_context::llama_kv_cache_unified_context(
         llama_kv_cache_unified * kv,
         llama_kv_cache_unified::ubatch_heads heads,
         std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
 }
 
-llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
+llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
 
-bool llama_kv_cache_unified_state::next() {
+bool llama_kv_cache_unified_context::next() {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
     if (++i_next >= ubatches.size()) {
@@ -1749,7 +1749,7 @@ bool llama_kv_cache_unified_state::next() {
     return true;
 }
 
-bool llama_kv_cache_unified_state::apply() {
+bool llama_kv_cache_unified_context::apply() {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
     // no ubatches -> this is a KV cache update
@@ -1767,45 +1767,45 @@ bool llama_kv_cache_unified_state::apply() {
     return true;
 }
 
-llama_memory_status llama_kv_cache_unified_state::get_status() const {
+llama_memory_status llama_kv_cache_unified_context::get_status() const {
     return status;
 }
 
-const llama_ubatch & llama_kv_cache_unified_state::get_ubatch() const {
+const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
     return ubatches[i_next];
 }
 
-uint32_t llama_kv_cache_unified_state::get_n_kv() const {
+uint32_t llama_kv_cache_unified_context::get_n_kv() const {
     return n_kv;
 }
 
-ggml_tensor * llama_kv_cache_unified_state::get_k(ggml_context * ctx, int32_t il) const {
+ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
     return kv->get_k(ctx, il, n_kv);
 }
 
-ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il) const {
+ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
     return kv->get_v(ctx, il, n_kv);
 }
 
-ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
+ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
     return kv->cpy_k(ctx, k_cur, il, head);
 }
 
-ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
+ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
     return kv->cpy_v(ctx, v_cur, il, head);
 }
 
-void llama_kv_cache_unified_state::set_input_k_shift(ggml_tensor * dst) const {
+void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
     kv->set_input_k_shift(dst);
 }
 
-void llama_kv_cache_unified_state::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
+void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
     kv->set_input_kq_mask(dst, ubatch, causal_attn);
 }
 
-void llama_kv_cache_unified_state::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
+void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
     kv->set_input_pos_bucket(dst, ubatch);
 }
 
index 1560640045c82d9267bc59b472eaf7b25fd98c3f..4c53f1273ab88326cb21f6bc25fd999935a8b761 100644 (file)
@@ -56,14 +56,14 @@ public:
     // llama_memory_i
     //
 
-    llama_memory_state_ptr init_batch(
+    llama_memory_context_ptr init_batch(
             llama_batch_allocr & balloc,
             uint32_t n_ubatch,
             bool embd_all) override;
 
-    llama_memory_state_ptr init_full() override;
+    llama_memory_context_ptr init_full() override;
 
-    llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
+    llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
 
     bool get_can_shift() const override;
 
@@ -208,36 +208,36 @@ private:
     bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
 };
 
-class llama_kv_cache_unified_state : public llama_memory_state_i {
+class llama_kv_cache_unified_context : public llama_memory_context_i {
 public:
     // some shorthands
     using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
     using defrag_info  = llama_kv_cache_unified::defrag_info;
 
     // used for errors
-    llama_kv_cache_unified_state(llama_memory_status status);
+    llama_kv_cache_unified_context(llama_memory_status status);
 
-    // used to create a full-cache state
-    llama_kv_cache_unified_state(
+    // used to create a full-cache context
+    llama_kv_cache_unified_context(
             llama_kv_cache_unified * kv);
 
-    // used to create an update state
-    llama_kv_cache_unified_state(
+    // used to create an update context
+    llama_kv_cache_unified_context(
             llama_kv_cache_unified * kv,
             llama_context * lctx,
             bool do_shift,
             defrag_info dinfo);
 
-    // used to create a decode state from a batch
-    llama_kv_cache_unified_state(
+    // used to create a batch procesing context from a batch
+    llama_kv_cache_unified_context(
             llama_kv_cache_unified * kv,
             ubatch_heads heads,
             std::vector<llama_ubatch> ubatches);
 
-    virtual ~llama_kv_cache_unified_state();
+    virtual ~llama_kv_cache_unified_context();
 
     //
-    // llama_memory_state_i
+    // llama_memory_context_i
     //
 
     bool next()  override;
@@ -247,7 +247,7 @@ public:
     const llama_ubatch & get_ubatch() const override;
 
     //
-    // llama_kv_cache_unified_state specific API
+    // llama_kv_cache_unified_context specific API
     //
 
     uint32_t get_n_kv() const;
@@ -272,7 +272,7 @@ private:
     llama_context * lctx;
 
     //
-    // update state
+    // update context
     //
 
     bool do_shift = false;
@@ -280,7 +280,7 @@ private:
     defrag_info dinfo;
 
     //
-    // batch processing state
+    // batch processing context
     //
 
     // the index of the next ubatch to process
index 1b16686819eff8582dc7ab6eea4334cb304f69f8..15cde98d138a85a7d9ea02512e6f7c587d7667fc 100644 (file)
@@ -56,7 +56,7 @@ llama_memory_hybrid::llama_memory_hybrid(
         n_seq_max
     )) {}
 
-llama_memory_state_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
+llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
     do {
         balloc.split_reset();
 
@@ -82,31 +82,31 @@ llama_memory_state_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ball
 
         // prepare the recurrent batches first
         if (!mem_recr->prepare(ubatches)) {
-            // TODO: will the recurrent cache be in an undefined state at this point?
+            // TODO: will the recurrent cache be in an undefined context at this point?
             LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
-            return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+            return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
         }
 
         // prepare the attention cache
         auto heads_attn = mem_attn->prepare(ubatches);
         if (heads_attn.empty()) {
             LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
-            return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+            return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
         }
 
-        return std::make_unique<llama_memory_hybrid_state>(
+        return std::make_unique<llama_memory_hybrid_context>(
                 this, std::move(heads_attn), std::move(ubatches));
     } while(false);
 
-    return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+    return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
 }
 
-llama_memory_state_ptr llama_memory_hybrid::init_full() {
-    return std::make_unique<llama_memory_hybrid_state>(this);
+llama_memory_context_ptr llama_memory_hybrid::init_full() {
+    return std::make_unique<llama_memory_hybrid_context>(this);
 }
 
-llama_memory_state_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
-    return std::make_unique<llama_memory_hybrid_state>(this, lctx, optimize);
+llama_memory_context_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
+    return std::make_unique<llama_memory_hybrid_context>(this, lctx, optimize);
 }
 
 bool llama_memory_hybrid::get_can_shift() const {
@@ -176,39 +176,39 @@ llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
     return mem_recr.get();
 }
 
-llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_status status) : status(status) {}
+llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_status status) : status(status) {}
 
-llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_hybrid * mem) :
-    state_attn(mem->get_mem_attn()->init_full()),
-    state_recr(mem->get_mem_recr()->init_full()),
-    status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
+llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_hybrid * mem) :
+    ctx_attn(mem->get_mem_attn()->init_full()),
+    ctx_recr(mem->get_mem_recr()->init_full()),
+    status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
 }
 
-llama_memory_hybrid_state::llama_memory_hybrid_state(
+llama_memory_hybrid_context::llama_memory_hybrid_context(
         llama_memory_hybrid * mem,
               llama_context * lctx,
                        bool   optimize) :
-    state_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
-    state_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
-    status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
+    ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
+    ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
+    status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
 }
 
-llama_memory_hybrid_state::llama_memory_hybrid_state(
+llama_memory_hybrid_context::llama_memory_hybrid_context(
               llama_memory_hybrid * mem,
             std::vector<uint32_t>   heads_attn,
         std::vector<llama_ubatch>   ubatches) :
     ubatches(std::move(ubatches)),
     // note: here we copy the ubatches. not sure if this is ideal
-    state_attn(new llama_kv_cache_unified_state(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
-    state_recr(new llama_memory_recurrent_state(mem->get_mem_recr(),                        this->ubatches)),
-    status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
+    ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
+    ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(),                        this->ubatches)),
+    status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
 }
 
-bool llama_memory_hybrid_state::next() {
+bool llama_memory_hybrid_context::next() {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
-    state_attn->next();
-    state_recr->next();
+    ctx_attn->next();
+    ctx_recr->next();
 
     if (++i_next >= ubatches.size()) {
         return false;
@@ -217,30 +217,30 @@ bool llama_memory_hybrid_state::next() {
     return true;
 }
 
-bool llama_memory_hybrid_state::apply() {
+bool llama_memory_hybrid_context::apply() {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
     bool res = true;
 
-    res = res & state_attn->apply();
-    res = res & state_recr->apply();
+    res = res & ctx_attn->apply();
+    res = res & ctx_recr->apply();
 
     return res;
 }
 
-llama_memory_status llama_memory_hybrid_state::get_status() const {
+llama_memory_status llama_memory_hybrid_context::get_status() const {
     return status;
 }
 
-const llama_ubatch & llama_memory_hybrid_state::get_ubatch() const {
+const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
     return ubatches[i_next];
 }
 
-const llama_kv_cache_unified_state * llama_memory_hybrid_state::get_state_attn() const {
-    return static_cast<const llama_kv_cache_unified_state *>(state_attn.get());
+const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const {
+    return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get());
 }
 
-const llama_memory_recurrent_state * llama_memory_hybrid_state::get_state_recr() const {
-    return static_cast<const llama_memory_recurrent_state *>(state_recr.get());
+const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {
+    return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get());
 }
index 4d27ab896aa05177b8711b7fad6a1e39b13653e6..f0c2420e9a2df5bc04093c07e972a837b8ccd1aa 100644 (file)
@@ -49,14 +49,14 @@ public:
     // llama_memory_i
     //
 
-    llama_memory_state_ptr init_batch(
+    llama_memory_context_ptr init_batch(
             llama_batch_allocr & balloc,
             uint32_t n_ubatch,
             bool embd_all) override;
 
-    llama_memory_state_ptr init_full() override;
+    llama_memory_context_ptr init_full() override;
 
-    llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
+    llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
 
     bool get_can_shift() const override;
 
@@ -90,27 +90,27 @@ private:
     const std::unique_ptr<llama_memory_recurrent> mem_recr;
 };
 
-class llama_memory_hybrid_state : public llama_memory_state_i {
+class llama_memory_hybrid_context : public llama_memory_context_i {
 public:
     // init failure
-    explicit llama_memory_hybrid_state(llama_memory_status status);
+    explicit llama_memory_hybrid_context(llama_memory_status status);
 
     // init full
-    explicit llama_memory_hybrid_state(llama_memory_hybrid * mem);
+    explicit llama_memory_hybrid_context(llama_memory_hybrid * mem);
 
     // init update
-    explicit llama_memory_hybrid_state(
+    explicit llama_memory_hybrid_context(
         llama_memory_hybrid * mem,
               llama_context * lctx,
                        bool   optimize);
 
     // init success
-    llama_memory_hybrid_state(
+    llama_memory_hybrid_context(
               llama_memory_hybrid * mem,
             std::vector<uint32_t>   heads_attn,
         std::vector<llama_ubatch>   ubatches);
 
-    ~llama_memory_hybrid_state() = default;
+    ~llama_memory_hybrid_context() = default;
 
     bool next()  override;
     bool apply() override;
@@ -119,11 +119,11 @@ public:
     const llama_ubatch & get_ubatch() const override;
 
     //
-    // llama_memory_hybrid_state
+    // llama_memory_hybrid_context
     //
 
-    const llama_kv_cache_unified_state * get_state_attn() const;
-    const llama_memory_recurrent_state * get_state_recr() const;
+    const llama_kv_cache_unified_context * get_attn() const;
+    const llama_memory_recurrent_context * get_recr() const;
 
 private:
     // the index of the next ubatch to process
@@ -131,8 +131,8 @@ private:
 
     std::vector<llama_ubatch> ubatches;
 
-    const llama_memory_state_ptr state_attn;
-    const llama_memory_state_ptr state_recr;
+    const llama_memory_context_ptr ctx_attn;
+    const llama_memory_context_ptr ctx_recr;
 
     const llama_memory_status status;
 };
index b064da0084c5295f3a28b93c9dbfb1a15452f6f1..1b1e95d567a6cf9bec8284ffe40fa2863462a82e 100644 (file)
@@ -362,7 +362,7 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
     return result;
 }
 
-llama_memory_state_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
+llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
     std::vector<llama_ubatch> ubatches;
 
     while (true) {
@@ -383,21 +383,21 @@ llama_memory_state_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & b
     }
 
     if (!prepare(ubatches)) {
-        return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+        return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
     }
 
-    return std::make_unique<llama_memory_recurrent_state>(this, std::move(ubatches));
+    return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
 }
 
-llama_memory_state_ptr llama_memory_recurrent::init_full() {
-    return std::make_unique<llama_memory_recurrent_state>(this);
+llama_memory_context_ptr llama_memory_recurrent::init_full() {
+    return std::make_unique<llama_memory_recurrent_context>(this);
 }
 
-llama_memory_state_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
+llama_memory_context_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
     GGML_UNUSED(lctx);
     GGML_UNUSED(optimize);
 
-    return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
+    return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_NO_UPDATE);
 }
 
 bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
@@ -1040,22 +1040,22 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
 }
 
 //
-// llama_memory_recurrent_state
+// llama_memory_recurrent_context
 //
 
-llama_memory_recurrent_state::llama_memory_recurrent_state(llama_memory_status status) : status(status) {}
+llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {}
 
-llama_memory_recurrent_state::llama_memory_recurrent_state(
+llama_memory_recurrent_context::llama_memory_recurrent_context(
         llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
 }
 
-llama_memory_recurrent_state::llama_memory_recurrent_state(
+llama_memory_recurrent_context::llama_memory_recurrent_context(
         llama_memory_recurrent * mem,
         std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
 
-llama_memory_recurrent_state::~llama_memory_recurrent_state() = default;
+llama_memory_recurrent_context::~llama_memory_recurrent_context() = default;
 
-bool llama_memory_recurrent_state::next() {
+bool llama_memory_recurrent_context::next() {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
     if (++i_next >= ubatches.size()) {
@@ -1065,7 +1065,7 @@ bool llama_memory_recurrent_state::next() {
     return true;
 }
 
-bool llama_memory_recurrent_state::apply() {
+bool llama_memory_recurrent_context::apply() {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
     mem->find_slot(ubatches[i_next]);
@@ -1073,40 +1073,40 @@ bool llama_memory_recurrent_state::apply() {
     return true;
 }
 
-llama_memory_status llama_memory_recurrent_state::get_status() const {
+llama_memory_status llama_memory_recurrent_context::get_status() const {
     return status;
 }
 
-const llama_ubatch & llama_memory_recurrent_state::get_ubatch() const {
+const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
     return ubatches[i_next];
 }
 
-uint32_t llama_memory_recurrent_state::get_n_rs() const {
+uint32_t llama_memory_recurrent_context::get_n_rs() const {
     return is_full ? mem->size : mem->n;
 }
 
-uint32_t llama_memory_recurrent_state::get_head() const {
+uint32_t llama_memory_recurrent_context::get_head() const {
     return is_full ? 0 : mem->head;
 }
 
-int32_t llama_memory_recurrent_state::get_rs_z() const {
+int32_t llama_memory_recurrent_context::get_rs_z() const {
     return is_full ? 0 : mem->rs_z;
 }
 
-uint32_t llama_memory_recurrent_state::get_size() const {
+uint32_t llama_memory_recurrent_context::get_size() const {
     return mem->size;
 }
 
-ggml_tensor * llama_memory_recurrent_state::get_r_l(int32_t il) const {
+ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const {
     return mem->r_l[il];
 }
 
-ggml_tensor * llama_memory_recurrent_state::get_s_l(int32_t il) const {
+ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
     return mem->s_l[il];
 }
 
-int32_t llama_memory_recurrent_state::s_copy(int i) const {
+int32_t llama_memory_recurrent_context::s_copy(int i) const {
     return  mem->cells[i + mem->head].src0;
 }
index be58dae7cfe33184cfe1e08bc9c95c1b9e1ffa49..4d094f9a05788cb3fa18b29bc188de1926578225 100644 (file)
@@ -11,8 +11,8 @@
 // llama_memory_recurrent
 //
 
-// TODO: extract the cache state used for graph computation into llama_memory_recurrent_state_i
-//       see the implementation of llama_kv_cache_unified_state_i for an example how to do it
+// TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
+//       see the implementation of llama_kv_cache_unified_context_i for an example how to do it
 class llama_memory_recurrent : public llama_memory_i {
 public:
 
@@ -34,14 +34,14 @@ public:
     // llama_memory_i
     //
 
-    llama_memory_state_ptr init_batch(
+    llama_memory_context_ptr init_batch(
             llama_batch_allocr & balloc,
             uint32_t n_ubatch,
             bool embd_all) override;
 
-    llama_memory_state_ptr init_full() override;
+    llama_memory_context_ptr init_full() override;
 
-    llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
+    llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
 
     void clear(bool data) override;
 
@@ -125,24 +125,24 @@ private:
     bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
 };
 
-class llama_memory_recurrent_state : public llama_memory_state_i {
+class llama_memory_recurrent_context : public llama_memory_context_i {
 public:
     // used for errors
-    llama_memory_recurrent_state(llama_memory_status status);
+    llama_memory_recurrent_context(llama_memory_status status);
 
-    // used to create a full-cache state
-    llama_memory_recurrent_state(
+    // used to create a full-cache or update context
+    llama_memory_recurrent_context(
             llama_memory_recurrent * mem);
 
-    // used to create a state from a batch
-    llama_memory_recurrent_state(
+    // used to create a batch processing context from a batch
+    llama_memory_recurrent_context(
             llama_memory_recurrent * mem,
             std::vector<llama_ubatch> ubatches);
 
-    virtual ~llama_memory_recurrent_state();
+    virtual ~llama_memory_recurrent_context();
 
     //
-    // llama_memory_state_i
+    // llama_memory_context_i
     //
 
     bool next()  override;
@@ -152,7 +152,7 @@ public:
     const llama_ubatch & get_ubatch() const override;
 
     //
-    // llama_memory_recurrent_state specific API
+    // llama_memory_recurrent_context specific API
     //
 
     uint32_t get_n_rs() const;
index d2ef0c2a3b4aafaae17a2bede37b97340f21b64a..16b7e5ee2484a05b293de97ee3816cb0bcab6239 100644 (file)
@@ -3,7 +3,6 @@
 #include "llama.h"
 
 #include <memory>
-#include <vector>
 
 struct llama_ubatch;
 
@@ -28,23 +27,21 @@ enum llama_memory_status {
     LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
 };
 
-// helper function for combining the status of two memory states
+// helper function for combining the status of two memory contexts
 // useful for implementing hybrid memory types (e.g. iSWA)
 llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
 
-// the interface for managing the memory state during batch processing
+// the interface for managing the memory context during batch processing
 // this interface is implemented per memory type. see:
-//   - llama_kv_cache_unified_state
-//   - llama_kv_cache_unified_iswa_state
+//   - llama_kv_cache_unified_context
+//   - llama_kv_cache_unified_iswa_context
 //   ...
 //
-// the only method that can mutate the memory and the memory state is llama_memory_i::apply()
-//
-// TODO: rename to llama_memory_context_i ?
-struct llama_memory_state_i {
-    virtual ~llama_memory_state_i() = default;
+// the only method that should mutate the memory and the memory context is llama_memory_i::apply()
+struct llama_memory_context_i {
+    virtual ~llama_memory_context_i() = default;
 
-    // consume the current ubatch from the state and proceed to the next one
+    // consume the current ubatch from the context and proceed to the next one
     // return false if we are done
     virtual bool next() = 0;
 
@@ -55,11 +52,11 @@ struct llama_memory_state_i {
     // get the current ubatch
     virtual const llama_ubatch & get_ubatch() const = 0;
 
-    // get the status of the memory state - used for error handling and checking if any updates would be applied
+    // get the status of the memory context - used for error handling and checking if any updates would be applied
     virtual llama_memory_status get_status() const = 0;
 };
 
-using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>;
+using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>;
 
 // general concept of LLM memory
 // the KV cache is a type of LLM memory, but there can be other types
@@ -67,19 +64,19 @@ struct llama_memory_i {
     virtual ~llama_memory_i() = default;
 
     // split the input batch into a set of ubatches and verify that they can fit into the cache
-    // return a state object containing the ubatches and KV cache state required to process them
-    // check the llama_memory_state_i::get_status() for the result
-    virtual llama_memory_state_ptr init_batch(
+    // return a context object containing the ubatches and memory state required to process them
+    // check the llama_memory_context_i::get_status() for the result
+    virtual llama_memory_context_ptr init_batch(
             llama_batch_allocr & balloc,
             uint32_t n_ubatch,
             bool embd_all) = 0;
 
     // simulate full cache, used for allocating worst-case compute buffers
-    virtual llama_memory_state_ptr init_full() = 0;
+    virtual llama_memory_context_ptr init_full() = 0;
 
     // prepare for any pending memory updates, such as shifts, defrags, etc.
     // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
-    virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
+    virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0;
 
     // getters
     virtual bool get_can_shift() const = 0;
index e2c82017f689018106e09e9ef499f9b18f8c5352..9b19da984081eeeb621ced55f4d1a5e8a987086f 100644 (file)
@@ -9171,9 +9171,9 @@ struct llm_build_mamba : public llm_graph_context {
                ggml_tensor * cur,
         const llama_ubatch & ubatch,
                        int   il) const {
-        const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
+        const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
 
-        const auto kv_head = kv_state->get_head();
+        const auto kv_head = mctx_cur->get_head();
 
         const int64_t d_conv  = hparams.ssm_d_conv;
         const int64_t d_inner = hparams.ssm_d_inner;
@@ -9191,8 +9191,8 @@ struct llm_build_mamba : public llm_graph_context {
         GGML_ASSERT(ubatch.equal_seqs);
         GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
 
-        ggml_tensor * conv_states_all = kv_state->get_r_l(il);
-        ggml_tensor * ssm_states_all  = kv_state->get_s_l(il);
+        ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
+        ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
 
         // (ab)using the KV cache to store the states
         ggml_tensor * conv = build_rs(
@@ -11916,7 +11916,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
             ggml_tensor * x_prev,
             const llama_ubatch & ubatch,
             int   il) const {
-        const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
+        const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
 
         const auto n_tokens = ubatch.n_tokens;
         const auto n_seqs = ubatch.n_seqs;
@@ -11926,7 +11926,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
         const auto n_head = n_embd / head_size;
         const auto n_head_kv = hparams.n_head_kv(il);
 
-        const auto kv_head = kv_state->get_head();
+        const auto kv_head = mctx_cur->get_head();
 
         const auto & layer = model.layers[il];
 
@@ -12038,7 +12038,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
         }
 
         ggml_tensor * wkv_state = build_rs(
-                inp, gf, kv_state->get_s_l(il),
+                inp, gf, mctx_cur->get_s_l(il),
                 hparams.n_embd_s(), n_seqs);
 
         ggml_tensor * wkv_output;
@@ -12057,9 +12057,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
                     wkv_state,
                     ggml_view_1d(
                         ctx0,
-                        kv_state->get_s_l(il),
+                        mctx_cur->get_s_l(il),
                         hparams.n_embd_s() * n_seqs,
-                        hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
+                        hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il))
                         )
                     )
                 );
@@ -12313,7 +12313,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
             ggml_tensor *& first_layer_value,
             const llama_ubatch & ubatch,
             int   il) const {
-        const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
+        const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
 
         const auto n_tokens = ubatch.n_tokens;
         const auto n_seqs = ubatch.n_seqs;
@@ -12322,7 +12322,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
         const auto head_count = n_embd / head_size;
         const auto n_seq_tokens = ubatch.n_seq_tokens;
 
-        const auto kv_head = kv_state->get_head();
+        const auto kv_head = mctx_cur->get_head();
 
         const auto & layer = model.layers[il];
 
@@ -12393,7 +12393,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
         a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
 
         ggml_tensor * wkv_state = build_rs(
-                inp, gf, kv_state->get_s_l(il),
+                inp, gf, mctx_cur->get_s_l(il),
                 hparams.n_embd_s(), n_seqs);
 
         ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
@@ -12407,9 +12407,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
                     wkv_state,
                     ggml_view_1d(
                         ctx0,
-                        kv_state->get_s_l(il),
+                        mctx_cur->get_s_l(il),
                         hparams.n_embd_s() * n_seqs,
-                        hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
+                        hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il))
                         )
                     )
                 );