From: slaren Date: Mon, 28 Aug 2023 08:31:39 +0000 (+0200) Subject: gpt-2 : use ggml-alloc (#486) X-Git-Tag: upstream/0.0.1642~1265 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=3e551c0bda2a1c47f88f178fa3a79ea2c80d15ed;p=pkg%2Fggml%2Fsources%2Fggml gpt-2 : use ggml-alloc (#486) * gpt-2 : use ggml-alloc * move function comment to gpt2_eval * gpt-2 : clarifying comment --------- Co-authored-by: Georgi Gerganov --- diff --git a/examples/gpt-2/main.cpp b/examples/gpt-2/main.cpp index dd134d48..87917e3d 100644 --- a/examples/gpt-2/main.cpp +++ b/examples/gpt-2/main.cpp @@ -1,4 +1,5 @@ #include "ggml/ggml.h" +#include "ggml/ggml-alloc.h" #include "common.h" #include "common-ggml.h" @@ -380,21 +381,12 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & return true; } -// evaluate the transformer -// -// - model: the model -// - n_threads: number of threads to use -// - n_past: the context size so far -// - embd_inp: the embeddings of the tokens in the context -// - embd_w: the predicted logits for the next token -// -bool gpt2_eval( +// build the computation graph +struct ggml_cgraph * gpt2_graph( const gpt2_model & model, - const int n_threads, + struct ggml_allocr * allocr, const int n_past, - const std::vector & embd_inp, - std::vector & embd_w, - size_t & mem_per_token) { + const std::vector & embd_inp) { const int N = embd_inp.size(); const auto & hparams = model.hparams; @@ -403,39 +395,41 @@ bool gpt2_eval( const int n_layer = hparams.n_layer; const int n_ctx = hparams.n_ctx; const int n_head = hparams.n_head; - const int n_vocab = hparams.n_vocab; - static size_t buf_size = 256u*1024*1024; - static void * buf = malloc(buf_size); - - if (mem_per_token > 0 && mem_per_token*N > buf_size) { - const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead - //printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new); - - // reallocate - buf_size = buf_size_new; - buf = realloc(buf, buf_size); - if (buf == nullptr) { - fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size); - return false; - } - } + // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data + static size_t buf_size = ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead(); + static std::vector buf(buf_size); struct ggml_init_params params = { /*.mem_size =*/ buf_size, - /*.mem_buffer =*/ buf, - /*.no_alloc =*/ false, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_allocr_alloc_graph() }; struct ggml_context * ctx0 = ggml_init(params); - struct ggml_cgraph gf = {}; + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd)); + ggml_allocr_alloc(allocr, embd); + + // avoid writing to tensors if we are only measuring the memory usage + if (!ggml_allocr_is_measure(allocr)) { + memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd)); + } struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - for (int i = 0; i < N; ++i) { - ((int32_t *) position->data)[i] = n_past + i; + ggml_allocr_alloc(allocr, position); + if (!ggml_allocr_is_measure(allocr)) { + for (int i = 0; i < N; ++i) { + ((int32_t *) position->data)[i] = n_past + i; + } + } + + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_allocr_alloc(allocr, KQ_scale); + if (!ggml_allocr_is_measure(allocr)) { + ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); } // wte + wpe @@ -490,8 +484,8 @@ bool gpt2_eval( struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past)); struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past)); - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); } // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) @@ -531,18 +525,17 @@ bool gpt2_eval( // KQ_scaled = KQ / sqrt(n_embd/n_head) // [n_past + N, N, 12] struct ggml_tensor * KQ_scaled = - ggml_scale_inplace(ctx0, + ggml_scale(ctx0, KQ, - ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head)) - ); + KQ_scale); // KQ_masked = mask_past(KQ_scaled) // [n_past + N, N, 12] - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); // KQ = soft_max(KQ_masked) // [n_past + N, N, 12] - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() // [n_past + N, 64, 12] @@ -669,17 +662,60 @@ bool gpt2_eval( inpL = ggml_mul_mat(ctx0, model.lm_head, inpL); // logits -> probs - //inpL = ggml_soft_max_inplace(ctx0, inpL); + //inpL = ggml_soft_max(ctx0, inpL); + + ggml_build_forward_expand(gf, inpL); + + ggml_free(ctx0); + + return gf; +} + +// evaluate the transformer +// +// - model: the model +// - allocr: ggml_allocr to use to allocate the compute buffer +// - n_threads: number of threads to use +// - n_past: the context size so far +// - embd_inp: the embeddings of the tokens in the context +// - embd_w: the predicted logits for the next token +// +bool gpt2_eval( + const gpt2_model & model, + struct ggml_allocr * allocr, + const int n_threads, + const int n_past, + const std::vector & embd_inp, + std::vector & embd_w) { + const int N = embd_inp.size(); + + const auto & hparams = model.hparams; + + const int n_vocab = hparams.n_vocab; + + // reset the allocator to free all the memory allocated during the previous inference + ggml_allocr_reset(allocr); + + struct ggml_cgraph * gf = gpt2_graph(model, allocr, n_past, embd_inp); + + // allocate tensors + ggml_allocr_alloc_graph(allocr, gf); // run the computation - ggml_build_forward_expand(&gf, inpL); - ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); + struct ggml_cplan plan = ggml_graph_plan(gf, n_threads); + static std::vector work_buffer; + work_buffer.resize(plan.work_size); + plan.work_data = work_buffer.data(); + ggml_graph_compute(gf, &plan); //if (n_past%100 == 0) { // ggml_graph_print (&gf); // ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot"); //} + // in this case, the output tensor is the last one in the graph + struct ggml_tensor * inpL = gf->nodes[gf->n_nodes - 1]; + //embd_w.resize(n_vocab*N); //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); @@ -687,13 +723,6 @@ bool gpt2_eval( embd_w.resize(n_vocab); memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); - if (mem_per_token == 0) { - mem_per_token = ggml_used_mem(ctx0)/N; - } - //printf("used_mem = %zu\n", ggml_used_mem(ctx0)); - - ggml_free(ctx0); - return true; } @@ -739,6 +768,30 @@ int main(int argc, char ** argv) { test_gpt_tokenizer(vocab, params.token_test); } + // keep this buffer alive while evaluating the model + std::vector compute_buffer; + + struct ggml_allocr * allocr = NULL; + // allocate the compute buffer + { + allocr = ggml_allocr_new_measure(GGML_MEM_ALIGN); + + // create the worst case graph for memory usage estimation + int n_tokens = std::min(model.hparams.n_ctx, params.n_batch); + int n_past = model.hparams.n_ctx - n_tokens; + struct ggml_cgraph * gf = gpt2_graph(model, allocr, n_past, std::vector(n_tokens, 0)); + + // compute the required memory + size_t mem_size = ggml_allocr_alloc_graph(allocr, gf) + GGML_MEM_ALIGN; + + // recreate the allocator with the required memory + ggml_allocr_free(allocr); + compute_buffer.resize(mem_size); + allocr = ggml_allocr_new(compute_buffer.data(), mem_size, GGML_MEM_ALIGN); + + fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0/1024.0); + } + int n_past = 0; int64_t t_sample_us = 0; @@ -762,16 +815,12 @@ int main(int argc, char ** argv) { // this reduces the memory usage during inference, at the cost of a bit of speed at the beginning std::vector embd; - // determine the required inference memory per token: - size_t mem_per_token = 0; - gpt2_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); - for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) { // predict if (embd.size() > 0) { const int64_t t_start_us = ggml_time_us(); - if (!gpt2_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) { + if (!gpt2_eval(model, allocr, params.n_threads, n_past, embd, logits)) { printf("Failed to predict\n"); return 1; } @@ -830,7 +879,6 @@ int main(int argc, char ** argv) { const int64_t t_main_end_us = ggml_time_us(); printf("\n\n"); - printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token); printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f); printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f); printf("%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past); diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index 2074fd04..1baa6cea 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -214,6 +214,11 @@ #define GGML_MAX_OP_PARAMS 32 #define GGML_DEFAULT_N_THREADS 4 +#if UINTPTR_MAX == 0xFFFFFFFF + #define GGML_MEM_ALIGN 4 +#else + #define GGML_MEM_ALIGN 16 +#endif #define GGML_EXIT_SUCCESS 0 #define GGML_EXIT_ABORTED 1 diff --git a/src/ggml.c b/src/ggml.c index af031cc2..5922c330 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -157,12 +157,6 @@ typedef void * thread_ret_t; //#define GGML_SOFT_MAX_ACCELERATE #endif -#if UINTPTR_MAX == 0xFFFFFFFF - #define GGML_MEM_ALIGN 4 -#else - #define GGML_MEM_ALIGN 16 -#endif - // // logging //