]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : remove cfg smooth factor as it is only a reparameterization of the guidance...
authorGuillaume "Vermeille" Sanchez <redacted>
Fri, 21 Jul 2023 10:58:36 +0000 (12:58 +0200)
committerGitHub <redacted>
Fri, 21 Jul 2023 10:58:36 +0000 (13:58 +0300)
examples/common.cpp
examples/common.h
examples/main/main.cpp
llama.cpp
llama.h

index 476d565949debe155b0175750eea151871d6d0f1..09901959956f9d92686719722afe6f4e2a245b3e 100644 (file)
@@ -260,12 +260,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.cfg_scale = std::stof(argv[i]);
-        } else if (arg == "--cfg-smooth-factor") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            params.cfg_smooth_factor = std::stof(argv[i]);
         } else if (arg == "-b" || arg == "--batch-size") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -509,7 +503,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stderr, "  --cfg-negative-prompt PROMPT \n");
     fprintf(stderr, "                        negative prompt to use for guidance. (default: empty)\n");
     fprintf(stderr, "  --cfg-scale N         strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
-    fprintf(stderr, "  --cfg-smooth-factor N smooth factor between old and new logits (default: %f, 1.0 = no smoothing)\n", params.cfg_smooth_factor);
     fprintf(stderr, "  -c N, --ctx-size N    size of the prompt context (default: %d)\n", params.n_ctx);
     fprintf(stderr, "  --rope-freq-base N    RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
     fprintf(stderr, "  --rope-freq-scale N   RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
index 037a4eecba5b5a0a7cc396a36118847504842f90..69170dfc098502ca20b9f67cacaeb782910297d6 100644 (file)
@@ -55,7 +55,6 @@ struct gpt_params {
     // https://arxiv.org/abs/2306.17806
     std::string cfg_negative_prompt;       // string to help guidance
     float       cfg_scale         = 1.f;   // How strong is guidance
-    float       cfg_smooth_factor = 1.f;   // Smooth factor between old and new logits
 
     std::string model             = "models/7B/ggml-model.bin"; // model path
     std::string model_alias       = "unknown"; // model alias
index bcbcf12b0ac19d93a36ce13e215b095e147fa67e..656382f8161dd3593e878927be56acaed1f9bce8 100644 (file)
@@ -557,7 +557,7 @@ int main(int argc, char ** argv) {
                 llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
 
                 if (ctx_guidance) {
-                    llama_sample_classifier_free_guidance(ctx, &candidates_p, ctx_guidance, params.cfg_scale, params.cfg_smooth_factor);
+                    llama_sample_classifier_free_guidance(ctx, &candidates_p, ctx_guidance, params.cfg_scale);
                 }
 
                 // Apply penalties
index 23e746d625fe44be9651d0841988b2f5d896dc11..3b0024e1284793040b0c66ca4607446b0e731d25 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -2218,8 +2218,7 @@ void llama_sample_classifier_free_guidance(
           struct llama_context * ctx,
         llama_token_data_array * candidates,
           struct llama_context * guidance_ctx,
-                         float   scale,
-                         float   smooth_factor) {
+                         float   scale) {
     int64_t t_start_sample_us = ggml_time_us();
 
     assert(ctx);
@@ -2240,16 +2239,7 @@ void llama_sample_classifier_free_guidance(
     for (int i = 0; i < n_vocab; ++i) {
         float logit_guidance = logits_guidance[i];
         float logit_base = logits_base[i];
-        logits_guidance[i] = scale * (logit_base - logit_guidance) + logit_guidance;
-    }
-
-    llama_log_softmax(logits_guidance, n_vocab);
-
-    for (int i = 0; i < n_vocab; ++i) {
-        float logit_base = logits_base[i];
-        float logit_guidance = logits_guidance[i];
-
-        candidates->data[i].logit = smooth_factor * logit_guidance + (1.f - smooth_factor) * logit_base;
+        candidates->data[i].logit = scale * (logit_base - logit_guidance) + logit_guidance;
     }
 
     if (ctx) {
diff --git a/llama.h b/llama.h
index c565f6a001cf3320d0b8fdafa429c6eff888171d..bbf28e68684cf60dc65f08f3eba1ff26565abcee 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -344,13 +344,11 @@ extern "C" {
     /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
     /// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
     /// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
-    /// @params smooth_factor Smooth factor between guidance logits and original logits. 1.0f means only use guidance logits. 0.0f means only original logits.
     LLAMA_API void llama_sample_classifier_free_guidance(
               struct llama_context * ctx,
             llama_token_data_array * candidates,
               struct llama_context * guidance_ctx,
-                             float   scale,
-                             float   smooth_factor);
+                             float   scale);
 
     /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
     LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);