]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : remove write/read of output ids/logits/embeddings (#18862)
authorDaniel Bevenius <redacted>
Mon, 23 Feb 2026 06:04:30 +0000 (07:04 +0100)
committerGitHub <redacted>
Mon, 23 Feb 2026 06:04:30 +0000 (07:04 +0100)
* llama : remove write/read of output ids/logits/embeddings

This commit removes the write/read of output ids, logits and
embeddings from the llama context state.

Refs: https://github.com/ggml-org/llama.cpp/pull/18862#issuecomment-3756330941

* completion : add replying of session state

This commit updates the session handing in the completion tool to handle
the that logits are no longer stored in the session file. Instead, we
need to replay the last token to get the logits for sampling.

* common : add common_prompt_batch_decode function

This commit adds a new function which is responsible for decoding prompt
and optionally handle the saving for session data.

* update save-state.cpp to use llama_state_load_file

This commit updates the save-load-state example to utilize the new
llama_state_load_file function for loading the model state from a file.
And it also replays the last token after loading since this state is now
stored before the last token is processed.

* examples : set n_seq_max = 2 for ctx3

This commit updates the save-load-state example to set the n_seq_max
parameter to 2 when initializing the ctx3 context.

The motivation for this change is that using 1 as n_parallel/n_seq_max
the context only supports one sequence, but the test laster tries to
use a second sequence which results in the following error:
```console
main : loaded state with 4 tokens
main : seq 0 copied, 225760 bytes
main : kv cache cleared
find_slot: seq_id=1 >= n_seq_max=1 Try using a bigger --parallel value
state_read_meta: failed to find available cells in kv cache
```
This seems to only happen for recurrent/hybrid models.

common/common.cpp
common/common.h
examples/save-load-state/save-load-state.cpp
src/llama-context.cpp
tools/completion/completion.cpp

index 75116ed6f3e6454e7cbaf4b2d280f4a33d035ef7..53bddc4ef2f53f3983eabb2bc27ca96395ca06f3 100644 (file)
@@ -1760,3 +1760,65 @@ float lr_opt::get_lr(float epoch) const {
     LOG_INF("epoch %.2g lr=%.2g\n", epoch, r);
     return r;
 }
+
+bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos) {
+    llama_batch batch = llama_batch_get_one(&last_token, 1);
+    batch.pos = &pos;
+    if (llama_decode(ctx, batch)) {
+        LOG_ERR("%s: failed to replay last token\n", __func__);
+        return false;
+    }
+    return true;
+}
+
+bool common_prompt_batch_decode(
+              struct llama_context * ctx,
+    const std::vector<llama_token> & tokens,
+                               int & n_past,
+                               int   n_batch,
+                  std::string_view   state_path,
+                              bool   save_state) {
+    const int n_eval = tokens.size();
+    if (n_eval == 0) {
+        return true;
+    }
+
+    if (save_state && n_eval > 1) {
+        const int n_tokens_before_last = n_eval - 1;
+
+        GGML_ASSERT(n_eval <= n_batch);
+
+        // Decode all but the last token so we can save the memory state before decoding the last token.
+        // This is done so we can restore the session state later and replay the last token.
+        // Memory implementations in recurrent/hybrid models don't support removing tokens from their
+        // memory, so we can't just remove the last token from the memory and replay the last token which
+        // is the reason for this logic.
+        if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(tokens.data()), n_tokens_before_last))) {
+            LOG_ERR("%s : failed to eval\n", __func__);
+            return false;
+        }
+        n_past += n_tokens_before_last;
+
+        llama_state_save_file(ctx, state_path.data(), tokens.data(), n_tokens_before_last);
+        LOG_INF("saved session before last token to %s, n_tokens = %d\n", state_path.data(), n_tokens_before_last);
+
+        llama_token last_token = tokens.back();
+        llama_batch batch = llama_batch_get_one(&last_token, 1);
+        int32_t pos = n_past;
+        batch.pos = &pos;
+
+        if (llama_decode(ctx, batch)) {
+            LOG_ERR("%s : failed to eval last token\n", __func__);
+            return false;
+        }
+        n_past++;
+    } else {
+        if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(tokens.data()), n_eval))) {
+            LOG_ERR("%s : failed to eval\n", __func__);
+            return false;
+        }
+        n_past += n_eval;
+    }
+
+    return true;
+}
index a4c431172d0cab926121fb985abaed2e0911818d..1fa1728656204f15cf79d17c2cd813b6510c3bdf 100644 (file)
@@ -804,6 +804,23 @@ void common_batch_add(
     const std::vector<llama_seq_id> & seq_ids,
                                bool   logits);
 
