]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
context : fix index overflow on huge outputs (#15080)
authorcompilade <redacted>
Tue, 5 Aug 2025 09:27:45 +0000 (05:27 -0400)
committerGitHub <redacted>
Tue, 5 Aug 2025 09:27:45 +0000 (11:27 +0200)
* context : fix overflow when re-ordering huge outputs

* context : fix logits size overflow for huge batches

src/llama-context.cpp

index 958bcc0477f7b4d47dce72b559badae1c6b8e799..26a5cf9c3f8dbc3550c77262a17420675b91b3fa 100644 (file)
@@ -786,7 +786,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
     const auto & hparams = model.hparams;
 
     const int64_t n_embd  = hparams.n_embd;
-    const int32_t n_vocab = model.vocab.n_tokens();
+    const int64_t n_vocab = model.vocab.n_tokens();
 
     // note: during encode, we always pass the full sequence starting from pos = 0
     if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
@@ -959,7 +959,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
     const auto & vocab   = model.vocab;
     const auto & hparams = model.hparams;
 
-    const int32_t n_vocab = vocab.n_tokens();
+    const int64_t n_vocab = vocab.n_tokens();
     const int64_t n_embd  = hparams.n_embd;
 
     // when computing embeddings, all tokens are output
@@ -1328,21 +1328,21 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
 }
 
 void llama_context::output_reorder() {
-    const uint32_t n_vocab = model.vocab.n_tokens();
+    const uint64_t n_vocab = model.vocab.n_tokens();
     const uint64_t n_embd  = model.hparams.n_embd;
 
-    for (uint32_t s = 0; s < output_swaps.size(); ++s) {
-        const uint32_t i0 = output_swaps[s].i0;
-        const uint32_t i1 = output_swaps[s].i1;
+    for (size_t s = 0; s < output_swaps.size(); ++s) {
+        const uint64_t i0 = output_swaps[s].i0;
+        const uint64_t i1 = output_swaps[s].i1;
 
         if (logits_size > 0) {
-            for (uint32_t k = 0; k < n_vocab; k++) {
+            for (uint64_t k = 0; k < n_vocab; k++) {
                 std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
             }
         }
 
         if (embd_size > 0) {
-            for (uint32_t k = 0; k < n_embd; k++) {
+            for (uint64_t k = 0; k < n_embd; k++) {
                 std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
             }
         }