]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama: store mrope data in KV cell (#16825)
authorXuan-Son Nguyen <redacted>
Wed, 29 Oct 2025 17:09:18 +0000 (18:09 +0100)
committerGitHub <redacted>
Wed, 29 Oct 2025 17:09:18 +0000 (18:09 +0100)
* llama: store mrope data in KV cell

* correct x,y ordering

* address review comments

* add consistency checks

* Update src/llama-kv-cache.cpp

Co-authored-by: Georgi Gerganov <redacted>
* add TODO

* fix asan error

* kv-cells : improve ext handling

* cont : fix headers

---------

Co-authored-by: Georgi Gerganov <redacted>
src/llama-batch.cpp
src/llama-batch.h
src/llama-kv-cache.cpp
src/llama-kv-cells.h
tools/mtmd/mtmd.cpp
tools/mtmd/mtmd.h

index 55d89eca0ad94938d471cdd1f6d714850e0ead79..6cb118f684e40200ee0a947d7dd2831e55b71c79 100644 (file)
@@ -215,6 +215,7 @@ bool llama_batch_allocr::init(
             /*.n_seq_tokens =*/ (uint32_t) 1,
             /*.n_seqs       =*/ (uint32_t) batch.n_tokens,
             /*.n_seqs_unq   =*/ (uint32_t) this->seq_id_unq.size(),
+            /*.n_pos        =*/ n_pos_per_embd,
             /*.token        =*/ batch.token,
             /*.embd         =*/ batch.embd,
             /*.pos          =*/ batch.pos,
@@ -251,46 +252,58 @@ bool llama_batch_allocr::init(
     // consistency checks
     //
 
-    for (uint32_t s = 0; s < n_seq_max; ++s) {
-        if (seq_pos[s].empty()) {
-            continue;
+    if (n_pos_per_embd > 1) {
+        // M-RoPE case: allow position to "jump" forward only (non-continuous positions are allowed)
+        for (uint32_t s = 0; s < n_seq_max; ++s) {
+            if (seq_pos[s].empty()) {
+                continue;
+            }
+
+            const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
+
+            if (p0 >= 0 && p0 >= seq_pos_min(s)) {
+                LLAMA_LOG_ERROR(
+                        "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
+                        " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
+                        " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
+                        " for M-RoPE, it is required that the position satisfies: X < Y\n",
+                        __func__, s, s, p0, s, seq_pos_min(s));
+
+                return false;
+            }
         }
+    } else {
+        for (uint32_t s = 0; s < n_seq_max; ++s) {
+            if (seq_pos[s].empty()) {
+                continue;
+            }
 
-        const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
+            const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
 
-        if (p0 >= 0) {
-            bool ok = true;
+            if (p0 >= 0) {
+                bool ok = true;
 
-            if (batch.token) {
                 if (seq_pos_min(s) != p0 + 1) {
                     ok = false;
                 }
-            } else {
-                assert(batch.embd);
 
-                // for embeddings (typically used as vision input), we allow them to have repeating positions
-                // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
-                if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
-                    ok = false;
+                if (!ok) {
+                    LLAMA_LOG_ERROR(
+                            "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
+                            " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
+                            " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
+                            " it is required that the sequence positions remain consecutive: Y = X + 1\n",
+                            __func__, s, s, p0, s, seq_pos_min(s));
+
+                    return false;
                 }
             }
 
-            if (!ok) {
-                LLAMA_LOG_ERROR(
-                        "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
-                        " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
-                        " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
-                        " it is required that the sequence positions remain consecutive: Y = X + 1\n",
-                        __func__, s, s, p0, s, seq_pos_min(s));
-
+            if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
+                LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
                 return false;
             }
         }
-
-        if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
-            LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
-            return false;
-        }
     }
 
     if (memory) {
@@ -389,6 +402,7 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
         /*.n_seq_tokens =*/ n_seq_tokens,
         /*.n_seqs       =*/ n_seqs,
         /*.n_seqs_unq   =*/ n_seqs,
+        /*.n_pos        =*/ n_pos_per_embd,
 
         /*.token        =*/ udata->token.data(),
         /*.embd         =*/ nullptr,
@@ -710,6 +724,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
         /*.n_seq_tokens =*/ n_tokens/n_seqs,
         /*.n_seqs       =*/ n_seqs,
         /*.n_seqs_unq   =*/ (uint32_t) udata->seq_id_unq.size(),
+        /*.n_pos        =*/ n_pos_per_embd,
 
         /*.token        =*/ batch.token ? udata->token.data() : nullptr,
         /*.embd         =*/ batch.embd ? udata->embd.data() : nullptr,
index 0dc8cebd2a7b3045022172dd670cd0e13b20389e..209cf3699de23b07b993b0971e4ca27aaa897b0b 100644 (file)
@@ -17,6 +17,16 @@ struct llama_ubatch {
         return b_equal_seqs != 0;
     }
 
+    // typical for M-RoPE cases:
+    //   0 - sequantial position of the tokens/embeddings in the sequence
+    //   1 - y position in the image
+    //   2 - x position in the image
+    //   3 - other
+    bool is_pos_2d() const {
+        // TODO @ngxson : we may need to check for model arch when more models use >1 positions
+        return n_pos >= 3;
+    }
+
     uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
                            //       otherwise address sanitizer complains
     // TODO: whole_seqs for embeddings?
@@ -25,6 +35,7 @@ struct llama_ubatch {
     uint32_t n_seq_tokens; // tokens per sequence set
     uint32_t n_seqs;       // sequence sets in the ubatch
     uint32_t n_seqs_unq;   // unique sequence ids in the ubatch
+    uint32_t n_pos;        // number of position inputs for each token/embedding
 
     // seq_id_unq: unique sequence ids in the ubatch
     // seq_idx:    indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
@@ -33,7 +44,7 @@ struct llama_ubatch {
     //                          // size               | idx | val
     llama_token  *  token;      // [n_tokens]         | i   | id, token
     float        *  embd;       // [n_embd, n_tokens] | i   | embd
-    llama_pos    *  pos;        // [n_tokens]         | i   | pos
+    llama_pos    *  pos;        // [n_tokens*n_pos]   | i   | pos
     int32_t      *  n_seq_id;   // [n_tokens]         | i   | -
     llama_seq_id ** seq_id;     // [n_tokens]         | s   | s0, s1, seq_id
     llama_seq_id *  seq_id_unq; // [n_seqs_unq]       | s   | seq_id
index 6d5dd6051e782e386b8b3ee057d442f160828e27..17627b6ccbb1ed39d21fb16e57ff514e927b6723 100644 (file)
@@ -338,6 +338,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
             llama_pos pos   = v_cells[s0].pos_get(i);
             llama_pos shift = v_cells[s0].get_shift(i);
 
+            llama_kv_cell_ext ext = v_cells[s0].ext_get(i);
+
             if (shift != 0) {
                 pos -= shift;
                 assert(pos >= 0);
@@ -349,6 +351,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
             if (shift != 0) {
                 v_cells[s1].pos_add(i, shift);
             }
+
+            v_cells[s1].ext_set(i, ext);
         }
     }
 
@@ -383,6 +387,7 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
 
 void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
     GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
+    GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1");
 
     auto & cells = v_cells[seq_to_stream[seq_id]];
     auto & head  = v_heads[seq_to_stream[seq_id]];
@@ -427,6 +432,7 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll
 
 void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
     GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
+    GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1");
 
     auto & cells = v_cells[seq_to_stream[seq_id]];
 
@@ -900,6 +906,14 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch &
 
             cells.pos_set(idx, ubatch.pos[i]);
 
+            if (ubatch.is_pos_2d()) {
+                llama_kv_cell_ext ext {
+                    /*.x =*/ ubatch.pos[i + ubatch.n_tokens*2],
+                    /*.y =*/ ubatch.pos[i + ubatch.n_tokens],
+                };
+                cells.ext_set(idx, ext);
+            }
+
             for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
                 cells.seq_add(idx, ubatch.seq_id[i][s]);
             }
@@ -1247,6 +1261,11 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
 
                 const llama_pos p1 = ubatch->pos[i];
 
+                // for M-RoPE
+                const bool is_2d = ubatch->is_pos_2d();
+                const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
+                const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens]   : 0;
+
                 const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
 
                 for (uint32_t j = 0; j < n_kv; ++j) {
@@ -1266,6 +1285,14 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
                         continue;
                     }
 
+                    // M-RoPE causal mask
+                    if (causal_attn && is_2d && p0 == p1) {
+                        const auto & p0_ext = cells.ext_get(j);
+                        if (p0_ext.is_2d_gt(p1_x, p1_y)) {
+                            continue;
+                        }
+                    }
+
                     // apply SWA if any
                     if (is_masked_swa(p0, p1)) {
                         continue;
@@ -1559,6 +1586,9 @@ void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t
             io.write(&pos,      sizeof(pos));
             io.write(&n_seq_id, sizeof(n_seq_id));
 
+            // TODO: we also need to save llama_kv_cell_ext when apply_ubatch() support loading it
+            //       see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
+
             for (const auto & seq_id : seq_ids) {
                 io.write(&seq_id, sizeof(seq_id));
             }
@@ -1704,6 +1734,8 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
             return false;
         }
 
+        // TODO: we cannot yet restore llama_kv_cell_ext as the apply_ubatch() does not support it yet
+        //       see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
         apply_ubatch(sinfo, ubatch);
 
         const auto head_cur = sinfo.head();
index 8f6bf01456c8fb734230b00170f6d2b84b869d89..10063bf4272ef7e9583751c49d488c2130b40884 100644 (file)
@@ -5,9 +5,27 @@
 
 #include <bitset>
 #include <cassert>
-#include <vector>
-#include <set>
+#include <cstring>
 #include <map>
+#include <set>
+#include <vector>
+
+struct llama_kv_cell_ext {
+    // 2D spatial positions, typically used for M-RoPE
+    llama_pos x = 0;
+    llama_pos y = 0;
+
+    // return true if the current 2D spatial position is greater than other
+    bool is_2d_gt(llama_pos ox, llama_pos oy) const {
+        return (y > oy) || (y == oy && x > ox);
+    }
+
+    void reset() {
+        static_assert(std::is_trivially_copyable_v<llama_kv_cell_ext>);
+
+        memset(this, 0, sizeof(*this));
+    }
+};
 
 // meta information about KV cells that can be part of multiple sequences at the same time
 // TODO: add unit tests
@@ -16,6 +34,7 @@ public:
     void reset() {
         for (uint32_t i = 0; i < pos.size(); ++i) {
             pos[i]   = -1;
+            ext[i].reset();
             shift[i] =  0;
             seq[i].reset();
         }
@@ -43,6 +62,7 @@ public:
 
     void resize(uint32_t n) {
         pos.resize(n);
+        ext.resize(n);
         shift.resize(n);
         seq.resize(n);
 
@@ -108,6 +128,7 @@ public:
             const auto idx = i + j;
 
             res.pos[j] = pos[idx];
+            res.ext[j] = ext[idx];
             res.seq[j] = seq[idx];
 
             assert(shift[idx] == 0);
@@ -126,6 +147,7 @@ public:
             const auto idx = idxs[j];
 
             res.pos[j] = pos[idx];
+            res.ext[j] = ext[idx];
             res.seq[j] = seq[idx];
 
             assert(shift[idx] == 0);
@@ -154,6 +176,7 @@ public:
             }
 
             pos[idx] = other.pos[j];
+            ext[idx] = other.ext[j];
             seq[idx] = other.seq[j];
 
             if (pos[idx] != -1) {
@@ -184,6 +207,7 @@ public:
             }
 
             pos[idx] = other.pos[j];
+            ext[idx] = other.ext[j];
             seq[idx] = other.seq[j];
 
             if (pos[idx] != -1) {
@@ -203,6 +227,7 @@ public:
         seq[i].reset();
 
         pos[i] = -1;
+        ext[i].reset();
         shift[i] = 0;
 
         used.erase(i);
@@ -221,6 +246,7 @@ public:
 
         if (seq[i].none()) {
             pos[i] = -1;
+            ext[i].reset();
             shift[i] = 0;
 
             used.erase(i);
@@ -250,6 +276,7 @@ public:
             seq[i].reset();
 
             pos[i] = -1;
+            ext[i].reset();
             shift[i] = 0;
 
             used.erase(i);
@@ -340,6 +367,13 @@ public:
         return pos[i];
     }
 
+    const llama_kv_cell_ext & ext_get(uint32_t i) const {
+        assert(i < pos.size());
+        assert(pos[i] != -1);
+
+        return ext[i];
+    }
+
     // note: call only if the cell is not empty
     llama_pos get_shift(uint32_t i) const {
         assert(i < pos.size());
@@ -368,6 +402,11 @@ public:
         used.insert(i);
     }
 
+    void ext_set(uint32_t i, llama_kv_cell_ext p) {
+        assert(i < ext.size());
+        ext[i] = p;
+    }
+
     // pos[i] = pos[i] + d
     // sets "has_shift" to true
     // note: call only if the cell is not empty
@@ -424,6 +463,9 @@ private:
 
     std::vector<llama_pos> pos;
 
+    // stores extra info per cell
+    std::vector<llama_kv_cell_ext> ext;
+
     // this array accumulates any applied shifts to the pos array since the last reset_shift() call
     // this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
     //
index 3b901bfac82159e3b52787ae27f11507d1815440..1a48f903efac93022b3f32825ff0ba56d0495a08 100644 (file)
@@ -5,6 +5,15 @@
 
 #include "llama.h"
 
+// fix problem with std::min and std::max
+#if defined(_WIN32)
+#define WIN32_LEAN_AND_MEAN
+#ifndef NOMINMAX
+#   define NOMINMAX
+#endif
+#include <windows.h>
+#endif
+
 #include <algorithm>
 #include <cerrno>
 #include <cstdio>
@@ -1031,7 +1040,9 @@ const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
 
 llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
     if (image_tokens->use_mrope_pos) {
-        return 1; // for M-RoPE, the whole image is 1 in temporal dimension
+        // for M-RoPE, temporal dimension = max(t,h,w)
+        // t is omitted as we don't support video input
+        return std::max(image_tokens->nx, image_tokens->ny);
     }
     return image_tokens->n_tokens();
 }
index f4ea07d3ad521328c8f7e6987912dc9db60bcf1d..0b5d2ba0c763418fb0d51b6291b51cc46c8b0ca7 100644 (file)
@@ -153,7 +153,7 @@ MTMD_API const mtmd_image_tokens *  mtmd_input_chunk_get_tokens_image(const mtmd
 MTMD_API size_t                     mtmd_input_chunk_get_n_tokens    (const mtmd_input_chunk * chunk);
 // returns nullptr for ID on text chunk
 MTMD_API const char *               mtmd_input_chunk_get_id          (const mtmd_input_chunk * chunk);
-// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
+// number of temporal positions (equals to max(t,h,w) for M-RoPE; equals to n_tokens otherwise)
 MTMD_API llama_pos                  mtmd_input_chunk_get_n_pos       (const mtmd_input_chunk * chunk);
 
 // in case you want to use custom logic to handle the chunk (i.e. KV cache management)
@@ -171,7 +171,7 @@ MTMD_API size_t       mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * i
 MTMD_API size_t       mtmd_image_tokens_get_nx      (const mtmd_image_tokens * image_tokens);
 MTMD_API size_t       mtmd_image_tokens_get_ny      (const mtmd_image_tokens * image_tokens);
 MTMD_API const char * mtmd_image_tokens_get_id      (const mtmd_image_tokens * image_tokens); // TODO: deprecate
-// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
+// number of temporal positions (equals to max(t,h,w) for M-RoPE; equals to n_tokens otherwise)
 MTMD_API llama_pos    mtmd_image_tokens_get_n_pos   (const mtmd_image_tokens * image_tokens); // TODO: deprecate
 
 // tokenize an input text prompt and a list of bitmaps (images/audio)