// keep this buffer alive while evaluating the model
ggml_backend_buffer_t buf_compute;
- // create a gpt2_batch
- // we use this object to submit token data for decoding
const int n_parallel = params.n_parallel;
- gpt2_batch batch = gpt2_batch_init(std::max(embd_inp.size(), (size_t)n_parallel), 0);
-
- // evaluate the initial prompt
- batch.n_tokens = embd_inp.size();
+ const int n_batch_max = std::max(embd_inp.size(), (size_t)n_parallel);
- for (int32_t i = 0; i < batch.n_tokens; i++) {
- batch.token[i] = embd_inp[i];
- batch.pos[i] = i;
- batch.seq_id[i] = 0;
- batch.logits[i] = false;
- }
-
- // gpt2_decode will output logits only for the last token of the prompt
- batch.logits[batch.n_tokens - 1] = true;
+ // create a gpt2_batch
+ // we use this object to submit token data for decoding
+ gpt2_batch batch = gpt2_batch_init(n_batch_max, 0);
+ // prepare required memory and allocate the compute buffer
struct ggml_allocr * allocr = NULL;
- // allocate the compute buffer
{
- // alignment required by the backend
+ // alignment required by the backend
size_t align = ggml_backend_get_alignment(model.backend);
allocr = ggml_allocr_new_measure(align);
+ batch.n_tokens = n_batch_max;
+
// create the worst case graph for memory usage estimation
struct ggml_cgraph * gf = gpt2_graph(model, allocr, batch);
std::vector<float> logits;
+ // evaluate the initial prompt
+ batch.n_tokens = embd_inp.size();
+
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
+ batch.token[i] = embd_inp[i];
+ batch.pos[i] = i;
+ batch.seq_id[i] = 0;
+ batch.logits[i] = false;
+ }
+
+ // gpt2_decode will output logits only for the last token of the prompt
+ batch.logits[batch.n_tokens - 1] = true;
+
if (gpt2_decode(model, allocr, batch, params.n_threads, logits) != 0) {
printf("%s: gpt2_decode() failed\n", __func__);
return 1;