]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add support for file load progress reporting callbacks (#434)
authorJed Fox <redacted>
Sat, 25 Mar 2023 05:26:28 +0000 (01:26 -0400)
committerGitHub <redacted>
Sat, 25 Mar 2023 05:26:28 +0000 (07:26 +0200)
* File load progress reporting

* Move llama_progress_handler into llama_context_params

* Renames

* Use seekg to find file size instead

* More correct load progress

* Call progress callback more frequently

* Fix typo

llama.cpp
llama.h

index 447fa91f3190b3c995a9bcaba03a9e08f23903c6..14de611a97bc7fad385e56128f05a518bbac835d 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -267,14 +267,16 @@ static void kv_cache_free(struct llama_kv_cache & cache) {
 
 struct llama_context_params llama_context_default_params() {
     struct llama_context_params result = {
-        /*.n_ctx      =*/ 512,
-        /*.n_parts    =*/ -1,
-        /*.seed       =*/ 0,
-        /*.f16_kv     =*/ false,
-        /*.logits_all =*/ false,
-        /*.vocab_only =*/ false,
-        /*.use_mlock  =*/ false,
-        /*.embedding  =*/ false,
+        /*.n_ctx                       =*/ 512,
+        /*.n_parts                     =*/ -1,
+        /*.seed                        =*/ 0,
+        /*.f16_kv                      =*/ false,
+        /*.logits_all                  =*/ false,
+        /*.vocab_only                  =*/ false,
+        /*.use_mlock                   =*/ false,
+        /*.embedding                   =*/ false,
+        /*.progress_callback           =*/ nullptr,
+        /*.progress_callback_user_data =*/ nullptr,
     };
 
     return result;
@@ -290,7 +292,9 @@ static bool llama_model_load(
         int n_ctx,
         int n_parts,
         ggml_type memory_type,
-        bool vocab_only) {
+        bool vocab_only,
+        llama_progress_callback progress_callback,
+        void *progress_callback_user_data) {
     fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
 
     const int64_t t_start_us = ggml_time_us();
@@ -576,6 +580,10 @@ static bool llama_model_load(
 
     std::vector<uint8_t> tmp;
 
+    if (progress_callback) {
+        progress_callback(0.0, progress_callback_user_data);
+    }
+
     for (int i = 0; i < n_parts; ++i) {
         const int part_id = i;
         //const int part_id = n_parts - i - 1;
@@ -589,6 +597,10 @@ static bool llama_model_load(
 
         fin = std::ifstream(fname_part, std::ios::binary);
         fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
+
+        fin.seekg(0, fin.end);
+        const size_t file_size = fin.tellg();
+
         fin.seekg(file_offset);
 
         // load weights
@@ -764,6 +776,11 @@ static bool llama_model_load(
                 model.n_loaded++;
 
                 // progress
+                if (progress_callback) {
+                    double current_file_progress = double(size_t(fin.tellg()) - file_offset) / double(file_size - file_offset);
+                    double current_progress = (double(i) + current_file_progress) / double(n_parts);
+                    progress_callback(current_progress, progress_callback_user_data);
+                }
                 if (model.n_loaded % 8 == 0) {
                     fprintf(stderr, ".");
                     fflush(stderr);
@@ -786,6 +803,10 @@ static bool llama_model_load(
 
     lctx.t_load_us = ggml_time_us() - t_start_us;
 
+    if (progress_callback) {
+        progress_callback(1.0, progress_callback_user_data);
+    }
+
     return true;
 }
 
@@ -1617,7 +1638,8 @@ struct llama_context * llama_init_from_file(
     ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
 
     if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, memory_type,
-                          params.vocab_only)) {
+                          params.vocab_only, params.progress_callback,
+                          params.progress_callback_user_data)) {
         fprintf(stderr, "%s: failed to load model\n", __func__);
         llama_free(ctx);
         return nullptr;
diff --git a/llama.h b/llama.h
index 57123dbcc85eda86f903ed0fb5be79e6af2a4156..827abc1f26005baa2d2bcd27068ef60f326cfdda 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -45,6 +45,8 @@ extern "C" {
 
     } llama_token_data;
 
+    typedef void (*llama_progress_callback)(double progress, void *ctx);
+
     struct llama_context_params {
         int n_ctx;   // text context
         int n_parts; // -1 for default
@@ -55,6 +57,11 @@ extern "C" {
         bool vocab_only; // only load the vocabulary, no weights
         bool use_mlock;  // force system to keep model in RAM
         bool embedding;  // embedding mode only
+
+        // called with a progress value between 0 and 1, pass NULL to disable
+        llama_progress_callback progress_callback;
+        // context pointer passed to the progress callback
+        void * progress_callback_user_data;
     };
 
     LLAMA_API struct llama_context_params llama_context_default_params();