]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : more consistent names of count variables (#5994)
authorGeorgi Gerganov <redacted>
Mon, 11 Mar 2024 15:49:47 +0000 (17:49 +0200)
committerGitHub <redacted>
Mon, 11 Mar 2024 15:49:47 +0000 (17:49 +0200)
* llama : more consistent names of count variables

ggml-ci

* llama : n_parallel -> n_seq_max

* common : fix param name

* examples : fix param name

README.md
common/common.cpp
examples/batched-bench/batched-bench.cpp
examples/batched/batched.cpp
examples/main/main.cpp
examples/perplexity/perplexity.cpp
llama.cpp
llama.h

index 98fdc6808ffe32766643235886cb8df721787ed6..54bf84bec67bb3ebd67ad4d8f09c0440affa796d 100644 (file)
--- a/README.md
+++ b/README.md
@@ -10,7 +10,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
 
 ### Recent API changes
 
-- [2024 Mar 8] `llama_kv_cache_seq_rm()` returns a `bool` instead of `void`, and new `llama_n_max_seq()` returns the upper limit of acceptable `seq_id` in batches (relevant when dealing with multiple sequences) https://github.com/ggerganov/llama.cpp/pull/5328
+- [2024 Mar 8] `llama_kv_cache_seq_rm()` returns a `bool` instead of `void`, and new `llama_n_seq_max()` returns the upper limit of acceptable `seq_id` in batches (relevant when dealing with multiple sequences) https://github.com/ggerganov/llama.cpp/pull/5328
 - [2024 Mar 4] Embeddings API updated https://github.com/ggerganov/llama.cpp/pull/5796
 - [2024 Mar 3] `struct llama_context_params` https://github.com/ggerganov/llama.cpp/pull/5849
 
index 16ef4d7f74dd99979a525350226bef0a14abebdc..2f38ac632b45a91c9444bb56ae3030c7a6261a80 100644 (file)
@@ -1288,7 +1288,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
 
     cparams.n_ctx             = params.n_ctx;
     cparams.n_batch           = params.n_batch;
-    cparams.n_parallel        = params.n_parallel;
+    cparams.n_seq_max         = params.n_parallel;
     cparams.n_threads         = params.n_threads;
     cparams.n_threads_batch   = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
     cparams.seed              = params.seed;
@@ -1786,17 +1786,17 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size) {
     static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+";
 
     printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d",
-        view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
+        view.n_cells, view.n_seq_max, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
 
     llama_kv_cache_view_cell * c_curr = view.cells;
     llama_seq_id * cs_curr = view.cells_sequences;
 
-    for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
+    for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
         if (i % row_size == 0) {
             printf("\n%5d: ", i);
         }
         int seq_count = 0;
-        for (int j = 0; j < view.n_max_seq; j++) {
+        for (int j = 0; j < view.n_seq_max; j++) {
             if (cs_curr[j] >= 0) { seq_count++; }
         }
         putchar(slot_chars[std::min(sizeof(slot_chars) - 2, size_t(seq_count))]);
@@ -1809,14 +1809,14 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
     static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
 
     printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n",
-        view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
+        view.n_cells, view.n_seq_max, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
 
     std::unordered_map<llama_seq_id, size_t> seqs;
     llama_kv_cache_view_cell * c_curr = view.cells;
     llama_seq_id * cs_curr = view.cells_sequences;
 
-    for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
-        for (int j = 0; j < view.n_max_seq; j++) {
+    for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
+        for (int j = 0; j < view.n_seq_max; j++) {
             if (cs_curr[j] < 0) { continue; }
             if (seqs.find(cs_curr[j]) == seqs.end()) {
                 if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
@@ -1835,11 +1835,11 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
 
     c_curr = view.cells;
     cs_curr = view.cells_sequences;
-    for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
+    for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
         if (i % row_size == 0) {
             printf("\n%5d: ", i);
         }
-        for (int j = 0; j < view.n_max_seq; j++) {
+        for (int j = 0; j < view.n_seq_max; j++) {
             if (cs_curr[j] >= 0) {
                 const auto & it = seqs.find(cs_curr[j]);
                 putchar(it != seqs.end() ? int(slot_chars[it->second]) : '+');
index dff6c68ec2e696d081d15202cba3486d2f513558..22bc93bca817cb5321eb1a24929e98c56a1b5059 100644 (file)
@@ -106,7 +106,7 @@ int main(int argc, char ** argv) {
     ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
 
     // ensure enough sequences are available
-    ctx_params.n_parallel = *std::max_element(n_pl.begin(), n_pl.end());
+    ctx_params.n_seq_max = *std::max_element(n_pl.begin(), n_pl.end());
 
     llama_context * ctx = llama_new_context_with_model(model, ctx_params);
 
index dde4d5a068e24695bc3bc6f01cdd9e4a5c876724..ee1f8f1bf5dd2245e78f833245d2aee1b48a4b9c 100644 (file)
@@ -80,7 +80,7 @@ int main(int argc, char ** argv) {
     ctx_params.seed  = 1234;
     ctx_params.n_ctx = n_kv_req;
     ctx_params.n_batch = std::max(n_len, n_parallel);
-    ctx_params.n_parallel      = n_parallel;
+    ctx_params.n_seq_max       = n_parallel;
     ctx_params.n_threads       = params.n_threads;
     ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
 
index 47059e582a0d499493e7e8d4d73534d39ba88830..e2d07a6319d506950443fca905293fe686d0c5e6 100644 (file)
@@ -878,6 +878,7 @@ int main(int argc, char ** argv) {
                     const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
                     const auto line_inp = ::llama_tokenize(ctx, buffer,              false, false);
                     const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true);
+
                     LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());
 
                     embd_inp.insert(embd_inp.end(), line_pfx.begin(), line_pfx.end());
index 293eb52c33653e3bfdc062509fcac5e2bdf0ad33..fdfc8f5dcd3e5e66d3aa50062e19edd061e95d7a 100644 (file)
@@ -841,7 +841,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
     const int n_batch = params.n_batch;
 
     const int max_tasks_per_batch = 32;
-    const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_max_seq(ctx));
+    const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
 
     llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
 
@@ -1118,7 +1118,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
     const int n_batch = params.n_batch;
 
     const int max_tasks_per_batch = 128;
-    const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_max_seq(ctx));
+    const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
 
     llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
 
@@ -1470,7 +1470,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
     const int n_batch = params.n_batch;
 
     const int max_tasks_per_batch = 32;
-    const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_max_seq(ctx));
+    const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
 
     llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
 
index 6aa88ce54baef46a575d3da88d910f52ab3b7053..98ec147aee5d545b0c9a947ae6ac88133873c2a3 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -12538,7 +12538,7 @@ struct llama_context_params llama_context_default_params() {
         /*.seed                        =*/ LLAMA_DEFAULT_SEED,
         /*.n_ctx                       =*/ 512,
         /*.n_batch                     =*/ 512,
-        /*.n_parallel                  =*/ 1,
+        /*.n_seq_max                   =*/ 1,
         /*.n_threads                   =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
         /*.n_threads_batch             =*/ GGML_DEFAULT_N_THREADS,
         /*.rope_scaling_type           =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
@@ -12700,7 +12700,7 @@ struct llama_context * llama_new_context_with_model(
     auto       & cparams = ctx->cparams;
 
     cparams.n_batch          = params.n_batch;
-    // TODO: maybe add n_parallel here too
+    // TODO: maybe add n_seq_max here too
     cparams.n_threads        = params.n_threads;
     cparams.n_threads_batch  = params.n_threads_batch;
     cparams.yarn_ext_factor  = params.yarn_ext_factor;
@@ -12767,7 +12767,7 @@ struct llama_context * llama_new_context_with_model(
     // Mamba only needs a constant number of KV cache cells per sequence
     if (model->arch == LLM_ARCH_MAMBA) {
         // Mamba needs at least as many KV cells as there are sequences kept at any time
-        kv_size = std::max((uint32_t) 1, params.n_parallel);
+        kv_size = std::max((uint32_t) 1, params.n_seq_max);
         // it's probably best to keep as much precision as possible for the states
         type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
         type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
@@ -13024,7 +13024,7 @@ uint32_t llama_n_batch(const struct llama_context * ctx) {
     return ctx->cparams.n_batch;
 }
 
-uint32_t llama_n_max_seq(const struct llama_context * ctx) {
+uint32_t llama_n_seq_max(const struct llama_context * ctx) {
     return ctx->kv_self.size;
 }
 
@@ -13188,10 +13188,10 @@ int32_t llama_model_apply_lora_from_file(const struct llama_model * model, const
     }
 }
 
-struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq) {
+struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max) {
     struct llama_kv_cache_view result = {
         /*.n_cells            = */ 0,
-        /*.n_max_seq          = */ n_max_seq,
+        /*.n_seq_max          = */ n_seq_max,
         /*.token_count        = */ 0,
         /*.used_cells         = */ llama_get_kv_cache_used_cells(ctx),
         /*.max_contiguous     = */ 0,
@@ -13219,7 +13219,7 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k
         void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
         GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
         view->cells = (struct llama_kv_cache_view_cell *)p;
-        p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_max_seq * view->n_cells);
+        p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells);
         GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
         view->cells_sequences = (llama_seq_id *)p;
     }
@@ -13233,7 +13233,7 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k
     uint32_t max_contig = 0;
     int32_t max_contig_idx = -1;
 
-    for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_max_seq) {
+    for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_seq_max) {
         const size_t curr_size = kv_cells[i].seq_id.size();
         token_count += curr_size;
         c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
@@ -13250,7 +13250,7 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k
 
         int seq_idx = 0;
         for (const llama_seq_id it : kv_cells[i].seq_id) {
-            if (seq_idx >= view->n_max_seq) {
+            if (seq_idx >= view->n_seq_max) {
                 break;
             }
             cs_curr[seq_idx] = it;
@@ -13259,7 +13259,7 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k
         if (seq_idx != 0) {
             used_cells++;
         }
-        for (; seq_idx < view->n_max_seq; seq_idx++) {
+        for (; seq_idx < view->n_seq_max; seq_idx++) {
             cs_curr[seq_idx] = -1;
         }
     }
@@ -13921,12 +13921,12 @@ int32_t llama_tokenize(
                   const char * text,
                      int32_t   text_len,
                  llama_token * tokens,
-                     int32_t   n_max_tokens,
+                     int32_t   n_tokens_max,
                         bool   add_bos,
                         bool   special) {
     auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos, special);
 
-    if (n_max_tokens < (int) res.size()) {
+    if (n_tokens_max < (int) res.size()) {
         // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
         return -((int) res.size());
     }
diff --git a/llama.h b/llama.h
index ccf65ca4e87a51a59fba1dea76222269b5ff9cec..446899da6e38ea6299d54becd71d0db0686bc48b 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -235,7 +235,7 @@ extern "C" {
         uint32_t seed;              // RNG seed, -1 for random
         uint32_t n_ctx;             // text context, 0 = from model
         uint32_t n_batch;           // prompt processing maximum batch size
-        uint32_t n_parallel;        // number of parallel sequences (i.e. distinct states for recurrent models)
+        uint32_t n_seq_max;         // max number of sequences (i.e. distinct states for recurrent models)
         uint32_t n_threads;         // number of threads to use for generation
         uint32_t n_threads_batch;   // number of threads to use for batch processing
 
@@ -377,7 +377,7 @@ extern "C" {
 
     LLAMA_API uint32_t llama_n_ctx      (const struct llama_context * ctx);
     LLAMA_API uint32_t llama_n_batch    (const struct llama_context * ctx);
-    LLAMA_API uint32_t llama_n_max_seq  (const struct llama_context * ctx);
+    LLAMA_API uint32_t llama_n_seq_max  (const struct llama_context * ctx);
 
     LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
     LLAMA_API enum llama_rope_type  llama_rope_type (const struct llama_model * model);
@@ -456,7 +456,7 @@ extern "C" {
         // Maximum number of sequences that can exist in a cell. It's not an error
         // if there are more sequences in a cell than this value, however they will
         // not be visible in the view cells_sequences.
-        int32_t n_max_seq;
+        int32_t n_seq_max;
 
         // Number of tokens in the cache. For example, if there are two populated
         // cells, the first with 1 sequence id in it and the second with 2 sequence
@@ -476,12 +476,12 @@ extern "C" {
         // Information for an individual cell.
         struct llama_kv_cache_view_cell * cells;
 
-        // The sequences for each cell. There will be n_max_seq items per cell.
+        // The sequences for each cell. There will be n_seq_max items per cell.
         llama_seq_id * cells_sequences;
     };
 
     // Create an empty KV cache view. (use only for debugging purposes)
-    LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq);
+    LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max);
 
     // Free a KV cache view. (use only for debugging purposes)
     LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
@@ -708,7 +708,7 @@ extern "C" {
 
     /// @details Convert the provided text into tokens.
     /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
-    /// @return Returns the number of tokens on success, no more than n_max_tokens
+    /// @return Returns the number of tokens on success, no more than n_tokens_max
     /// @return Returns a negative number on failure - the number of tokens that would have been returned
     /// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
     ///                Does not insert a leading space.
@@ -717,7 +717,7 @@ extern "C" {
                       const char * text,
                          int32_t   text_len,
                      llama_token * tokens,
-                         int32_t   n_max_tokens,
+                         int32_t   n_tokens_max,
                             bool   add_bos,
                             bool   special);