#include <cassert>
#include <cstring>
#include <algorithm>
+#include <sstream>
llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
// clear empty sequences
);
}
-llama_batch_allocr::llama_batch_allocr() = default;
+llama_batch_allocr::llama_batch_allocr() {
+ const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
+ debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
+}
bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0) {
clear();
n_outputs += batch.logits[i] != 0;
}
+ if (debug > 0) {
+ LLAMA_LOG_DEBUG("%s: input batch info (p0 = %d):\n", __func__, p0);
+ LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, batch.n_tokens);
+ LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) batch.token);
+ LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) batch.embd);
+ LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) batch.pos);
+ LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) batch.n_seq_id);
+ LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) batch.seq_id);
+ LLAMA_LOG_DEBUG("%s: logits = %p\n", __func__, (void *) batch.logits);
+ LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);
+
+ if (debug > 1) {
+ int seq_id_max = 0;
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
+ for (int s = 0; s < batch.n_seq_id[i]; ++s) {
+ for (int s = 0; s < batch.n_seq_id[i]; ++s) {
+ seq_id_max = std::max(seq_id_max, batch.seq_id[i][s]);
+ }
+ }
+ }
+ ++seq_id_max;
+
+ LLAMA_LOG_DEBUG("%s: token = [\n", __func__);
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
+ std::vector<int8_t> seq_id(seq_id_max);
+
+ for (int s = 0; s < batch.n_seq_id[i]; ++s) {
+ seq_id[batch.seq_id[i][s]] = 1;
+ }
+
+ std::stringstream ss;
+ for (int s = 0; s < seq_id_max; ++s) {
+ if (seq_id[s]) {
+ ss << s%10;
+ } else {
+ ss << ".";
+ }
+ }
+
+ LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
+ __func__, i, batch.token[i], vocab.token_to_piece(batch.token[i]).c_str(),
+ batch.pos[i], batch.n_seq_id[i], ss.str().c_str(), batch.logits[i]);
+ }
+ LLAMA_LOG_DEBUG("%s: ]\n", __func__);
+ }
+ }
+
return true;
}