]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
gpt-2 : fix allocr worst-case when n_parallel > prompt size
authorGeorgi Gerganov <redacted>
Fri, 20 Oct 2023 07:12:39 +0000 (10:12 +0300)
committerGeorgi Gerganov <redacted>
Fri, 20 Oct 2023 07:12:39 +0000 (10:12 +0300)
examples/gpt-2/main-batched.cpp

index 76066ce8a28cd243d44c630660afcccacaae457b..3ba665da2da2a0c35651e61627aeb7dac5601790 100644 (file)
@@ -1032,31 +1032,22 @@ int main(int argc, char ** argv) {
     // keep this buffer alive while evaluating the model
     ggml_backend_buffer_t buf_compute;
 
-    // create a gpt2_batch
-    // we use this object to submit token data for decoding
     const int n_parallel = params.n_parallel;
-    gpt2_batch batch = gpt2_batch_init(std::max(embd_inp.size(), (size_t)n_parallel), 0);
-
-    // evaluate the initial prompt
-    batch.n_tokens = embd_inp.size();
+    const int n_batch_max = std::max(embd_inp.size(), (size_t)n_parallel);
 
-    for (int32_t i = 0; i < batch.n_tokens; i++) {
-        batch.token[i]  = embd_inp[i];
-        batch.pos[i]    = i;
-        batch.seq_id[i] = 0;
-        batch.logits[i] = false;
-    }
-
-    // gpt2_decode will output logits only for the last token of the prompt
-    batch.logits[batch.n_tokens - 1] = true;
+    // create a gpt2_batch
+    // we use this object to submit token data for decoding
+    gpt2_batch batch = gpt2_batch_init(n_batch_max, 0);
 
+    // prepare required memory and allocate the compute buffer
     struct ggml_allocr * allocr = NULL;
-    // allocate the compute buffer
     {
-         // alignment required by the backend
+        // alignment required by the backend
         size_t align = ggml_backend_get_alignment(model.backend);
         allocr = ggml_allocr_new_measure(align);
 
+        batch.n_tokens = n_batch_max;
+
         // create the worst case graph for memory usage estimation
         struct ggml_cgraph * gf = gpt2_graph(model, allocr, batch);
 
@@ -1076,6 +1067,19 @@ int main(int argc, char ** argv) {
 
     std::vector<float> logits;
 
+    // evaluate the initial prompt
+    batch.n_tokens = embd_inp.size();
+
+    for (int32_t i = 0; i < batch.n_tokens; i++) {
+        batch.token[i]  = embd_inp[i];
+        batch.pos[i]    = i;
+        batch.seq_id[i] = 0;
+        batch.logits[i] = false;
+    }
+
+    // gpt2_decode will output logits only for the last token of the prompt
+    batch.logits[batch.n_tokens - 1] = true;
+
     if (gpt2_decode(model, allocr, batch, params.n_threads, logits) != 0) {
         printf("%s: gpt2_decode() failed\n", __func__);
         return 1;