]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
examples : add llama_init_from_gpt_params() common function (#1290)
authorRon Evans <redacted>
Tue, 2 May 2023 20:39:51 +0000 (22:39 +0200)
committerGitHub <redacted>
Tue, 2 May 2023 20:39:51 +0000 (23:39 +0300)
Signed-off-by: deadprogram <redacted>
examples/common.cpp
examples/common.h
examples/embedding/embedding.cpp
examples/main/main.cpp
examples/perplexity/perplexity.cpp

index 2bf0dc597eeadcbc56830fdeb85f93ef1a3e724c..9b23b1f63159b95e8a6694d4de515473a082ac7f 100644 (file)
@@ -405,6 +405,37 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
     return res;
 }
 
+struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
+    auto lparams = llama_context_default_params();
+
+    lparams.n_ctx      = params.n_ctx;
+    lparams.n_parts    = params.n_parts;
+    lparams.seed       = params.seed;
+    lparams.f16_kv     = params.memory_f16;
+    lparams.use_mmap   = params.use_mmap;
+    lparams.use_mlock  = params.use_mlock;
+
+    llama_context * lctx = llama_init_from_file(params.model.c_str(), lparams);
+
+    if (lctx == NULL) {
+        fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
+        return NULL;
+    }
+
+    if (!params.lora_adapter.empty()) {
+        int err = llama_apply_lora_from_file(lctx,
+                                             params.lora_adapter.c_str(),
+                                             params.lora_base.empty() ? NULL : params.lora_base.c_str(),
+                                             params.n_threads);
+        if (err != 0) {
+            fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
+            return NULL;
+        }
+    }
+
+    return lctx;
+}
+
 /* Keep track of current color of output, and emit ANSI code if it changes. */
 void set_console_color(console_state & con_st, console_color_t color) {
     if (con_st.use_color && con_st.color != color) {
index 627696e30a4f64cde08a15e525897ee0ee042953..138d0ded0344e4e5671da9ceeb73e5208dfb3ec7 100644 (file)
@@ -77,6 +77,12 @@ std::string gpt_random_prompt(std::mt19937 & rng);
 
 std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos);
 
+//
+// Model utils
+//
+
+struct llama_context * llama_init_from_gpt_params(const gpt_params & params);
+
 //
 // Console utils
 //
index 1e9d8a8ce75d5b67d5fa5aba9dd975e5897e4640..e4b729128b5717f0dbb52458ba4d281401cb3b55 100644 (file)
@@ -35,24 +35,10 @@ int main(int argc, char ** argv) {
     llama_context * ctx;
 
     // load the model
-    {
-        auto lparams = llama_context_default_params();
-
-        lparams.n_ctx      = params.n_ctx;
-        lparams.n_parts    = params.n_parts;
-        lparams.seed       = params.seed;
-        lparams.f16_kv     = params.memory_f16;
-        lparams.logits_all = params.perplexity;
-        lparams.use_mmap   = params.use_mmap;
-        lparams.use_mlock  = params.use_mlock;
-        lparams.embedding  = params.embedding;
-
-        ctx = llama_init_from_file(params.model.c_str(), lparams);
-
-        if (ctx == NULL) {
-            fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
-            return 1;
-        }
+    ctx = llama_init_from_gpt_params(params);
+    if (ctx == NULL) {
+        fprintf(stderr, "%s: error: unable to load model\n", __func__);
+        return 1;
     }
 
     // print system information
index 54836b3652ecf3f8d18156a8cfd4257324669e59..a10256abf362fda4ea71829095c8dd33b3d6e4fd 100644 (file)
@@ -101,34 +101,11 @@ int main(int argc, char ** argv) {
     llama_context * ctx;
     g_ctx = &ctx;
 
-    // load the model
-    {
-        auto lparams = llama_context_default_params();
-
-        lparams.n_ctx      = params.n_ctx;
-        lparams.n_parts    = params.n_parts;
-        lparams.seed       = params.seed;
-        lparams.f16_kv     = params.memory_f16;
-        lparams.use_mmap   = params.use_mmap;
-        lparams.use_mlock  = params.use_mlock;
-
-        ctx = llama_init_from_file(params.model.c_str(), lparams);
-
-        if (ctx == NULL) {
-            fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
-            return 1;
-        }
-    }
-
-    if (!params.lora_adapter.empty()) {
-        int err = llama_apply_lora_from_file(ctx,
-                                             params.lora_adapter.c_str(),
-                                             params.lora_base.empty() ? NULL : params.lora_base.c_str(),
-                                             params.n_threads);
-        if (err != 0) {
-            fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
-            return 1;
-        }
+    // load the model and apply lora adapter, if any
+    ctx = llama_init_from_gpt_params(params);
+    if (ctx == NULL) {
+        fprintf(stderr, "%s: error: unable to load model\n", __func__);
+        return 1;
     }
 
     // print system information
index d474bc50f72e4b0dde6bb2200b3d22f4ab874500..299a19999d8ce26e50b19510e204078e40d249da 100644 (file)
@@ -122,36 +122,11 @@ int main(int argc, char ** argv) {
 
     llama_context * ctx;
 
-    // load the model
-    {
-        auto lparams = llama_context_default_params();
-
-        lparams.n_ctx      = params.n_ctx;
-        lparams.n_parts    = params.n_parts;
-        lparams.seed       = params.seed;
-        lparams.f16_kv     = params.memory_f16;
-        lparams.logits_all = params.perplexity;
-        lparams.use_mmap   = params.use_mmap;
-        lparams.use_mlock  = params.use_mlock;
-        lparams.embedding  = params.embedding;
-
-        ctx = llama_init_from_file(params.model.c_str(), lparams);
-
-        if (ctx == NULL) {
-            fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
-            return 1;
-        }
-    }
-
-    if (!params.lora_adapter.empty()) {
-        int err = llama_apply_lora_from_file(ctx,
-                                             params.lora_adapter.c_str(),
-                                             params.lora_base.empty() ? NULL : params.lora_base.c_str(),
-                                             params.n_threads);
-        if (err != 0) {
-            fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
-            return 1;
-        }
+    // load the model and apply lora adapter, if any
+    ctx = llama_init_from_gpt_params(params);
+    if (ctx == NULL) {
+        fprintf(stderr, "%s: error: unable to load model\n", __func__);
+        return 1;
     }
 
     // print system information