From: Georgi Gerganov Date: Fri, 20 Oct 2023 07:12:39 +0000 (+0300) Subject: gpt-2 : fix allocr worst-case when n_parallel > prompt size X-Git-Tag: upstream/0.0.1642~1212 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=53590e37e2665814bc3ba1ffe46a47d2e10868ed;p=pkg%2Fggml%2Fsources%2Fggml gpt-2 : fix allocr worst-case when n_parallel > prompt size --- diff --git a/examples/gpt-2/main-batched.cpp b/examples/gpt-2/main-batched.cpp index 76066ce8..3ba665da 100644 --- a/examples/gpt-2/main-batched.cpp +++ b/examples/gpt-2/main-batched.cpp @@ -1032,31 +1032,22 @@ int main(int argc, char ** argv) { // keep this buffer alive while evaluating the model ggml_backend_buffer_t buf_compute; - // create a gpt2_batch - // we use this object to submit token data for decoding const int n_parallel = params.n_parallel; - gpt2_batch batch = gpt2_batch_init(std::max(embd_inp.size(), (size_t)n_parallel), 0); - - // evaluate the initial prompt - batch.n_tokens = embd_inp.size(); + const int n_batch_max = std::max(embd_inp.size(), (size_t)n_parallel); - for (int32_t i = 0; i < batch.n_tokens; i++) { - batch.token[i] = embd_inp[i]; - batch.pos[i] = i; - batch.seq_id[i] = 0; - batch.logits[i] = false; - } - - // gpt2_decode will output logits only for the last token of the prompt - batch.logits[batch.n_tokens - 1] = true; + // create a gpt2_batch + // we use this object to submit token data for decoding + gpt2_batch batch = gpt2_batch_init(n_batch_max, 0); + // prepare required memory and allocate the compute buffer struct ggml_allocr * allocr = NULL; - // allocate the compute buffer { - // alignment required by the backend + // alignment required by the backend size_t align = ggml_backend_get_alignment(model.backend); allocr = ggml_allocr_new_measure(align); + batch.n_tokens = n_batch_max; + // create the worst case graph for memory usage estimation struct ggml_cgraph * gf = gpt2_graph(model, allocr, batch); @@ -1076,6 +1067,19 @@ int main(int argc, char ** argv) { std::vector logits; + // evaluate the initial prompt + batch.n_tokens = embd_inp.size(); + + for (int32_t i = 0; i < batch.n_tokens; i++) { + batch.token[i] = embd_inp[i]; + batch.pos[i] = i; + batch.seq_id[i] = 0; + batch.logits[i] = false; + } + + // gpt2_decode will output logits only for the last token of the prompt + batch.logits[batch.n_tokens - 1] = true; + if (gpt2_decode(model, allocr, batch, params.n_threads, logits) != 0) { printf("%s: gpt2_decode() failed\n", __func__); return 1;