]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
gpt-2 : use ggml-alloc (#486)
authorslaren <redacted>
Mon, 28 Aug 2023 08:31:39 +0000 (10:31 +0200)
committerGitHub <redacted>
Mon, 28 Aug 2023 08:31:39 +0000 (11:31 +0300)
* gpt-2 : use ggml-alloc

* move function comment to gpt2_eval

* gpt-2 : clarifying comment

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/gpt-2/main.cpp
include/ggml/ggml.h
src/ggml.c

index dd134d482072bf6d094843bdd3fa824dd87fadab..87917e3d04aa4f47849952b4b4b900ff87d09018 100644 (file)
@@ -1,4 +1,5 @@
 #include "ggml/ggml.h"
+#include "ggml/ggml-alloc.h"
 
 #include "common.h"
 #include "common-ggml.h"
@@ -380,21 +381,12 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
     return true;
 }
 
-// evaluate the transformer
-//
-//   - model:     the model
-//   - n_threads: number of threads to use
-//   - n_past:    the context size so far
-//   - embd_inp:  the embeddings of the tokens in the context
-//   - embd_w:    the predicted logits for the next token
-//
-bool gpt2_eval(
+// build the computation graph
+struct ggml_cgraph * gpt2_graph(
         const gpt2_model & model,
-        const int n_threads,
+        struct ggml_allocr * allocr,
         const int n_past,
-        const std::vector<gpt_vocab::id> & embd_inp,
-              std::vector<float>         & embd_w,
-              size_t                     & mem_per_token) {
+        const std::vector<gpt_vocab::id> & embd_inp) {
     const int N = embd_inp.size();
 
     const auto & hparams = model.hparams;
@@ -403,39 +395,41 @@ bool gpt2_eval(
     const int n_layer = hparams.n_layer;
     const int n_ctx   = hparams.n_ctx;
     const int n_head  = hparams.n_head;
-    const int n_vocab = hparams.n_vocab;
 
-    static size_t buf_size = 256u*1024*1024;
-    static void * buf = malloc(buf_size);
-
-    if (mem_per_token > 0 && mem_per_token*N > buf_size) {
-        const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
-        //printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
-
-        // reallocate
-        buf_size = buf_size_new;
-        buf = realloc(buf, buf_size);
-        if (buf == nullptr) {
-            fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
-            return false;
-        }
-    }
+    // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data
+    static size_t buf_size = ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead();
+    static std::vector<uint8_t> buf(buf_size);
 
     struct ggml_init_params params = {
         /*.mem_size   =*/ buf_size,
-        /*.mem_buffer =*/ buf,
-        /*.no_alloc   =*/ false,
+        /*.mem_buffer =*/ buf.data(),
+        /*.no_alloc   =*/ true, // the tensors will be allocated later by ggml_allocr_alloc_graph()
     };
 
     struct ggml_context * ctx0 = ggml_init(params);
-    struct ggml_cgraph gf = {};
+
+    struct ggml_cgraph  * gf = ggml_new_graph(ctx0);
 
     struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
-    memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
+    ggml_allocr_alloc(allocr, embd);
+
+    // avoid writing to tensors if we are only measuring the memory usage
+    if (!ggml_allocr_is_measure(allocr)) {
+        memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
+    }
 
     struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
-    for (int i = 0; i < N; ++i) {
-        ((int32_t *) position->data)[i] = n_past + i;
+    ggml_allocr_alloc(allocr, position);
+    if (!ggml_allocr_is_measure(allocr)) {
+        for (int i = 0; i < N; ++i) {
+            ((int32_t *) position->data)[i] = n_past + i;
+        }
+    }
+
+    struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+    ggml_allocr_alloc(allocr, KQ_scale);
+    if (!ggml_allocr_is_measure(allocr)) {
+        ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
     }
 
     // wte + wpe
@@ -490,8 +484,8 @@ bool gpt2_eval(
                 struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
                 struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
 
-                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
-                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
             }
 
             // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
@@ -531,18 +525,17 @@ bool gpt2_eval(
             // KQ_scaled = KQ / sqrt(n_embd/n_head)
             // [n_past + N, N, 12]
             struct ggml_tensor * KQ_scaled =
-                ggml_scale_inplace(ctx0,
+                ggml_scale(ctx0,
                         KQ,
-                        ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
-                        );
+                        KQ_scale);
 
             // KQ_masked = mask_past(KQ_scaled)
             // [n_past + N, N, 12]
-            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
+            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
 
             // KQ = soft_max(KQ_masked)
             // [n_past + N, N, 12]
-            struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
 
             // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
             // [n_past + N, 64, 12]
@@ -669,17 +662,60 @@ bool gpt2_eval(
     inpL = ggml_mul_mat(ctx0, model.lm_head, inpL);
 
     // logits -> probs
-    //inpL = ggml_soft_max_inplace(ctx0, inpL);
+    //inpL = ggml_soft_max(ctx0, inpL);
+
+    ggml_build_forward_expand(gf, inpL);
+
+    ggml_free(ctx0);
+
+    return gf;
+}
+
+// evaluate the transformer
+//
+//   - model:     the model
+//   - allocr:    ggml_allocr to use to allocate the compute buffer
+//   - n_threads: number of threads to use
+//   - n_past:    the context size so far
+//   - embd_inp:  the embeddings of the tokens in the context
+//   - embd_w:    the predicted logits for the next token
+//
+bool gpt2_eval(
+        const gpt2_model & model,
+        struct ggml_allocr * allocr,
+        const int n_threads,
+        const int n_past,
+        const std::vector<gpt_vocab::id> & embd_inp,
+              std::vector<float>         & embd_w) {
+    const int N = embd_inp.size();
+
+    const auto & hparams = model.hparams;
+
+    const int n_vocab = hparams.n_vocab;
+
+    // reset the allocator to free all the memory allocated during the previous inference
+    ggml_allocr_reset(allocr);
+
+    struct ggml_cgraph * gf = gpt2_graph(model, allocr, n_past, embd_inp);
+
+    // allocate tensors
+    ggml_allocr_alloc_graph(allocr, gf);
 
     // run the computation
-    ggml_build_forward_expand(&gf, inpL);
-    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+    struct ggml_cplan plan = ggml_graph_plan(gf, n_threads);
+    static std::vector<uint8_t> work_buffer;
+    work_buffer.resize(plan.work_size);
+    plan.work_data = work_buffer.data();
+    ggml_graph_compute(gf, &plan);
 
     //if (n_past%100 == 0) {
     //    ggml_graph_print   (&gf);
     //    ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
     //}
 
+    // in this case, the output tensor is the last one in the graph
+    struct ggml_tensor * inpL = gf->nodes[gf->n_nodes - 1];
+
     //embd_w.resize(n_vocab*N);
     //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
 
@@ -687,13 +723,6 @@ bool gpt2_eval(
     embd_w.resize(n_vocab);
     memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
 
-    if (mem_per_token == 0) {
-        mem_per_token = ggml_used_mem(ctx0)/N;
-    }
-    //printf("used_mem = %zu\n", ggml_used_mem(ctx0));
-
-    ggml_free(ctx0);
-
     return true;
 }
 
@@ -739,6 +768,30 @@ int main(int argc, char ** argv) {
         test_gpt_tokenizer(vocab, params.token_test);
     }
 
+    // keep this buffer alive while evaluating the model
+    std::vector<uint8_t> compute_buffer;
+
+    struct ggml_allocr * allocr = NULL;
+    // allocate the compute buffer
+    {
+        allocr = ggml_allocr_new_measure(GGML_MEM_ALIGN);
+
+        // create the worst case graph for memory usage estimation
+        int n_tokens = std::min(model.hparams.n_ctx, params.n_batch);
+        int n_past = model.hparams.n_ctx - n_tokens;
+        struct ggml_cgraph * gf = gpt2_graph(model, allocr, n_past, std::vector<gpt_vocab::id>(n_tokens, 0));
+
+        // compute the required memory
+        size_t mem_size = ggml_allocr_alloc_graph(allocr, gf) + GGML_MEM_ALIGN;
+
+        // recreate the allocator with the required memory
+        ggml_allocr_free(allocr);
+        compute_buffer.resize(mem_size);
+        allocr = ggml_allocr_new(compute_buffer.data(), mem_size, GGML_MEM_ALIGN);
+
+        fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0/1024.0);
+    }
+
     int n_past = 0;
 
     int64_t t_sample_us  = 0;
@@ -762,16 +815,12 @@ int main(int argc, char ** argv) {
     // this reduces the memory usage during inference, at the cost of a bit of speed at the beginning
     std::vector<gpt_vocab::id> embd;
 
-    // determine the required inference memory per token:
-    size_t mem_per_token = 0;
-    gpt2_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
-
     for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
         // predict
         if (embd.size() > 0) {
             const int64_t t_start_us = ggml_time_us();
 
-            if (!gpt2_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {
+            if (!gpt2_eval(model, allocr, params.n_threads, n_past, embd, logits)) {
                 printf("Failed to predict\n");
                 return 1;
             }
@@ -830,7 +879,6 @@ int main(int argc, char ** argv) {
         const int64_t t_main_end_us = ggml_time_us();
 
         printf("\n\n");
-        printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token);
         printf("%s:     load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
         printf("%s:   sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f);
         printf("%s:  predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past);
index 2074fd04b0f5c36338c613f250d468ddbb4b3ef8..1baa6cea133794744969f79d8e79d5067bc4a510 100644 (file)
 #define GGML_MAX_OP_PARAMS     32
 #define GGML_DEFAULT_N_THREADS 4
 
+#if UINTPTR_MAX == 0xFFFFFFFF
+    #define GGML_MEM_ALIGN 4
+#else
+    #define GGML_MEM_ALIGN 16
+#endif
 
 #define GGML_EXIT_SUCCESS 0
 #define GGML_EXIT_ABORTED 1
index af031cc2363755d49a01e2cb862bea82d50a3057..5922c330bd2eee4d81cdea5f03f236fb2dc993c5 100644 (file)
@@ -157,12 +157,6 @@ typedef void * thread_ret_t;
 //#define GGML_SOFT_MAX_ACCELERATE
 #endif
 
-#if UINTPTR_MAX == 0xFFFFFFFF
-    #define GGML_MEM_ALIGN 4
-#else
-    #define GGML_MEM_ALIGN 16
-#endif
-
 //
 // logging
 //