]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
kv-cache : fix M-RoPE checkpoints (#20132)
authorGeorgi Gerganov <redacted>
Fri, 6 Mar 2026 06:46:51 +0000 (08:46 +0200)
committerGitHub <redacted>
Fri, 6 Mar 2026 06:46:51 +0000 (08:46 +0200)
src/llama-batch.cpp
src/llama-kv-cache.cpp

index 386fab04ac9c7a01dc77ef0bb7b4fcd35d81e548..6bf76939cddcbb7c06232ad7e88bf47e5bd57bc0 100644 (file)
@@ -394,11 +394,13 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
     clear();
     split_reset();
 
+    const int64_t n_pos_all = (int64_t) n_tokens*n_pos_per_embd;
+
     auto udata = std::make_shared<llama_ubatch::data_t>();
 
     udata->token     .resize(n_tokens);
     udata->embd      .clear();
-    udata->pos       .resize(n_tokens);
+    udata->pos       .resize(n_pos_all);
     udata->n_seq_id  .resize(n_tokens);
     udata->seq_id    .resize(n_tokens);
     udata->seq_id_unq.resize(0);
index 4031bafe9ecf101c858e57ae7cf5ea339ac604d2..d80e8a70bc277301539757d97d1206a233290760 100644 (file)
@@ -1760,8 +1760,10 @@ void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t
             io.write(&pos,      sizeof(pos));
             io.write(&n_seq_id, sizeof(n_seq_id));
 
-            // TODO: we also need to save llama_kv_cell_ext when apply_ubatch() support loading it
-            //       see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
+            if (hparams.n_pos_per_embd() > 1) {
+                const llama_kv_cell_ext ext = cells.ext_get(i);
+                io.write(&ext, sizeof(ext));
+            }
 
             for (const auto & seq_id : seq_ids) {
                 io.write(&seq_id, sizeof(seq_id));
@@ -1895,6 +1897,14 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
                 return false;
             }
 
+            if (hparams.n_pos_per_embd() > 1) {
+                llama_kv_cell_ext ext;
+                io.read_to(&ext, sizeof(ext));
+
+                ubatch.pos[i + ubatch.n_tokens]   = ext.y;
+                ubatch.pos[i + ubatch.n_tokens*2] = ext.x;
+            }
+
             // read the sequence id, but directly discard it - we will use dest_seq_id instead
             {
                 llama_seq_id seq_id;