]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
main : print total token count and tokens consumed so far (#4874)
authorpudepiedj <redacted>
Thu, 11 Jan 2024 16:14:52 +0000 (16:14 +0000)
committerGitHub <redacted>
Thu, 11 Jan 2024 16:14:52 +0000 (18:14 +0200)
* Token count changes

* Add show token count

* Updating before PR

* Two requested changes

* Move param def posn

common/common.cpp
common/common.h
examples/main/main.cpp
llama.cpp

index 4e89fe516e0a9bfc7b93f1bdd1e84a708ccc6dfe..bfcd6d4dfe5d15968907ab00ff6f6b24354c780c 100644 (file)
@@ -630,6 +630,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.ppl_stride = std::stoi(argv[i]);
+        } else if (arg == "-stc" || arg == "--show_token_count") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.token_interval = std::stoi(argv[i]);
         } else if (arg == "--ppl-output-type") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -944,6 +950,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  --override-kv KEY=TYPE:VALUE\n");
     printf("                        advanced option to override model metadata by key. may be specified multiple times.\n");
     printf("                        types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
+    printf("  -stc N --show_token_count N\n");
+    printf("                        show consumed tokens every N tokens\n");
     printf("\n");
 #ifndef LOG_DISABLE_LOGS
     log_print_usage();
index e2bbfc258b6467cb24e5d40a6e28cd54ab148368..a295e88b05044f732eb22b5c19ccb46232cb2cbf 100644 (file)
@@ -64,6 +64,7 @@ struct gpt_params {
     int32_t n_beams                         = 0;     // if non-zero then use beam search of given width.
     int32_t grp_attn_n                      = 1;     // group-attention factor
     int32_t grp_attn_w                      = 512;   // group-attention width
+    int32_t token_interval                  = 512;   // show token count every 512 tokens
     float   rope_freq_base                  = 0.0f;  // RoPE base frequency
     float   rope_freq_scale                 = 0.0f;  // RoPE frequency scaling factor
     float   yarn_ext_factor                 = -1.0f; // YaRN extrapolation mix factor
@@ -242,4 +243,3 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);
 
 // Dump the KV cache view showing individual sequences in each cell (long output).
 void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
-
index 5ea67051f36546aa5d11c6c7fb80489a94e44252..1f35febbd181a614b2738b330a8d86b860c39846 100644 (file)
@@ -500,7 +500,7 @@ int main(int argc, char ** argv) {
     while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
         // predict
         if (!embd.empty()) {
-            // Note: n_ctx - 4 here is to match the logic for commandline prompt handling via
+            // Note: (n_ctx - 4) here is to match the logic for commandline prompt handling via
             // --prompt or --file which uses the same value.
             int max_embd_size = n_ctx - 4;
 
@@ -650,6 +650,10 @@ int main(int argc, char ** argv) {
                 n_past += n_eval;
 
                 LOG("n_past = %d\n", n_past);
+                // Display total tokens alongside total time
+                if (n_past % params.token_interval == 0) {
+                    printf("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx);
+                }
             }
 
             if (!embd.empty() && !path_session.empty()) {
index e1f1932baecf15721e35ad164b41831eeac45637..aaadfa444637e9a785ff89144d29b875bf3411cb 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -10921,7 +10921,7 @@ void llama_print_timings(struct llama_context * ctx) {
             __func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval);
     LLAMA_LOG_INFO("%s:        eval time = %10.2f ms / %5d runs   (%8.2f ms per token, %8.2f tokens per second)\n",
             __func__, timings.t_eval_ms, timings.n_eval, timings.t_eval_ms / timings.n_eval, 1e3 / timings.t_eval_ms * timings.n_eval);
-    LLAMA_LOG_INFO("%s:       total time = %10.2f ms\n", __func__, (timings.t_end_ms - timings.t_start_ms));
+    LLAMA_LOG_INFO("%s:       total time = %10.2f ms / %5d tokens\n", __func__, (timings.t_end_ms - timings.t_start_ms), (timings.n_p_eval + timings.n_eval));
 }
 
 void llama_reset_timings(struct llama_context * ctx) {