]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Reduce memory usage and allocate enough memory for largest context (#473)
authorGeorgi Gerganov <redacted>
Fri, 24 Mar 2023 21:17:37 +0000 (23:17 +0200)
committerGitHub <redacted>
Fri, 24 Mar 2023 21:17:37 +0000 (23:17 +0200)
* Reduce memory usage and allocate enough memory for large contexts

* Simpler scratch buffer usage

* Reenable BLAS for quantized mul_mat

* Fix number of layers in 30B and 65B

* Fix KV cache size for F32

ggml.c
llama.cpp
main.cpp
utils.cpp
utils.h

diff --git a/ggml.c b/ggml.c
index 92b857a0007ac53dfa82873d36e017626c5c4176..cfdf427df1249604f60d813f003e15f2f8911a1d 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -5846,7 +5846,8 @@ static bool ggml_compute_forward_mul_mat_use_blas(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
               struct ggml_tensor * dst) {
-    UNUSED(src0);
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
 
     const int ne10 = src1->ne[0];
 
@@ -5856,7 +5857,14 @@ static bool ggml_compute_forward_mul_mat_use_blas(
     // TODO: find the optimal values for these
     if (ggml_is_contiguous(src0) &&
         ggml_is_contiguous(src1) && ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) {
-        //printf("BLAS: %d %d %d\n", ne0, ne1, ne10);
+
+        //// disable BLAS for Q4_0 and Q4_1
+        //// looks like there is no benefit and we only waste a lot of memory
+        //if (src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1) {
+        //    return false;
+        //}
+
+        //printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);
         return true;
     }
 
index 9a93409cccb4cb530a09e76716d49df2428678c5..9d48ccd4c79e38318d3dea3e8e938ca4f652df45 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -5,12 +5,25 @@
 #include <cinttypes>
 #include <fstream>
 #include <random>
+#include <map>
 #include <unordered_map>
 #include <queue>
 #include <regex>
 #include <cassert>
 #include <cstring>
 
+#define LLAMA_USE_SCRATCH
+#define LLAMA_MAX_SCRATCH_BUFFERS 16
+
+#define LLAMA_ASSERT(x) \
+    do { \
+        if (!(x)) { \
+            fprintf(stderr, "LLAMA_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
+            abort(); \
+        } \
+    } while (0)
+
+
 // determine number of model parts based on the dimension
 static const std::unordered_map<int, int> LLAMA_N_PARTS = {
     { 4096, 1 },
@@ -19,6 +32,52 @@ static const std::unordered_map<int, int> LLAMA_N_PARTS = {
     { 8192, 8 },
 };
 
+// available llama models
+enum e_model {
+    MODEL_UNKNOWN,
+    MODEL_7B,
+    MODEL_13B,
+    MODEL_30B,
+    MODEL_65B,
+};
+
+static const size_t MB = 1024*1024;
+
+// computed for n_ctx == 2048
+// TODO: dynamically determine these sizes
+//       needs modifications in ggml
+
+static const std::map<e_model, size_t> MEM_REQ_SCRATCH0 = {
+    { MODEL_7B,    512ull*MB },
+    { MODEL_13B,   512ull*MB },
+    { MODEL_30B,   512ull*MB },
+    { MODEL_65B,   512ull*MB },
+};
+
+static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
+    { MODEL_7B,    512ull*MB },
+    { MODEL_13B,   512ull*MB },
+    { MODEL_30B,   512ull*MB },
+    { MODEL_65B,   512ull*MB },
+};
+
+// 2*n_embd*n_ctx*n_layer*sizeof(float16)
+static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
+    { MODEL_7B,   1026ull*MB },
+    { MODEL_13B,  1608ull*MB },
+    { MODEL_30B,  3124ull*MB },
+    { MODEL_65B,  5120ull*MB },
+};
+
+// this is mostly needed for temporary mul_mat buffers to dequantize the data
+// not actually needed if BLAS is disabled
+static const std::map<e_model, size_t> MEM_REQ_EVAL = {
+    { MODEL_7B,   768ull*MB },
+    { MODEL_13B, 1024ull*MB },
+    { MODEL_30B, 1280ull*MB },
+    { MODEL_65B, 1536ull*MB },
+};
+
 // default hparams (LLaMA 7B)
 struct llama_hparams {
     int32_t n_vocab = 32000;
@@ -50,7 +109,20 @@ struct llama_layer {
     struct ggml_tensor * w3;
 };
 
+struct llama_kv_cache {
+    struct ggml_tensor * k;
+    struct ggml_tensor * v;
+
+    struct ggml_context * ctx;
+
+    std::vector<uint8_t> buf;
+
+    int n; // number of tokens currently in the cache
+};
+
 struct llama_model {
+    e_model type = MODEL_UNKNOWN;
+
     llama_hparams hparams;
 
     struct ggml_tensor * tok_embeddings;
@@ -60,12 +132,18 @@ struct llama_model {
 
     std::vector<llama_layer> layers;
 
-    // key + value memory
-    struct ggml_tensor * memory_k;
-    struct ggml_tensor * memory_v;
-
-    //
+    // context
     struct ggml_context * ctx;
+
+    // key + value cache for the self attention
+    // TODO: move to llama_state
+    struct llama_kv_cache kv_self;
+
+    // the model memory buffer
+    std::vector<uint8_t> buf;
+
+    // tensors
+    int n_loaded;
     std::unordered_map<std::string, struct ggml_tensor *> tensors;
 };
 
@@ -105,8 +183,88 @@ struct llama_context {
 
     // input embedding (1-dimensional array: [n_embd])
     std::vector<float> embedding;
+
+    // memory buffers used to evaluate the model
+    // TODO: move in llama_state
+    std::vector<uint8_t> buf_compute;
+    std::vector<uint8_t> buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS];
+
+    int    buf_last = 0;
+    size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 };
+
+    void use_buf(struct ggml_context * ctx, int i) {
+#if defined(LLAMA_USE_SCRATCH)
+        size_t last_size = 0;
+
+        if (i == -1) {
+            last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, });
+        } else {
+            auto & buf = buf_scratch[i];
+            last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), });
+        }
+
+        if (buf_last >= 0) {
+            buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
+        }
+
+        buf_last = i;
+#else
+        (void) i;
+        (void) ctx;
+#endif
+    }
+
+    size_t get_buf_max_mem(int i) const {
+#if defined(LLAMA_USE_SCRATCH)
+        return buf_max_size[i];
+#else
+        (void) i;
+        return 0;
+#endif
+    }
 };
 