+// decodes a single batch of tokens for a prompt and manages session tokens
+//
+// Note: We save state before the last token so that we can replay it to ensure
+// compatibility with all memory types. Recurrent/hybrid models cannot remove
+// tokens from memory, so this approach works across all model architectures.
+bool common_prompt_batch_decode(
+              struct llama_context * ctx,
+    const std::vector<llama_token> & embd,
+                               int & n_past,
+                               int   n_batch,
+                  std::string_view   state_path,
+                              bool   save_state);
+
+// replays the last token after loading state to regenerate logits
+// used after loading session state to ensure the sampling context has valid logits
+bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos);
+
 //
 // Vocab utils
 //
index 39d4464663dc6ce0582b633663494081637a53a2..5e35dcd6030666d4b537ade803aee7542aa8935d 100644 (file)
@@ -5,12 +5,15 @@
 #include <vector>
 #include <cstdio>
 
+
 int main(int argc, char ** argv) {
     common_params params;
 
     params.prompt = "The quick brown fox";
     params.sampling.seed = 1234;
 
+    const std::string_view state_file = "dump_state.bin";
+
     if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
         return 1;
     }
@@ -53,35 +56,16 @@ int main(int argc, char ** argv) {
     // tokenize prompt
     auto tokens = common_tokenize(ctx, params.prompt, true);
 
-    // prepare the batch
-    llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
-    for (size_t i = 0; i < tokens.size(); i++) {
-        common_batch_add(batch, tokens[i], i, {0}, false);
-    }
-    batch.logits[batch.n_tokens - 1] = true; // generate next token
-
-    // evaluate prompt
-    llama_decode(ctx, batch);
-    n_past += batch.n_tokens;
-
-    // save state (rng, logits, embedding and kv_cache) to file
-    {
-        std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
-        const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size());
-
-        FILE *fp_write = fopen("dump_state.bin", "wb");
-        fwrite(state_mem.data(), 1, written, fp_write);
-        fclose(fp_write);
-
-        fprintf(stderr, "%s : serialized state into %zd out of a maximum of %zd bytes\n", __func__, written, state_mem.size());
+    const bool save_state = true;
+    if (!common_prompt_batch_decode(ctx, tokens, n_past, params.n_batch, state_file, save_state)) {
+        return 1;
     }
 
-    // save state (last tokens)
-    const auto n_past_saved = n_past;
-
     // first run
     printf("\nfirst run: %s", params.prompt.c_str());
 
+    llama_batch batch = llama_batch_init(1, 0, 1);
+
     for (auto i = 0; i < params.n_predict; i++) {
         auto next_token     = llama_sampler_sample(smpl, ctx, -1);
         auto next_token_str = common_token_to_piece(ctx, next_token);
@@ -111,27 +95,23 @@ int main(int argc, char ** argv) {
 
     printf("\nsecond run: %s", params.prompt.c_str());
 
-    // load state (rng, logits, embedding and kv_cache) from file
-    {
-        std::vector<uint8_t> state_mem;
-
-        FILE * fp_read = fopen("dump_state.bin", "rb");
-        fseek(fp_read, 0, SEEK_END);
-        state_mem.resize(ftell(fp_read));
-        fseek(fp_read, 0, SEEK_SET);
-        const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
-        fclose(fp_read);
-
-        if (read != llama_state_set_data(ctx2, state_mem.data(), state_mem.size())) {
-            fprintf(stderr, "\n%s : failed to read state\n", __func__);
-            return 1;
-        }
+    // load state from file
+    std::vector<llama_token> unused_sts(tokens.size()); // unused session tokens.
+    size_t n_token_count_out = 0;
 
-        fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
+    if (!llama_state_load_file(ctx2, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
+        fprintf(stderr, "\n%s : failed to load state\n", __func__);
+        return 1;
     }
 
+    fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out);
+
     // restore state (last tokens)
-    n_past = n_past_saved;
+    n_past = n_token_count_out;
+    if (!common_replay_last_token(ctx2, tokens.back(), n_past)) {
+        return 1;
+    }
+    ++n_past;
 
     // second run
     for (auto i = 0; i < params.n_predict; i++) {
@@ -160,7 +140,9 @@ int main(int argc, char ** argv) {
     }
 
     // make new context
-    llama_context * ctx3 = llama_init_from_model(model, common_context_params_to_llama(params));
+    auto params_ctx3 = common_context_params_to_llama(params);
+    params_ctx3.n_seq_max = 2;
+    llama_context * ctx3 = llama_init_from_model(model, params_ctx3);
 
     llama_sampler * smpl3 = llama_sampler_chain_init(sparams);
 
@@ -169,26 +151,21 @@ int main(int argc, char ** argv) {
     printf("\nsingle seq run: %s", params.prompt.c_str());
 
     // load state (rng, logits, embedding and kv_cache) from file
-    {
-        std::vector<uint8_t> state_mem;
-
-        FILE * fp_read = fopen("dump_state.bin", "rb");
-        fseek(fp_read, 0, SEEK_END);
-        state_mem.resize(ftell(fp_read));
-        fseek(fp_read, 0, SEEK_SET);
-        const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
-        fclose(fp_read);
+    n_token_count_out = 0;
 
-        if (read != llama_state_set_data(ctx3, state_mem.data(), state_mem.size())) {
-            fprintf(stderr, "\n%s : failed to read state\n", __func__);
-            return 1;
-        }
-
-        fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
+    if (!llama_state_load_file(ctx3, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
+        fprintf(stderr, "\n%s : failed to load state\n", __func__);
+        return 1;
     }
 
+    fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out);
+
     // restore state (last tokens)
-    n_past = n_past_saved;
+    n_past = n_token_count_out;
+    if (!common_replay_last_token(ctx3, tokens.back(), n_past)) {
+        return 1;
+    }
+    ++n_past;
 
     // save seq 0 and load into seq 1
     {
index 7cd0bfc0d2405e3176a6138d614f60ecce9f6a64..98d055d34ef70a8b86aa8cec867a922a9f442328 100644 (file)
@@ -2440,64 +2440,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
         // TODO: add more model-specific info which should prevent loading the session file if not identical
     }
 
-    // write output ids
-    {
-        LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
-
-        const auto n_outputs    = this->n_outputs;
-        const auto & output_ids = this->output_ids;
-
-        std::vector<int32_t> w_output_pos;
-
-        w_output_pos.resize(n_outputs);
-
-        // build a more compact representation of the output ids
-        for (size_t i = 0; i < n_batch(); ++i) {
-            // map an output id to a position in the batch
-            int64_t pos = output_ids[i];
-            if (pos >= 0) {
-                GGML_ASSERT(pos < n_outputs);
-                w_output_pos[pos] = i;
-            }
-        }
-
-        io.write(&n_outputs, sizeof(n_outputs));
-
-        if (n_outputs) {
-            io.write(w_output_pos.data(), n_outputs * sizeof(int32_t));
-        }
-    }
-
-    // [TAG_CONTEXT_STATE_LOGITS]
-    // write logits
-    {
-        LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
-
-        const uint64_t logits_size = std::min((uint64_t) this->logits.size, (uint64_t) n_outputs * model.vocab.n_tokens());
-
-        io.write(&logits_size, sizeof(logits_size));
-
-        if (logits_size) {
-            io.write(logits.data, logits_size * sizeof(float));
-        }
-    }
-
-    // write embeddings
-    {
-        LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__);
-
-        const uint64_t embd_size = std::min((uint64_t) this->embd.size, (uint64_t) n_outputs * model.hparams.n_embd);
-
-        io.write(&embd_size, sizeof(embd_size));
-
-        if (embd_size) {
-            io.write(embd.data, embd_size * sizeof(float));
-        }
-    }
-
-    // TODO: handle sampling buffers and samplers state ?
-    //       https://github.com/ggml-org/llama.cpp/pull/17004
-
     if (memory != nullptr) {
         LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
         memory->state_write(io);
@@ -2523,70 +2465,6 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
         // TODO: add more info which needs to be identical but which is not verified otherwise
     }
 
-    // read output ids
-    {
-        LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__);
-
-        auto n_outputs = this->n_outputs;
-        io.read_to(&n_outputs, sizeof(n_outputs));
-
-        if (n_outputs > output_reserve(n_outputs)) {
-            throw std::runtime_error("could not reserve outputs");
-        }
-
-        std::vector<int32_t> output_pos;
-
-        if (n_outputs) {
-            output_pos.resize(n_outputs);
-            io.read_to(output_pos.data(), n_outputs * sizeof(int32_t));
-
-            for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
-                int32_t id = output_pos[i];
-                if ((uint32_t) id >= n_batch()) {
-                    throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch()));
-                }
-                this->output_ids[id] = i;
-            }
-
-            this->n_outputs = n_outputs;
-        }
-    }
-
-    // read logits
-    {
-        LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__);
-
-        uint64_t logits_size;
-        io.read_to(&logits_size, sizeof(logits_size));
-
-        if (this->logits.size < logits_size) {
-            throw std::runtime_error("logits buffer too small");
-        }
-
-        if (logits_size) {
-            io.read_to(this->logits.data, logits_size * sizeof(float));
-        }
-    }
-
-    // read embeddings
-    {
-        LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__);
-
-        uint64_t embd_size;
-        io.read_to(&embd_size, sizeof(embd_size));
-
-        if (this->embd.size < embd_size) {
-            throw std::runtime_error("embeddings buffer too small");
-        }
-
-        if (embd_size) {
-            io.read_to(this->embd.data, embd_size * sizeof(float));
-        }
-    }
-
-    // TODO: handle sampling buffers and samplers state ?
-    //       https://github.com/ggml-org/llama.cpp/pull/17004
-
     if (memory) {
         LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
 
index 977132756f7f9bb084782a9665610832fd2f06d8..aed2c0e38fb45e7bda66a4ee905e6f7564d90230 100644 (file)
@@ -387,6 +387,17 @@ int main(int argc, char ** argv) {
         }
 
         session_do_save = !path_session.empty() && n_match < embd_inp.size() && !params.prompt_cache_ro;
+
+        // Logits are not stored as part of the session state so we need to
+        // "replay" the last token to get logits for sampling.
+        if (!session_tokens.empty() && n_match > 0 && n_match == session_tokens.size()) {
+            if (!common_replay_last_token(ctx, session_tokens.back(), n_match)) {
+                return 1;
+            }
+
+            session_do_save = false;
+            LOG_INF("%s: replayed last token from session\n", __func__);
+        }
     }
 
     // number of tokens to keep when resetting context
