]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cparams : rename LLAMA_MAX_PARALLEL_SEQUENCES to LLAMA_MAX_SEQ (#14188)
authorGeorgi Gerganov <redacted>
Sun, 15 Jun 2025 07:08:58 +0000 (10:08 +0300)
committerGitHub <redacted>
Sun, 15 Jun 2025 07:08:58 +0000 (10:08 +0300)
ggml-ci

src/llama-batch.cpp
src/llama-context.cpp
src/llama-cparams.cpp
src/llama-cparams.h
src/llama-kv-cache-unified.cpp
src/llama-kv-cells.h

index 2265db9b235b8b3f6d61e5ca33156c88fde89348..a9f4a3d4c45c5da6fe8ada2d164d424eea7774e7 100644 (file)
@@ -289,10 +289,10 @@ llama_batch_allocr::llama_batch_allocr() {
     const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
     debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
 
-    seq_pos.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
-    seq_cpl.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
+    seq_pos.resize(LLAMA_MAX_SEQ);
+    seq_cpl.resize(LLAMA_MAX_SEQ);
     for (auto & cur : seq_cpl) {
-        cur.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
+        cur.resize(LLAMA_MAX_SEQ);
     }
 }
 
@@ -322,8 +322,8 @@ bool llama_batch_allocr::init(
     if (batch.seq_id) {
         for (int32_t i = 0; i < batch.n_tokens; ++i) {
             for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
-                if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
-                    LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_PARALLEL_SEQUENCES);
+                if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
+                    LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
                     return false;
                 }
             }
@@ -355,8 +355,8 @@ bool llama_batch_allocr::init(
         pos.resize(batch.n_tokens);
 
         // initialize the starting position for each sequence based on the positions in the memory
-        llama_pos p0[LLAMA_MAX_PARALLEL_SEQUENCES];
-        for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+        llama_pos p0[LLAMA_MAX_SEQ];
+        for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
             if (!memory) {
                 p0[s] = 0;
             } else {
@@ -480,7 +480,7 @@ bool llama_batch_allocr::init(
     // consistency checks
     //
 
-    for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+    for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
         if (seq_pos[s].empty()) {
             continue;
         }
@@ -497,8 +497,8 @@ bool llama_batch_allocr::init(
     }
 
     if (memory) {
-        for (int32_t s0 = 0; s0 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s0) {
-            for (int32_t s1 = 0; s1 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s1) {
+        for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
+            for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
                 if (seq_cpl[s0][s1]) {
                     if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
                         memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
index 47c60e960dc0199f2e034313ff8a2c506a768ec4..3a113d1bcfb2a48df98b2c1150dc717bcdc3ad8c 100644 (file)
@@ -29,8 +29,8 @@ llama_context::llama_context(
     const auto & hparams = model.hparams;
 
     cparams.n_seq_max = std::max(1u, params.n_seq_max);
-    if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) {
-        throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES));
+    if (cparams.n_seq_max > LLAMA_MAX_SEQ) {
+        throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
     }
 
     cparams.n_threads        = params.n_threads;
@@ -1023,8 +1023,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
 
         if (!res) {
             // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
-            llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
-            for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+            llama_pos pos_min[LLAMA_MAX_SEQ];
+            for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
                 pos_min[s] = std::numeric_limits<llama_pos>::max();
             }
 
@@ -1035,7 +1035,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
                 pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
             }
 
-            for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+            for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
                 if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
                     continue;
                 }
index f7b36590fe3e3f0a9dc9dfee9e7f749d2610aee7..a3e7a37ee36d78fd6af7d5e3aba8762ecbb636b6 100644 (file)
@@ -1,5 +1,5 @@
 #include "llama-cparams.h"
 
 size_t llama_max_parallel_sequences(void) {
-    return LLAMA_MAX_PARALLEL_SEQUENCES;
+    return LLAMA_MAX_SEQ;
 }
index 51ebe5d17efa7212b825f47ada5e101883038a9f..118615d5bd2d59f4e640c661592cf91ee0d3fda3 100644 (file)
@@ -4,8 +4,7 @@
 
 #include <cstdint>
 
-// TODO: rename to something shorter
-#define LLAMA_MAX_PARALLEL_SEQUENCES 64
+#define LLAMA_MAX_SEQ 64
 
 struct llama_cparams {
     uint32_t n_ctx;           // context size used during inference
index d4e92eab3a1796c13153faafd705fc8594a458ab..03107057079ca6c1ff3d6a70096a56d62e89c349 100644 (file)
@@ -572,7 +572,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
             LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
         }
 
-        for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+        for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
             if (cells.seq_pos_min(s) < 0) {
                 continue;
             }
@@ -652,8 +652,8 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
 
     // keep track of the max sequence position that we would overwrite with this ubatch
     // for non-SWA cache, this would be always empty
-    llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES];
-    for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+    llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
+    for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
         seq_pos_max_rm[s] = -1;
     }
 
@@ -684,7 +684,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
     // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
     //       will be present in the cache. so we have to purge any position which is less than those we would overwrite
     //       ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
-    for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+    for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
         if (seq_pos_max_rm[s] == -1) {
             continue;
         }
index acf30aebec69b071bc0b36223771e71ce0a5f9fd..1d4e70f4d321249882287e0bf6b1f56f1c8110dc 100644 (file)
@@ -23,7 +23,7 @@ public:
 
         used.clear();
 
-        for (uint32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+        for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
             seq_pos[s].clear();
         }
     }
@@ -240,7 +240,7 @@ public:
     llama_seq_id seq_get(uint32_t i) const {
         assert(seq[i].count() == 1);
 
-        for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+        for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
             if (seq[i].test(s)) {
                 return s;
             }
@@ -253,7 +253,7 @@ public:
     // return -1 if the sequence is not present
     llama_pos seq_pos_min(llama_seq_id seq_id) const {
         assert(seq_id >= 0);
-        assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
+        assert(seq_id < LLAMA_MAX_SEQ);
 
         if (seq_pos[seq_id].empty()) {
             return -1;
@@ -266,7 +266,7 @@ public:
     // return -1 if the sequence is not present
     llama_pos seq_pos_max(llama_seq_id seq_id) const {
         assert(seq_id >= 0);
-        assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
+        assert(seq_id < LLAMA_MAX_SEQ);
 
         if (seq_pos[seq_id].empty()) {
             return -1;
@@ -384,20 +384,20 @@ private:
     //
     std::vector<llama_pos> shift;
 
-    using bits_t = std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>;
+    using bits_t = std::bitset<LLAMA_MAX_SEQ>;
 
     // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
     std::vector<bits_t> seq;
 
     // the set seq_pos[s] tells us which positions are currently present for sequence s
     // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
-    std::set<llama_pos> seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES];
+    std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
 
     // helper functions for updating `seq_pos`, once cell at a time:
 
     // remove cell i
     void seq_pos_rm(uint32_t i) {
-        for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+        for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
             if (seq[i].test(s)) {
                 seq_pos[s].erase(pos[i]);
             }
@@ -406,7 +406,7 @@ private:
 
     // add cell i
     void seq_pos_add(uint32_t i) {
-        for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+        for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
             if (seq[i].test(s)) {
                 seq_pos[s].insert(pos[i]);
             }