};
}
-struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd, int32_t n_seq_max) {
+struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
if (embd) {
- batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd);
+ batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
} else {
- batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
+ batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
}
- batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
- batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
- batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
- for (int i = 0; i < n_tokens; ++i) {
+ batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
+ batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
+ batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
+ for (int i = 0; i < n_tokens_alloc; ++i) {
batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
}
- batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
+ batch.seq_id[n_tokens_alloc] = nullptr;
+
+ batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
return batch;
}
if (batch.pos) free(batch.pos);
if (batch.n_seq_id) free(batch.n_seq_id);
if (batch.seq_id) {
- for (int i = 0; i < batch.n_tokens; ++i) {
+ for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
free(batch.seq_id[i]);
}
free(batch.seq_id);