@@ -675,40 +686,27 @@ int main(int argc, char ** argv) {
             }
 
             if (!embd.empty()) {
-                int n_eval = (int) embd.size();
-                LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
-
-                GGML_ASSERT(n_eval <= params.n_batch);
-                if (llama_decode(ctx, llama_batch_get_one(embd.data(), n_eval))) {
-                    LOG_ERR("%s : failed to eval\n", __func__);
+                const bool is_last_batch = (n_consumed >= (int) embd_inp.size());
+                const bool save_now = session_do_save && is_last_batch;
+                if (!common_prompt_batch_decode(ctx, embd, n_past, params.n_batch, path_session, save_now)) {
                     return 1;
                 }
-
-                n_past += n_eval;
+                session_tokens.insert(session_tokens.end(), embd.begin(), embd.begin());
+                n_session_consumed = session_tokens.size();
+                session_do_save = false;
 
                 LOG_DBG("n_past = %d\n", n_past);
+
                 // Display total tokens alongside total time
                 if (params.n_print > 0 && n_past % params.n_print == 0) {
                     LOG_DBG("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx);
                 }
             }
-
-            if (!embd.empty() && !path_session.empty()) {
-                session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
-                n_session_consumed = session_tokens.size();
-            }
         }
 
         embd.clear();
 
         if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
-            // optionally save the session on first sample (for faster prompt loading next time)
-            if (session_do_save) {
-                session_do_save = false;
-                llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
-
-                LOG_DBG("saved session to %s\n", path_session.c_str());
-            }
 
             const llama_token id = common_sampler_sample(smpl, ctx, -1);