From: Georgi Gerganov Date: Sat, 20 May 2023 17:56:35 +0000 (+0300) Subject: examples : use scratch buffers to reduce memory usage (#176) X-Git-Tag: upstream/0.0.1642~1452 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=d6957552095ddb1a321a213be189cb70c94c5857;p=pkg%2Fggml%2Fsources%2Fggml examples : use scratch buffers to reduce memory usage (#176) * starcoder : example for using scratch buffers to reduce memory usage * starcoder : bump scratch buffers to 256 MB * examples : add scratch buffers to MPT and GPT-NeoX --- diff --git a/examples/gpt-neox/main.cpp b/examples/gpt-neox/main.cpp index dcb6ccf1..183585d5 100644 --- a/examples/gpt-neox/main.cpp +++ b/examples/gpt-neox/main.cpp @@ -445,6 +445,14 @@ bool gpt_neox_eval( static size_t buf_size = 256u*1024*1024; static void * buf = malloc(buf_size); + // use 2 scratch buffers + // TODO: very hacky solution - reimplement in a more elegant way + static size_t scr0_size = 256u*1024*1024; + static void * scr0 = malloc(scr0_size); + + static size_t scr1_size = 256u*1024*1024; + static void * scr1 = malloc(scr1_size); + if (mem_per_token > 0 && mem_per_token*N > buf_size) { const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead //printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new); @@ -477,6 +485,8 @@ bool gpt_neox_eval( for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * cur; + ggml_set_scratch(ctx0, { 0, scr0_size, scr0, }); + // self-attention { { @@ -580,6 +590,8 @@ bool gpt_neox_eval( } } + ggml_set_scratch(ctx0, { 0, scr1_size, scr1, }); + if (hparams.par_res == 0) { struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpL); @@ -602,6 +614,8 @@ bool gpt_neox_eval( } } + ggml_set_scratch(ctx0, { 0, scr0_size, scr0, }); + // norm { inpL = ggml_norm(ctx0, inpL); @@ -614,6 +628,8 @@ bool gpt_neox_eval( ggml_repeat(ctx0, model.ln_f_b, inpL)); } + ggml_set_scratch(ctx0, { 0, 0, nullptr, }); + // lm_head { inpL = ggml_mul_mat(ctx0, model.lmh_g, inpL); diff --git a/examples/mpt/main.cpp b/examples/mpt/main.cpp index 2041d931..94cb44dc 100644 --- a/examples/mpt/main.cpp +++ b/examples/mpt/main.cpp @@ -348,6 +348,14 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past, static size_t buf_size = 256u * 1024 * 1024; static void * buf = malloc(buf_size); + // use 2 scratch buffers + // TODO: very hacky solution - reimplement in a more elegant way + static size_t scr0_size = 256u*1024*1024; + static void * scr0 = malloc(scr0_size); + + static size_t scr1_size = 256u*1024*1024; + static void * scr1 = malloc(scr1_size); + if (mem_per_token > 0 && mem_per_token * N > buf_size) { const size_t buf_size_new = 1.1 * (mem_per_token * N); // add 10% to account for ggml object overhead // printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, @@ -380,6 +388,8 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past, struct ggml_tensor * cur; + ggml_set_scratch(ctx0, { 0, scr0_size, scr0, }); + // a = self.ln_1(x) { cur = ggml_norm(ctx0, inpL); @@ -392,7 +402,6 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past, // attn_bias=attn_bias, attention_mask=attention_mask, // is_causal=is_causal) { - // compute QKV cur = ggml_mul_mat(ctx0, model.layers[il].c_attn_wqkv_weight, cur); @@ -475,6 +484,8 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past, inpL = ggml_add(ctx0, inpL, cur); + ggml_set_scratch(ctx0, { 0, scr1_size, scr1, }); + // m = self.ln_2(x) { cur = ggml_norm(ctx0, inpL); @@ -499,6 +510,8 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past, inpL = ggml_add(ctx0, inpL, cur); } + ggml_set_scratch(ctx0, { 0, scr0_size, scr0, }); + // norm { inpL = ggml_norm(ctx0, inpL); @@ -506,6 +519,8 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past, inpL = ggml_mul(ctx0, ggml_repeat(ctx0, model.norm_f_weight, inpL), inpL); } + ggml_set_scratch(ctx0, { 0, 0, nullptr, }); + // output embedding weight tied to input embedding inpL = ggml_mul_mat(ctx0, model.wte_weight, inpL); diff --git a/examples/starcoder/main.cpp b/examples/starcoder/main.cpp index 49609405..c9d1d7ec 100644 --- a/examples/starcoder/main.cpp +++ b/examples/starcoder/main.cpp @@ -416,6 +416,14 @@ bool starcoder_eval( static size_t buf_size = 256u*1024*1024; static void * buf = malloc(buf_size); + // use 2 scratch buffers + // TODO: very hacky solution - reimplement in a more elegant way + static size_t scr0_size = 256u*1024*1024; + static void * scr0 = malloc(scr0_size); + + static size_t scr1_size = 256u*1024*1024; + static void * scr1 = malloc(scr1_size); + if (mem_per_token > 0 && mem_per_token*N > buf_size) { const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead //printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new); @@ -456,6 +464,8 @@ bool starcoder_eval( for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * cur; + ggml_set_scratch(ctx0, { 0, scr0_size, scr0, }); + // norm { // [ 768, N] @@ -519,7 +529,7 @@ bool starcoder_eval( ggml_reshape_3d(ctx0, ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd), n_embd/n_head, n_head, n_past + N), - 0, 2, 1, 3); //TODO: need to be tiled + 0, 2, 1, 3); //TODO: need to be tiled // GG: flash attention //struct ggml_tensor * V = @@ -602,6 +612,8 @@ bool starcoder_eval( struct ggml_tensor * inpFF = cur; + ggml_set_scratch(ctx0, { 0, scr1_size, scr1, }); + // feed-forward network { // norm @@ -658,6 +670,8 @@ bool starcoder_eval( inpL = ggml_add(ctx0, cur, inpFF); } + ggml_set_scratch(ctx0, { 0, scr0_size, scr0, }); + // norm { // [ 768, N] @@ -672,6 +686,8 @@ bool starcoder_eval( ggml_repeat(ctx0, model.ln_f_b, inpL)); } + ggml_set_scratch(ctx0, { 0, 0, nullptr, }); + // inpL = WTE * inpL // [ 768, 50257] - model.lm_head // [ 768, N] - inpL @@ -699,7 +715,7 @@ bool starcoder_eval( if (mem_per_token == 0) { mem_per_token = ggml_used_mem(ctx0)/N; } - //printf("used_mem = %zu\n", ggml_used_mem(ctx0)); + //printf("used_mem = %zu MB\n", ggml_used_mem(ctx0)/(1024*1024)); ggml_free(ctx0); diff --git a/src/ggml.c b/src/ggml.c index 77a3d89f..7612c86d 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -4077,7 +4077,8 @@ struct ggml_tensor * ggml_new_tensor_impl( }; } else { if (ctx->scratch.offs + size_needed > ctx->scratch.size) { - GGML_PRINT("%s: not enough space in the scratch memory\n", __func__); + GGML_PRINT("%s: not enough space in the scratch memory pool (needed %zu, available %zu)\n", + __func__, ctx->scratch.offs + size_needed, ctx->scratch.size); assert(false); return NULL; }