]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
context : remove logits_all flag (#13284)
authorGeorgi Gerganov <redacted>
Thu, 8 May 2025 11:26:50 +0000 (14:26 +0300)
committerGitHub <redacted>
Thu, 8 May 2025 11:26:50 +0000 (14:26 +0300)
* context : remove logits_all flag

ggml-ci

* llama : remove logits_all flag + reorder llama_context_params

ggml-ci

common/arg.cpp
common/common.cpp
common/common.h
include/llama.h
src/llama-context.cpp
src/llama-context.h
tools/imatrix/imatrix.cpp
tools/main/main.cpp
tools/perplexity/perplexity.cpp

index 5e07e8a699b8fdc52f6e94f3930ad0c35a459011..9f87e9910b5405019a1a2367fbfeb8eab68762f6 100644 (file)
@@ -2097,13 +2097,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.cache_type_v = kv_cache_type_from_str(value);
         }
     ).set_env("LLAMA_ARG_CACHE_TYPE_V"));
-    add_opt(common_arg(
-        {"--perplexity", "--all-logits"},
-        string_format("return logits for all tokens in the batch (default: %s)", params.logits_all ? "true" : "false"),
-        [](common_params & params) {
-            params.logits_all = true;
-        }
-    ).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
     add_opt(common_arg(
         {"--hellaswag"},
         "compute HellaSwag score over random tasks from datafile supplied with -f",
index 94f545f815c27023e727e45d87911eba2a10ae7f..bd20af233695c2f76594bd05fb1e0916f5e83096 100644 (file)
@@ -1096,7 +1096,6 @@ struct llama_context_params common_context_params_to_llama(const common_params &
     cparams.n_threads         = params.cpuparams.n_threads;
     cparams.n_threads_batch   = params.cpuparams_batch.n_threads == -1 ?
                                 params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
-    cparams.logits_all        = params.logits_all;
     cparams.embeddings        = params.embedding;
     cparams.rope_scaling_type = params.rope_scaling_type;
     cparams.rope_freq_base    = params.rope_freq_base;
index 400f674b2283da3abee8014c3b6becbacccee5c9..90702245463cbbc968cc01b5e0ad016130a7bd65 100644 (file)
@@ -324,7 +324,6 @@ struct common_params {
     bool ctx_shift         = true;  // context shift on inifinite text generation
 
     bool input_prefix_bos  = false; // prefix BOS to user inputs, preceding input_prefix
-    bool logits_all        = false; // return logits for all tokens in the batch
     bool use_mmap          = true;  // use mmap for faster loads
     bool use_mlock         = false; // use mlock to keep model in memory
     bool verbose_prompt    = false; // print prompt tokens before generation
index 06c56395c139fb8cc936542cae2a124c16cfba71..e18e9b8da337f042011a49a6018d844da0063136 100644 (file)
@@ -351,19 +351,17 @@ extern "C" {
         enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
         enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
 
-        // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
-        // TODO: move at the end of the struct
-        bool logits_all;  // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
-        bool embeddings;  // if true, extract embeddings (together with logits)
-        bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
-        bool flash_attn;  // whether to use flash attention [EXPERIMENTAL]
-        bool no_perf;     // whether to measure performance timings
-
         // Abort callback
         // if it returns true, execution of llama_decode() will be aborted
         // currently works only with CPU execution
         ggml_abort_callback abort_callback;
         void *              abort_callback_data;
+
+        // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
+        bool embeddings;  // if true, extract embeddings (together with logits)
+        bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
+        bool flash_attn;  // whether to use flash attention [EXPERIMENTAL]
+        bool no_perf;     // whether to measure performance timings
     };
 
     // model quantization parameters
index 45591be992d8788303fc2f164984040ce103fff3..dadb87517f7ef17fb5b7a35bd8f0584a74328093 100644 (file)
@@ -116,8 +116,6 @@ llama_context::llama_context(
                 __func__, n_ctx_per_seq, hparams.n_ctx_train);
     }
 
-    logits_all = params.logits_all;
-
     if (!hparams.vocab_only) {
         // GPU backends
         for (auto * dev : model.devices) {
@@ -890,7 +888,7 @@ int llama_context::decode(llama_batch & inp_batch) {
         for (uint32_t i = 0; i < n_tokens_all; ++i) {
             n_outputs_all += batch.logits[i] != 0;
         }
-    } else if (logits_all || embd_pooled) {
+    } else if (embd_pooled) {
         n_outputs_all = n_tokens_all;
     } else {
         // keep last output only
@@ -1853,13 +1851,12 @@ llama_context_params llama_context_default_params() {
         /*.cb_eval_user_data           =*/ nullptr,
         /*.type_k                      =*/ GGML_TYPE_F16,
         /*.type_v                      =*/ GGML_TYPE_F16,
-        /*.logits_all                  =*/ false,
+        /*.abort_callback              =*/ nullptr,
+        /*.abort_callback_data         =*/ nullptr,
         /*.embeddings                  =*/ false,
         /*.offload_kqv                 =*/ true,
         /*.flash_attn                  =*/ false,
         /*.no_perf                     =*/ true,
-        /*.abort_callback              =*/ nullptr,
-        /*.abort_callback_data         =*/ nullptr,
     };
 
     return result;
index cf41ac57b9fba8135ce3d98c6aaee9de48f52bd8..5a080e67fcc4b8be582e9897ff60bd99b324a949 100644 (file)
@@ -187,9 +187,6 @@ private:
 
     std::unique_ptr<llama_memory_i> memory;
 
-    // TODO: remove
-    bool logits_all = false;
-
     // decode output (2-dimensional array: [n_outputs][n_vocab])
     size_t  logits_size = 0; // capacity (of floats) for logits
     float * logits      = nullptr;
index b4640f9faf229743205ec09bff7bf5c6942e8ed9..2c39278dba3d97304ac9a8e8a01094d6565a3d7a 100644 (file)
@@ -585,7 +585,6 @@ int main(int argc, char ** argv) {
     params.out_file = "imatrix.dat" ;
 
     params.n_ctx = 512;
-    params.logits_all = true;
     params.escape = false;
 
     if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_IMATRIX, print_usage)) {
index c59b941bf5e4759d450b1bbc96fb045db166b080..756297c257a6e5e026869f969cf9abe9c07bfeae 100644 (file)
@@ -99,14 +99,6 @@ int main(int argc, char ** argv) {
     console::init(params.simple_io, params.use_color);
     atexit([]() { console::cleanup(); });
 
-    if (params.logits_all) {
-        LOG_ERR("************\n");
-        LOG_ERR("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__);
-        LOG_ERR("************\n\n");
-
-        return 0;
-    }
-
     if (params.embedding) {
         LOG_ERR("************\n");
         LOG_ERR("%s: please use the 'embedding' tool for embedding calculations\n", __func__);
index 175f2804b5da007afeee7cf4381e0e3d7de48588..b5cdf5beb1b24a6e1926d6882c2e5e93cd48d5b7 100644 (file)
@@ -1554,7 +1554,10 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
             if (int(batch_indeces.size()) != num_answers) {
                 batch_indeces.resize(num_answers);
             }
-            for (int s = 0; s < num_answers; ++s) batch_indeces[s] = s0 + s;
+
+            for (int s = 0; s < num_answers; ++s) {
+                batch_indeces[s] = s0 + s;
+            }
 
             for (size_t i = 0; i < cur_task.common_prefix; ++i) {
                 //llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
@@ -1970,7 +1973,6 @@ int main(int argc, char ** argv) {
     common_params params;
 
     params.n_ctx = 512;
-    params.logits_all = true;
     params.escape = false;
 
     if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {