]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
batch : fix sequence id ownership (#17915)
authorGeorgi Gerganov <redacted>
Thu, 11 Dec 2025 12:29:47 +0000 (14:29 +0200)
committerGitHub <redacted>
Thu, 11 Dec 2025 12:29:47 +0000 (14:29 +0200)
* batch : fix sequence id ownage

* cont : reduce allocations

src/llama-batch.cpp
src/llama-batch.h

index 86a1a4ba187eef004e17f5a5185367b889470520..386fab04ac9c7a01dc77ef0bb7b4fcd35d81e548 100644 (file)
@@ -695,6 +695,8 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
     udata->seq_idx   .resize(LLAMA_MAX_SEQ, -1);
     udata->output    .resize(n_tokens);
 
+    udata->seq_id_data.reserve(n_tokens);
+
     seq_set_t seq_set_unq;
 
     for (size_t i = 0; i < idxs.size(); ++i) {
@@ -716,11 +718,13 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
         }
 
         udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
-        udata->seq_id[i]   = batch.seq_id[idxs[i]];
         udata->output[i]   = batch.logits[idxs[i]];
 
         for (int s = 0; s < udata->n_seq_id[i]; ++s) {
-            seq_set_unq.set(udata->seq_id[i][s]);
+            const llama_seq_id seq_id = batch.seq_id[idxs[i]][s];
+
+            udata->seq_id_data.push_back(seq_id);
+            seq_set_unq.set(seq_id);
         }
 
         if (udata->output[i]) {
@@ -728,6 +732,12 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
         }
     }
 
+    llama_seq_id * seq_id_ptr = udata->seq_id_data.data();
+    for (size_t i = 0; i < idxs.size(); ++i) {
+        udata->seq_id[i] = seq_id_ptr;
+        seq_id_ptr += udata->n_seq_id[i];
+    }
+
     for (uint32_t s = 0; s < n_seq_max; ++s) {
         if (seq_set_unq.test(s)) {
             udata->seq_idx[s] = udata->seq_id_unq.size();
index 209cf3699de23b07b993b0971e4ca27aaa897b0b..8e6fac0efabba3d5f1b3dd528ebb9e3d2c9ec99a 100644 (file)
@@ -56,13 +56,15 @@ struct llama_ubatch {
         std::vector<float>          embd;
         std::vector<llama_pos>      pos;
         std::vector<int32_t>        n_seq_id;
-        std::vector<llama_seq_id *> seq_id;
+        std::vector<llama_seq_id *> seq_id;      // these point into the seq_id_data below
         std::vector<llama_seq_id>   seq_id_unq;
         std::vector<int32_t>        seq_idx;
         std::vector<int8_t>         output;
+
+        std::vector<llama_seq_id> seq_id_data;
     };
 
-    // the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data
+    // the llama_ubatch pointers above point to this data if set. otherwise - point to external non-owning data
     std::shared_ptr<data_t> data;
 };