]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
examples : add save_load_state example (#1150)
authorxaedes <redacted>
Mon, 24 Apr 2023 16:23:31 +0000 (18:23 +0200)
committerGitHub <redacted>
Mon, 24 Apr 2023 16:23:31 +0000 (19:23 +0300)
* add save_load_state example

* use <cstdio> instead of <iostream> and fprintf / printf instead of cout

* renamed save-load-state example files replacing underscores by dashes

examples/CMakeLists.txt
examples/save-load-state/CMakeLists.txt [new file with mode: 0644]
examples/save-load-state/save-load-state.cpp [new file with mode: 0644]

index 67a7cea543a4013b0717f2cd740789891c130657..be35363f58a03b8bfce0ca9cfda2ad9c92b2787d 100644 (file)
@@ -34,4 +34,5 @@ else()
     add_subdirectory(quantize-stats)
     add_subdirectory(perplexity)
     add_subdirectory(embedding)
+    add_subdirectory(save-load-state)
 endif()
diff --git a/examples/save-load-state/CMakeLists.txt b/examples/save-load-state/CMakeLists.txt
new file mode 100644 (file)
index 0000000..cff79fa
--- /dev/null
@@ -0,0 +1,4 @@
+set(TARGET save-load-state)
+add_executable(${TARGET} save-load-state.cpp)
+target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_11)
diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp
new file mode 100644 (file)
index 0000000..39aa7f8
--- /dev/null
@@ -0,0 +1,128 @@
+#include <vector>
+#include <cstdio>
+#include <chrono>
+
+#include "common.h"
+#include "llama.h"
+#include "llama.cpp"
+
+using namespace std;
+
+int main(int argc, char ** argv) {
+    gpt_params params;
+    params.model = "models/llama-7B/ggml-model.bin";
+    params.seed = 42;
+    params.n_threads = 4;
+    params.repeat_last_n = 64;
+    params.prompt = "The quick brown fox";
+
+    if (gpt_params_parse(argc, argv, params) == false) {
+        return 1;
+    }
+
+    auto lparams = llama_context_default_params();
+
+    lparams.n_ctx      = params.n_ctx;
+    lparams.n_parts    = params.n_parts;
+    lparams.seed       = params.seed;
+    lparams.f16_kv     = params.memory_f16;
+    lparams.use_mmap   = params.use_mmap;
+    lparams.use_mlock  = params.use_mlock;
+
+    auto n_past = 0;
+    auto last_n_tokens_data = vector<llama_token>(params.repeat_last_n, 0);
+
+    // init
+    auto ctx = llama_init_from_file(params.model.c_str(), lparams);
+    auto tokens = vector<llama_token>(params.n_ctx);
+    auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), tokens.size(), true);
+
+    if (n_prompt_tokens < 1) {
+        fprintf(stderr, "%s : failed to tokenize prompt\n", __func__);
+        return 1;
+    }
+
+    // evaluate prompt
+
+    llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads);
+
+    last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens);
+    n_past += n_prompt_tokens;
+
+    // Save state (rng, logits, embedding and kv_cache) to file
+    FILE *fp_write = fopen("dump_state.bin", "wb");
+    auto state_size = llama_get_state_size(ctx);
+    auto state_mem = new uint8_t[state_size];
+    llama_copy_state_data(ctx, state_mem); // could also copy directly to memory mapped file
+    fwrite(state_mem, 1, state_size, fp_write);
+    fclose(fp_write);
+
+    // save state (last tokens)
+    auto last_n_tokens_data_saved = vector<llama_token>(last_n_tokens_data);
+    auto n_past_saved = n_past;
+
+    // first run
+    printf("\n%s", params.prompt.c_str());
+    for (auto i = 0; i < params.n_predict; i++) {
+        auto next_token = llama_sample_top_p_top_k(
+            ctx,
+            &last_n_tokens_data.back() - params.repeat_last_n,
+            params.repeat_last_n,
+            40,
+            1.0,
+            1.0,
+            1.1);
+        auto next_token_str = llama_token_to_str(ctx, next_token);
+        last_n_tokens_data.push_back(next_token);
+        printf("%s", next_token_str);
+        if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) {
+            fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
+            return 1;
+        }
+        n_past += 1;
+    }
+    printf("\n\n");
+
+    // free old model
+    llama_free(ctx);
+
+    // load new model
+
+    auto ctx2 = llama_init_from_file(params.model.c_str(), lparams);
+
+    // Load state (rng, logits, embedding and kv_cache) from file
+    FILE *fp_read = fopen("dump_state.bin", "rb");
+    auto state_size2 = llama_get_state_size(ctx2);
+    if (state_size != state_size2) {
+        fprintf(stderr, "\n%s : failed to validate state size\n", __func__);
+    }
+    fread(state_mem, 1, state_size, fp_read);
+    llama_set_state_data(ctx2, state_mem);  // could also read directly from memory mapped file
+    fclose(fp_read);
+
+    // restore state (last tokens)
+    last_n_tokens_data = last_n_tokens_data_saved;
+    n_past = n_past_saved;
+
+    // second run
+    for (auto i = 0; i < params.n_predict; i++) {
+        auto next_token = llama_sample_top_p_top_k(
+            ctx2,
+            &last_n_tokens_data.back() - params.repeat_last_n,
+            params.repeat_last_n,
+            40,
+            1.0,
+            1.0,
+            1.1);
+        auto next_token_str = llama_token_to_str(ctx2, next_token);
+        last_n_tokens_data.push_back(next_token);
+        printf("%s", next_token_str);
+        if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) {
+            fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
+            return 1;
+        }
+        n_past += 1;
+    }
+    printf("\n\n");
+    return 0;
+}