From: compilade Date: Tue, 5 Aug 2025 09:27:45 +0000 (-0400) Subject: context : fix index overflow on huge outputs (#15080) X-Git-Tag: upstream/0.0.6164~74 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=ee3a9fcf88fe5b5e1213711e05861b83cd4fdfe6;p=pkg%2Fggml%2Fsources%2Fllama.cpp context : fix index overflow on huge outputs (#15080) * context : fix overflow when re-ordering huge outputs * context : fix logits size overflow for huge batches --- diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 958bcc04..26a5cf9c 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -786,7 +786,7 @@ int llama_context::encode(const llama_batch & batch_inp) { const auto & hparams = model.hparams; const int64_t n_embd = hparams.n_embd; - const int32_t n_vocab = model.vocab.n_tokens(); + const int64_t n_vocab = model.vocab.n_tokens(); // note: during encode, we always pass the full sequence starting from pos = 0 if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { @@ -959,7 +959,7 @@ int llama_context::decode(const llama_batch & batch_inp) { const auto & vocab = model.vocab; const auto & hparams = model.hparams; - const int32_t n_vocab = vocab.n_tokens(); + const int64_t n_vocab = vocab.n_tokens(); const int64_t n_embd = hparams.n_embd; // when computing embeddings, all tokens are output @@ -1328,21 +1328,21 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { } void llama_context::output_reorder() { - const uint32_t n_vocab = model.vocab.n_tokens(); + const uint64_t n_vocab = model.vocab.n_tokens(); const uint64_t n_embd = model.hparams.n_embd; - for (uint32_t s = 0; s < output_swaps.size(); ++s) { - const uint32_t i0 = output_swaps[s].i0; - const uint32_t i1 = output_swaps[s].i1; + for (size_t s = 0; s < output_swaps.size(); ++s) { + const uint64_t i0 = output_swaps[s].i0; + const uint64_t i1 = output_swaps[s].i1; if (logits_size > 0) { - for (uint32_t k = 0; k < n_vocab; k++) { + for (uint64_t k = 0; k < n_vocab; k++) { std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]); } } if (embd_size > 0) { - for (uint32_t k = 0; k < n_embd; k++) { + for (uint64_t k = 0; k < n_embd; k++) { std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]); } }