+//
+// kv cache
+//
+
+static bool kv_cache_init(
+        const struct llama_hparams & hparams,
+             struct llama_kv_cache & cache,
+                         ggml_type   wtype,
+                               int   n_ctx) {
+    const int n_embd  = hparams.n_embd;
+    const int n_layer = hparams.n_layer;
+
+    const int n_mem      = n_layer*n_ctx;
+    const int n_elements = n_embd*n_mem;
+
+    cache.buf.resize(2*n_elements*ggml_type_size(wtype) + 2u*MB);
+
+    struct ggml_init_params params;
+    params.mem_size   = cache.buf.size();
+    params.mem_buffer = cache.buf.data();
+
+    cache.ctx = ggml_init(params);
+
+    if (!cache.ctx) {
+        fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
+        return false;
+    }
+
+    cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
+    cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
+
+    return true;
+}
+
+static void kv_cache_free(struct llama_kv_cache & cache) {
+    if (cache.ctx) {
+        ggml_free(cache.ctx);
+        cache.ctx = nullptr;
+    }
+}
+
 struct llama_context_params llama_context_default_params() {
     struct llama_context_params result = {
         /*.n_ctx      =*/ 512,
@@ -204,6 +362,22 @@ static bool llama_model_load(
             fprintf(stderr, "%s: use '--n_parts 1' if necessary\n", __func__);
         }
 
+        if (hparams.n_layer == 32) {
+            model.type = e_model::MODEL_7B;
+        }
+
+        if (hparams.n_layer == 40) {
+            model.type = e_model::MODEL_13B;
+        }
+
+        if (hparams.n_layer == 60) {
+            model.type = e_model::MODEL_30B;
+        }
+
+        if (hparams.n_layer == 80) {
+            model.type = e_model::MODEL_65B;
+        }
+
         fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
         fprintf(stderr, "%s: n_ctx   = %d\n", __func__, hparams.n_ctx);
         fprintf(stderr, "%s: n_embd  = %d\n", __func__, hparams.n_embd);
@@ -214,6 +388,7 @@ static bool llama_model_load(
         fprintf(stderr, "%s: f16     = %d\n", __func__, hparams.f16);
         fprintf(stderr, "%s: n_ff    = %d\n", __func__, n_ff);
         fprintf(stderr, "%s: n_parts = %d\n", __func__, n_parts);
+        fprintf(stderr, "%s: type    = %d\n", __func__, model.type);
     }
 
     // load vocab
@@ -307,11 +482,32 @@ static bool llama_model_load(
         fprintf(stderr, "%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
     }
 
+    // print memory requirements
+    {
+        const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1;
+
+        // this is the total memory required to run the inference
+        const size_t mem_required =
+            ctx_size +
+            MEM_REQ_SCRATCH0.at(model.type) +
+            MEM_REQ_SCRATCH1.at(model.type) +
+            MEM_REQ_EVAL.at    (model.type);
+
+        // this is the memory required by one llama_state
+        const size_t mem_required_state =
+            scale*MEM_REQ_KV_SELF.at(model.type);
+
+        fprintf(stderr, "%s: mem required  = %7.2f MB (+ %7.2f MB per state)\n", __func__,
+                mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
+    }
+
     // create the ggml context
     {
+        lctx.model.buf.resize(ctx_size);
+
         struct ggml_init_params params = {
-            /*.mem_size   =*/ ctx_size,
-            /*.mem_buffer =*/ NULL,
+            /*.mem_size   =*/ lctx.model.buf.size(),
+            /*.mem_buffer =*/ lctx.model.buf.data(),
         };
 
         model.ctx = ggml_init(params);
@@ -374,25 +570,6 @@ static bool llama_model_load(
         }
     }
 
-    // key + value memory
-    {
-        const auto & hparams = model.hparams;
-
-        const int n_embd  = hparams.n_embd;
-        const int n_layer = hparams.n_layer;
-        const int n_ctx   = hparams.n_ctx;
-
-        const int n_mem      = n_layer*n_ctx;
-        const int n_elements = n_embd*n_mem;
-
-        model.memory_k = ggml_new_tensor_1d(ctx, memory_type, n_elements);
-        model.memory_v = ggml_new_tensor_1d(ctx, memory_type, n_elements);
-
-        const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
-
-        fprintf(stderr, "%s: memory_size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);
-    }
-
     const size_t file_offset = fin.tellg();
 
     fin.close();
@@ -416,9 +593,10 @@ static bool llama_model_load(
 
         // load weights
         {
-            int n_tensors = 0;
             size_t total_size = 0;
 
+            model.n_loaded = 0;
+
             fprintf(stderr, "%s: ", __func__);
 
             while (true) {
@@ -583,7 +761,10 @@ static bool llama_model_load(
                 }
 
                 //fprintf(stderr, "%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
-                if (++n_tensors % 8 == 0) {
+                model.n_loaded++;
+
+                // progress
+                if (model.n_loaded % 8 == 0) {
                     fprintf(stderr, ".");
                     fflush(stderr);
                 }
@@ -591,7 +772,13 @@ static bool llama_model_load(
 
             fprintf(stderr, " done\n");
 
-            fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors);
+            fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, model.n_loaded);
+            if (model.n_loaded == 0) {
+                fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
+            } else if (model.n_loaded != (int) model.tensors.size()) {
+                fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
+                return false;
+            }
         }
 
         fin.close();
@@ -622,6 +809,10 @@ static bool llama_eval_internal(
     const auto & model   = lctx.model;
     const auto & hparams = model.hparams;
 
+    auto & kv_self = model.kv_self;
+
+    LLAMA_ASSERT(!!kv_self.ctx);
+
     const int n_embd  = hparams.n_embd;
     const int n_layer = hparams.n_layer;
     const int n_ctx   = hparams.n_ctx;
@@ -630,27 +821,11 @@ static bool llama_eval_internal(
     const int n_rot   = hparams.n_embd/hparams.n_head;
 
     auto & mem_per_token = lctx.mem_per_token;
-
-    // TODO: fix this hardcoded size
-    static size_t buf_size = 2048u*1024*1024; // TMP !!!
-    static void * buf = malloc(buf_size);
-
-    if (mem_per_token > 0 && mem_per_token*N > buf_size) {
-        const size_t buf_size_new = 1.3*(mem_per_token*N); // add 30% to account for ggml object overhead
-        //fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
-
-        // reallocate
-        buf_size = buf_size_new;
-        buf = realloc(buf, buf_size);
-        if (buf == nullptr) {
-            fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
-            return false;
-        }
-    }
+    auto & buf_compute   = lctx.buf_compute;
 
     struct ggml_init_params params = {
-        /*.mem_size   =*/ buf_size,
-        /*.mem_buffer =*/ buf,
+        /*.mem_size   =*/ buf_compute.size(),
+        /*.mem_buffer =*/ buf_compute.data(),
     };
 
     struct ggml_context * ctx0 = ggml_init(params);
@@ -667,6 +842,8 @@ static bool llama_eval_internal(
 
         struct ggml_tensor * cur;
 
+        lctx.use_buf(ctx0, 0);
+
         // norm
         {
             cur = ggml_rms_norm(ctx0, inpL);
@@ -685,8 +862,8 @@ static bool llama_eval_internal(
 
             // store key and value to memory
             if (N >= 1) {
-                struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
-                struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
+                struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
+                struct ggml_tensor * v = ggml_view_1d(ctx0, kv_self.v, N*n_embd, (ggml_element_size(kv_self.v)*n_embd)*(il*n_ctx + n_past));
 
                 ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
                 ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
@@ -707,7 +884,7 @@ static bool llama_eval_internal(
                 ggml_permute(ctx0,
                         ggml_rope(ctx0,
                             ggml_reshape_3d(ctx0,
-                                ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
+                                ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
                                 n_embd/n_head, n_head, n_past + N),
                             n_past, n_rot, 1),
                         0, 2, 1, 3);
@@ -733,7 +910,7 @@ static bool llama_eval_internal(
                 ggml_cpy(ctx0,
                     ggml_permute(ctx0,
                             ggml_reshape_3d(ctx0,
-                                ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
+                                ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.v)*n_embd),
                                 n_embd/n_head, n_head, n_past + N),
                             1, 2, 0, 3),
                     ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head));
@@ -755,6 +932,8 @@ static bool llama_eval_internal(
                     cur);
         }
 
+        lctx.use_buf(ctx0, 1);
+
         struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
 
         // feed-forward network
@@ -773,7 +952,6 @@ static bool llama_eval_internal(
                     model.layers[il].w3,
                     cur);
 
-
             cur = ggml_mul_mat(ctx0,
                     model.layers[il].w1,
                     cur);
@@ -788,17 +966,20 @@ static bool llama_eval_internal(
                     cur);
         }
 
-        cur  = ggml_add(ctx0, cur, inpFF);
+        cur = ggml_add(ctx0, cur, inpFF);
 
         // input for next layer
         inpL = cur;
     }
 
+    lctx.use_buf(ctx0, 0);
+
     // used at the end to optionally extract the embeddings
     struct ggml_tensor * embeddings = NULL;
 
     // norm
     {
+
         inpL = ggml_rms_norm(ctx0, inpL);
 
         // inpL = norm*inpL
@@ -810,9 +991,9 @@ static bool llama_eval_internal(
     }
 
     // lm_head
-    {
-        inpL = ggml_mul_mat(ctx0, model.output, inpL);
-    }
+    inpL = ggml_mul_mat(ctx0, model.output, inpL);
+
+    lctx.use_buf(ctx0, -1);
 
     // logits -> probs
     //inpL = ggml_soft_max(ctx0, inpL);
@@ -854,7 +1035,13 @@ static bool llama_eval_internal(
     if (mem_per_token == 0) {
         mem_per_token = ggml_used_mem(ctx0)/N;
     }
-    //fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0));
+
+#if 0
+    printf("\n%s: used_mem = %.3f MB, scratch -- %.3f MB %.3f MB\n", __func__,
+            ggml_used_mem(ctx0)/1024.0/1024.0,
+            lctx.get_buf_max_mem(0)/1024.0/1024.0,
+            lctx.get_buf_max_mem(1)/1024.0/1024.0);
+#endif
 
     ggml_free(ctx0);
 
@@ -1427,9 +1614,9 @@ struct llama_context * llama_init_from_file(
     ctx->rng = std::mt19937(params.seed);
     ctx->logits_all = params.logits_all;
 
-    ggml_type type_memory = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
+    ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
 
-    if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory,
+    if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, memory_type,
                           params.vocab_only)) {
         fprintf(stderr, "%s: failed to load model\n", __func__);
         llama_free(ctx);
@@ -1448,6 +1635,17 @@ struct llama_context * llama_init_from_file(
 
     // reserve memory for context buffers
     {
+        if (!kv_cache_init(ctx->model.hparams, ctx->model.kv_self, memory_type, ctx->model.hparams.n_ctx)) {
+            fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
+            llama_free(ctx);
+            return nullptr;
+        }
+
+        {
+            const size_t memory_size = ggml_nbytes(ctx->model.kv_self.k) + ggml_nbytes(ctx->model.kv_self.v);
+            fprintf(stderr, "%s: kv self size  = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
+        }
+
         const auto & hparams = ctx->model.hparams;
         if (params.logits_all) {
             ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
@@ -1458,12 +1656,19 @@ struct llama_context * llama_init_from_file(
         if (params.embedding){
             ctx->embedding.reserve(hparams.n_embd);
         }
+
+        ctx->buf_compute.resize(MEM_REQ_EVAL.at(ctx->model.type));
+
+        ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type));
+        ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type));
     }
 
     return ctx;
 }
 
 void llama_free(struct llama_context * ctx) {
+    kv_cache_free(ctx->model.kv_self);
+
     if (ctx->model.ctx) {
         ggml_free(ctx->model.ctx);
     }
@@ -1619,4 +1824,3 @@ const char * llama_print_system_info(void) {
 
     return s.c_str();
 }
-
index 44437750eee2d89481634b1b5f9deaa5154c5b88..bc71a5494b2317909c911b1479a0bb7bbea50abb 100644 (file)
--- a/main.cpp
+++ b/main.cpp
@@ -217,11 +217,23 @@ int main(int argc, char ** argv) {
                 params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
     }
 
-    // determine the required inference memory per token:
-    // TODO: better way to do that
-    {
-        const std::vector<llama_token> tmp = { 0, 1, 2, 3 };
-        llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads);
+    // determine the maximum memory usage needed to do inference for the given n_batch and n_predict parameters
+    // uncomment the "used_mem" line in llama.cpp to see the results
+    if (params.mem_test) {
+        {
+            const std::vector<llama_token> tmp(params.n_batch, 0);
+            llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads);
+        }
+
+        {
+            const std::vector<llama_token> tmp = { 0, };
+            llama_eval(ctx, tmp.data(), tmp.size(), params.n_predict - 1, params.n_threads);
+        }
+
+        llama_print_timings(ctx);
+        llama_free(ctx);
+
+        return 0;
     }
 
     if (params.perplexity) {
@@ -508,7 +520,6 @@ int main(int argc, char ** argv) {
 #endif
 
     llama_print_timings(ctx);
-
     llama_free(ctx);
 
     set_console_state(CONSOLE_STATE_DEFAULT);
index 10673fb825ddf436baf5d58265db67a430c48850..2f995c12d8da03375d2fa38f631fd1e95e9c14a9 100644 (file)
--- a/utils.cpp
+++ b/utils.cpp
@@ -79,8 +79,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.n_ctx = std::stoi(argv[i]);
-        } else if (arg == "--memory_f16") {
-            params.memory_f16 = true;
+        } else if (arg == "--memory_f32") {
+            params.memory_f16 = false;
         } else if (arg == "--top_p") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -111,6 +111,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.n_batch = std::stoi(argv[i]);
+            params.n_batch = std::min(512, params.n_batch);
         } else if (arg == "-m" || arg == "--model") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -131,6 +132,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
             params.use_color = true;
         } else if (arg == "--mlock") {
             params.use_mlock = true;
+        } else if (arg == "--mtest") {
+            params.mem_test = true;
         } else if (arg == "-r" || arg == "--reverse-prompt") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -193,7 +196,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stderr, "  --repeat_penalty N    penalize repeat sequence of tokens (default: %.1f)\n", params.repeat_penalty);
     fprintf(stderr, "  -c N, --ctx_size N    size of the prompt context (default: %d)\n", params.n_ctx);
     fprintf(stderr, "  --ignore-eos          ignore end of stream token and continue generating\n");
-    fprintf(stderr, "  --memory_f16          use f16 instead of f32 for memory key+value\n");
+    fprintf(stderr, "  --memory_f32          use f32 instead of f16 for memory key+value\n");
     fprintf(stderr, "  --temp N              temperature (default: %.1f)\n", params.temp);
     fprintf(stderr, "  --n_parts N           number of model parts (default: -1 = determine from dimensions)\n");
     fprintf(stderr, "  -b N, --batch_size N  batch size for prompt processing (default: %d)\n", params.n_batch);
@@ -201,6 +204,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     if (ggml_mlock_supported()) {
         fprintf(stderr, "  --mlock               force system to keep model in RAM rather than swapping or compressing\n");
     }
+    fprintf(stderr, "  --mtest               compute maximum memory usage\n");
     fprintf(stderr, "  -m FNAME, --model FNAME\n");
     fprintf(stderr, "                        model path (default: %s)\n", params.model.c_str());
     fprintf(stderr, "\n");
diff --git a/utils.h b/utils.h
index cf914990cdf7eca4f16b75a3ca79adbc161cd2bd..d469bc6a0b41682a3d4659dcb824bfb1d0f6de43 100644 (file)
--- a/utils.h
+++ b/utils.h
 //
 
 struct gpt_params {
-    int32_t seed          = -1;  // RNG seed
+    int32_t seed          = -1;   // RNG seed
     int32_t n_threads     = std::min(4, (int32_t) std::thread::hardware_concurrency());
-    int32_t n_predict     = 128; // new tokens to predict
-    int32_t repeat_last_n = 64;  // last n tokens to penalize
-    int32_t n_parts       = -1;  // amount of model parts (-1 = determine from model dimensions)
-    int32_t n_ctx         = 512; //context size
+    int32_t n_predict     = 128;  // new tokens to predict
+    int32_t repeat_last_n = 64;   // last n tokens to penalize
+    int32_t n_parts       = -1;   // amount of model parts (-1 = determine from model dimensions)
+    int32_t n_ctx         = 512;  // context size
+    int32_t n_batch       = 8;    // batch size for prompt processing
 
     // sampling parameters
     int32_t top_k = 40;
@@ -27,15 +28,13 @@ struct gpt_params {
     float   temp  = 0.80f;
     float   repeat_penalty  = 1.10f;
 
-    int32_t n_batch = 8; // batch size for prompt processing
-
     std::string model  = "models/lamma-7B/ggml-model.bin"; // model path
     std::string prompt = "";
 
 
     std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
 
-    bool memory_f16        = false; // use f16 instead of f32 for memory kv
+    bool memory_f16        = true;  // use f16 instead of f32 for memory kv
     bool random_prompt     = false; // do not randomize prompt if none provided
     bool use_color         = false; // use color to distinguish generations and inputs
     bool interactive       = false; // interactive mode
@@ -47,6 +46,7 @@ struct gpt_params {
     bool ignore_eos        = false; // do not stop generating after eos
     bool perplexity        = false; // compute perplexity over the prompt
     bool use_mlock         = false; // use mlock to keep model in memory
+    bool mem_test          = false; // compute maximum memory usage
 };
 
 bool gpt_params_parse(int argc, char ** argv, gpt_params & params);