]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
examples : use scratch buffers to reduce memory usage (#176)
authorGeorgi Gerganov <redacted>
Sat, 20 May 2023 17:56:35 +0000 (20:56 +0300)
committerGitHub <redacted>
Sat, 20 May 2023 17:56:35 +0000 (20:56 +0300)
* starcoder : example for using scratch buffers to reduce memory usage

* starcoder : bump scratch buffers to 256 MB

* examples : add scratch buffers to MPT and GPT-NeoX

examples/gpt-neox/main.cpp
examples/mpt/main.cpp
examples/starcoder/main.cpp
src/ggml.c

index dcb6ccf16ff58cab2b594fc6d0d125d473e2f83c..183585d5901d047074317eceffd590cb8bd10566 100644 (file)
@@ -445,6 +445,14 @@ bool gpt_neox_eval(
     static size_t buf_size = 256u*1024*1024;
     static void * buf = malloc(buf_size);
 
+    // use 2 scratch buffers
+    // TODO: very hacky solution - reimplement in a more elegant way
+    static size_t scr0_size = 256u*1024*1024;
+    static void * scr0 = malloc(scr0_size);
+
+    static size_t scr1_size = 256u*1024*1024;
+    static void * scr1 = malloc(scr1_size);
+
     if (mem_per_token > 0 && mem_per_token*N > buf_size) {
         const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
         //printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
@@ -477,6 +485,8 @@ bool gpt_neox_eval(
     for (int il = 0; il < n_layer; ++il) {
         struct ggml_tensor * cur;
 
+        ggml_set_scratch(ctx0, { 0, scr0_size, scr0, });
+
         // self-attention
         {
             {
@@ -580,6 +590,8 @@ bool gpt_neox_eval(
             }
         }
 
+        ggml_set_scratch(ctx0, { 0, scr1_size, scr1, });
+
         if (hparams.par_res == 0) {
             struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpL);
 
@@ -602,6 +614,8 @@ bool gpt_neox_eval(
         }
     }
 
+    ggml_set_scratch(ctx0, { 0, scr0_size, scr0, });
+
     // norm
     {
         inpL = ggml_norm(ctx0, inpL);
@@ -614,6 +628,8 @@ bool gpt_neox_eval(
                 ggml_repeat(ctx0, model.ln_f_b, inpL));
     }
 
+    ggml_set_scratch(ctx0, { 0, 0, nullptr, });
+
     // lm_head
     {
         inpL = ggml_mul_mat(ctx0, model.lmh_g, inpL);
index 2041d931113a41057d18752b22aff8f9376f675f..94cb44dcbab9efc569fa418f3335d56bca868526 100644 (file)
@@ -348,6 +348,14 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
     static size_t buf_size = 256u * 1024 * 1024;
     static void * buf = malloc(buf_size);
 
+    // use 2 scratch buffers
+    // TODO: very hacky solution - reimplement in a more elegant way
+    static size_t scr0_size = 256u*1024*1024;
+    static void * scr0 = malloc(scr0_size);
+
+    static size_t scr1_size = 256u*1024*1024;
+    static void * scr1 = malloc(scr1_size);
+
     if (mem_per_token > 0 && mem_per_token * N > buf_size) {
         const size_t buf_size_new = 1.1 * (mem_per_token * N); // add 10% to account for ggml object overhead
         // printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__,
@@ -380,6 +388,8 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
 
         struct ggml_tensor * cur;
 
+        ggml_set_scratch(ctx0, { 0, scr0_size, scr0, });
+
         // a = self.ln_1(x)
         {
             cur = ggml_norm(ctx0, inpL);
@@ -392,7 +402,6 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
         //  attn_bias=attn_bias, attention_mask=attention_mask,
         //  is_causal=is_causal)
         {
-
             // compute QKV
             cur = ggml_mul_mat(ctx0, model.layers[il].c_attn_wqkv_weight, cur);
 
@@ -475,6 +484,8 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
 
         inpL = ggml_add(ctx0, inpL, cur);
 
+        ggml_set_scratch(ctx0, { 0, scr1_size, scr1, });
+
         // m = self.ln_2(x)
         {
             cur = ggml_norm(ctx0, inpL);
@@ -499,6 +510,8 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
         inpL = ggml_add(ctx0, inpL, cur);
     }
 
+    ggml_set_scratch(ctx0, { 0, scr0_size, scr0, });
+
     // norm
     {
         inpL = ggml_norm(ctx0, inpL);
@@ -506,6 +519,8 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
         inpL = ggml_mul(ctx0, ggml_repeat(ctx0, model.norm_f_weight, inpL), inpL);
     }
 
+    ggml_set_scratch(ctx0, { 0, 0, nullptr, });
+
     // output embedding weight tied to input embedding
     inpL = ggml_mul_mat(ctx0, model.wte_weight, inpL);
 
index 4960940522e83a6cc3e74447abb51f15a1b19890..c9d1d7ec7d55527e550e9c72b7d54d91d55ebec9 100644 (file)
@@ -416,6 +416,14 @@ bool starcoder_eval(
     static size_t buf_size = 256u*1024*1024;
     static void * buf = malloc(buf_size);
 
+    // use 2 scratch buffers
+    // TODO: very hacky solution - reimplement in a more elegant way
+    static size_t scr0_size = 256u*1024*1024;
+    static void * scr0 = malloc(scr0_size);
+
+    static size_t scr1_size = 256u*1024*1024;
+    static void * scr1 = malloc(scr1_size);
+
     if (mem_per_token > 0 && mem_per_token*N > buf_size) {
         const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
         //printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
@@ -456,6 +464,8 @@ bool starcoder_eval(
     for (int il = 0; il < n_layer; ++il) {
         struct ggml_tensor * cur;
 
+        ggml_set_scratch(ctx0, { 0, scr0_size, scr0, });
+
         // norm
         {
             // [ 768, N]
@@ -519,7 +529,7 @@ bool starcoder_eval(
                         ggml_reshape_3d(ctx0,
                             ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
                             n_embd/n_head, n_head, n_past + N),
-                        0, 2, 1, 3); //TODO: need to be tiled 
+                        0, 2, 1, 3); //TODO: need to be tiled
 
             // GG: flash attention
             //struct ggml_tensor * V =
@@ -602,6 +612,8 @@ bool starcoder_eval(
 
         struct ggml_tensor * inpFF = cur;
 
+        ggml_set_scratch(ctx0, { 0, scr1_size, scr1, });
+
         // feed-forward network
         {
             // norm
@@ -658,6 +670,8 @@ bool starcoder_eval(
         inpL = ggml_add(ctx0, cur, inpFF);
     }
 
+    ggml_set_scratch(ctx0, { 0, scr0_size, scr0, });
+
     // norm
     {
         // [ 768, N]
@@ -672,6 +686,8 @@ bool starcoder_eval(
                 ggml_repeat(ctx0, model.ln_f_b, inpL));
     }
 
+    ggml_set_scratch(ctx0, { 0, 0, nullptr, });
+
     // inpL = WTE * inpL
     // [ 768, 50257] - model.lm_head
     // [ 768, N]     - inpL
@@ -699,7 +715,7 @@ bool starcoder_eval(
     if (mem_per_token == 0) {
         mem_per_token = ggml_used_mem(ctx0)/N;
     }
-    //printf("used_mem = %zu\n", ggml_used_mem(ctx0));
+    //printf("used_mem = %zu MB\n", ggml_used_mem(ctx0)/(1024*1024));
 
     ggml_free(ctx0);
 
index 77a3d89f748e0bb67d4d82218f67bc0bc172e706..7612c86dcf06a634091fbc1af977601bed4b5394 100644 (file)
@@ -4077,7 +4077,8 @@ struct ggml_tensor * ggml_new_tensor_impl(
         };
     } else {
         if (ctx->scratch.offs + size_needed > ctx->scratch.size) {
-            GGML_PRINT("%s: not enough space in the scratch memory\n", __func__);
+            GGML_PRINT("%s: not enough space in the scratch memory pool (needed %zu, available %zu)\n",
+                    __func__, ctx->scratch.offs + size_needed, ctx->scratch.size);
             assert(false);
             return NULL;
         }