]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
save-load-state : fix example + add ci test (#3655)
authorGeorgi Gerganov <redacted>
Tue, 17 Oct 2023 16:12:46 +0000 (19:12 +0300)
committerGitHub <redacted>
Tue, 17 Oct 2023 16:12:46 +0000 (19:12 +0300)
* save-load-state : fix example (close #3606)

* ci : add test for save-load-state example

ggml-ci

ci/run.sh
examples/save-load-state/save-load-state.cpp

index 34c9129c1154c791bdc77f184f19e2265f652a31..2e33438312e850f61664ed634dded0733e4663ec 100755 (executable)
--- a/ci/run.sh
+++ b/ci/run.sh
@@ -208,6 +208,8 @@ function gg_run_open_llama_3b_v2 {
     (time ./bin/perplexity --model ${model_q5_k} -f ${wiki_test_60} -c 128 -b 128 --chunks 2 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log
     (time ./bin/perplexity --model ${model_q6_k} -f ${wiki_test_60} -c 128 -b 128 --chunks 2 ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log
 
+    (time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
+
     function check_ppl {
         qnt="$1"
         ppl=$(echo "$2" | grep -oE "[0-9]+\.[0-9]+" | tail -n 1)
@@ -296,6 +298,7 @@ function gg_sum_open_llama_3b_v2 {
     gg_printf '- q4_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q4_k.log)"
     gg_printf '- q5_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q5_k.log)"
     gg_printf '- q6_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q6_k.log)"
+    gg_printf '- save-load-state: \n```\n%s\n```\n' "$(cat $OUT/${ci}-save-load-state.log)"
     gg_printf '- shakespeare (f16):\n```\n%s\n```\n' "$(cat $OUT/${ci}-ppl-shakespeare-f16.log)"
     gg_printf '- shakespeare (f16 lora):\n```\n%s\n```\n' "$(cat $OUT/${ci}-ppl-shakespeare-lora-f16.log)"
     gg_printf '- shakespeare (q8_0):\n```\n%s\n```\n' "$(cat $OUT/${ci}-ppl-shakespeare-q8_0.log)"
@@ -382,6 +385,8 @@ function gg_run_open_llama_7b_v2 {
     (time ./bin/perplexity --model ${model_q5_k} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log
     (time ./bin/perplexity --model ${model_q6_k} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log
 
+    (time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
+
     function check_ppl {
         qnt="$1"
         ppl=$(echo "$2" | grep -oE "[0-9]+\.[0-9]+" | tail -n 1)
@@ -470,6 +475,7 @@ function gg_sum_open_llama_7b_v2 {
     gg_printf '- q4_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q4_k.log)"
     gg_printf '- q5_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q5_k.log)"
     gg_printf '- q6_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q6_k.log)"
+    gg_printf '- save-load-state: \n```\n%s\n```\n' "$(cat $OUT/${ci}-save-load-state.log)"
     gg_printf '- shakespeare (f16):\n```\n%s\n```\n' "$(cat $OUT/${ci}-ppl-shakespeare-f16.log)"
     gg_printf '- shakespeare (f16 lora):\n```\n%s\n```\n' "$(cat $OUT/${ci}-ppl-shakespeare-lora-f16.log)"
     #gg_printf '- shakespeare (q8_0):\n```\n%s\n```\n' "$(cat $OUT/${ci}-ppl-shakespeare-q8_0.log)"
index f9e3c98a38a409a06d57eb2745b94f4ac4af1c59..38d05f4d328e7ab29622849d91f014b6b5914eaa 100644 (file)
@@ -8,10 +8,7 @@
 
 int main(int argc, char ** argv) {
     gpt_params params;
-    llama_sampling_params & sparams = params.sampling_params;
-    params.seed = 42;
-    params.n_threads = 4;
-    sparams.repeat_last_n = 64;
+
     params.prompt = "The quick brown fox";
 
     if (!gpt_params_parse(argc, argv, params)) {
@@ -25,56 +22,49 @@ int main(int argc, char ** argv) {
     }
 
     auto n_past = 0;
-    auto last_n_tokens_data = std::vector<llama_token>(sparams.repeat_last_n, 0);
+
+    std::string result0;
+    std::string result1;
 
     // init
     llama_model * model;
     llama_context * ctx;
 
-    std::tie(model, ctx) = llama_init_from_gpt_params( params );
-    if (model == nullptr) {
-        return 1;
-    }
-    if (ctx == nullptr) {
-        llama_free_model(model);
+    std::tie(model, ctx) = llama_init_from_gpt_params(params);
+    if (model == nullptr || ctx == nullptr) {
+        fprintf(stderr, "%s : failed to init\n", __func__);
         return 1;
     }
+
+    // tokenize prompt
     auto tokens = llama_tokenize(ctx, params.prompt, true);
-    auto n_prompt_tokens = tokens.size();
-    if (n_prompt_tokens < 1) {
-        fprintf(stderr, "%s : failed to tokenize prompt\n", __func__);
-        llama_free(ctx);
-        llama_free_model(model);
-        return 1;
-    }
 
     // evaluate prompt
-    llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt_tokens, n_past, 0));
+    llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), n_past, 0));
+    n_past += tokens.size();
 
-    last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens);
-    n_past += n_prompt_tokens;
-
-    const size_t state_size = llama_get_state_size(ctx);
-    uint8_t * state_mem = new uint8_t[state_size];
-
-    // Save state (rng, logits, embedding and kv_cache) to file
+    // save state (rng, logits, embedding and kv_cache) to file
     {
-        FILE *fp_write = fopen("dump_state.bin", "wb");
-        llama_copy_state_data(ctx, state_mem); // could also copy directly to memory mapped file
-        fwrite(state_mem, 1, state_size, fp_write);
-        fclose(fp_write);
+        std::vector<uint8_t> state_mem(llama_get_state_size(ctx));
+
+        {
+            FILE *fp_write = fopen("dump_state.bin", "wb");
+            llama_copy_state_data(ctx, state_mem.data()); // could also copy directly to memory mapped file
+            fwrite(state_mem.data(), 1, state_mem.size(), fp_write);
+            fclose(fp_write);
+        }
     }
 
     // save state (last tokens)
-    const auto last_n_tokens_data_saved = std::vector<llama_token>(last_n_tokens_data);
     const auto n_past_saved = n_past;
 
     // first run
-    printf("\n%s", params.prompt.c_str());
+    printf("\nfirst run: %s", params.prompt.c_str());
 
     for (auto i = 0; i < params.n_predict; i++) {
         auto * logits = llama_get_logits(ctx);
         auto n_vocab = llama_n_vocab(model);
+
         std::vector<llama_token_data> candidates;
         candidates.reserve(n_vocab);
         for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
@@ -83,9 +73,10 @@ int main(int argc, char ** argv) {
         llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
         auto next_token = llama_sample_token(ctx, &candidates_p);
         auto next_token_str = llama_token_to_piece(ctx, next_token);
-        last_n_tokens_data.push_back(next_token);
 
         printf("%s", next_token_str.c_str());
+        result0 += next_token_str;
+
         if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0))) {
             fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
             llama_free(ctx);
@@ -103,32 +94,28 @@ int main(int argc, char ** argv) {
     // make new context
     auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
 
-    // Load state (rng, logits, embedding and kv_cache) from file
+    printf("\nsecond run: %s", params.prompt.c_str());
+
+    // load state (rng, logits, embedding and kv_cache) from file
     {
-        FILE *fp_read = fopen("dump_state.bin", "rb");
-        if (state_size != llama_get_state_size(ctx2)) {
-            fprintf(stderr, "\n%s : failed to validate state size\n", __func__);
-            llama_free(ctx2);
-            llama_free_model(model);
-            return 1;
-        }
+        std::vector<uint8_t> state_mem(llama_get_state_size(ctx2));
 
-        const size_t ret = fread(state_mem, 1, state_size, fp_read);
-        if (ret != state_size) {
+        FILE * fp_read = fopen("dump_state.bin", "rb");
+
+        const size_t ret = fread(state_mem.data(), 1, state_mem.size(), fp_read);
+        if (ret != state_mem.size()) {
             fprintf(stderr, "\n%s : failed to read state\n", __func__);
             llama_free(ctx2);
             llama_free_model(model);
             return 1;
         }
 
-        llama_set_state_data(ctx2, state_mem);  // could also read directly from memory mapped file
+        llama_set_state_data(ctx2, state_mem.data());
+
         fclose(fp_read);
     }
 
-    delete[] state_mem;
-
     // restore state (last tokens)
-    last_n_tokens_data = last_n_tokens_data_saved;
     n_past = n_past_saved;
 
     // second run
@@ -143,10 +130,11 @@ int main(int argc, char ** argv) {
         llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
         auto next_token = llama_sample_token(ctx2, &candidates_p);
         auto next_token_str = llama_token_to_piece(ctx2, next_token);
-        last_n_tokens_data.push_back(next_token);
 
         printf("%s", next_token_str.c_str());
-        if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0))) {
+        result1 += next_token_str;
+
+        if (llama_decode(ctx2, llama_batch_get_one(&next_token, 1, n_past, 0))) {
             fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
             llama_free(ctx2);
             llama_free_model(model);
@@ -155,10 +143,17 @@ int main(int argc, char ** argv) {
         n_past += 1;
     }
 
-    printf("\n\n");
+    printf("\n");
 
     llama_free(ctx2);
     llama_free_model(model);
 
+    if (result0 != result1) {
+        fprintf(stderr, "\n%s : error : the 2 generations are different\n", __func__);
+        return 1;
+    }
+
+    fprintf(stderr, "\n%s : success\n", __func__);
+
     return 0;
 }