]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add classifier-free guidance (#2135)
authorBach Le <redacted>
Tue, 11 Jul 2023 16:18:43 +0000 (00:18 +0800)
committerGitHub <redacted>
Tue, 11 Jul 2023 16:18:43 +0000 (19:18 +0300)
* Initial implementation

* Remove debug print

* Restore signature of llama_init_from_gpt_params

* Free guidance context

* Make freeing of guidance_ctx conditional

* Make Classifier-Free Guidance a sampling function

* Correct typo. CFG already means context-free grammar.

* Record sampling time in llama_sample_classifier_free_guidance

* Shift all values by the max value before applying logsoftmax

* Fix styling based on review

examples/common.cpp
examples/common.h
examples/main/main.cpp
llama.cpp
llama.h

index fad16887de3c34ed72294fc0de26f67dbf8030d1..fd551c9cb2fcf91a82832169e79ff7e0a4f276d3 100644 (file)
@@ -236,6 +236,24 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.mirostat_tau = std::stof(argv[i]);
+        } else if (arg == "--cfg-negative-prompt") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.cfg_negative_prompt = argv[i];
+        } else if (arg == "--cfg-scale") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.cfg_scale = std::stof(argv[i]);
+        } else if (arg == "--cfg-smooth-factor") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.cfg_smooth_factor = std::stof(argv[i]);
         } else if (arg == "-b" || arg == "--batch-size") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -469,6 +487,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stderr, "                        modifies the likelihood of token appearing in the completion,\n");
     fprintf(stderr, "                        i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
     fprintf(stderr, "                        or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
+    fprintf(stderr, "  --cfg-negative-prompt PROMPT \n");
+    fprintf(stderr, "                        negative prompt to use for guidance. (default: empty)\n");
+    fprintf(stderr, "  --cfg-scale N         strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
+    fprintf(stderr, "  --cfg-smooth-factor N smooth factor between old and new logits (default: %f, 1.0 = no smoothing)\n", params.cfg_smooth_factor);
     fprintf(stderr, "  -c N, --ctx-size N    size of the prompt context (default: %d)\n", params.n_ctx);
     fprintf(stderr, "  --ignore-eos          ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
     fprintf(stderr, "  --no-penalize-nl      do not penalize newline token\n");
@@ -535,7 +557,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
     return res;
 }
 
-std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params) {
+struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
     auto lparams = llama_context_default_params();
 
     lparams.n_ctx        = params.n_ctx;
@@ -551,6 +573,12 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
     lparams.logits_all   = params.perplexity;
     lparams.embedding    = params.embedding;
 
+    return lparams;
+}
+
+std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params) {
+    auto lparams = llama_context_params_from_gpt_params(params);
+
     llama_model * model  = llama_load_model_from_file(params.model.c_str(), lparams);
     if (model == NULL) {
         fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
index 96f2228f8677b65e4aa3a39898059045e5993c21..6315df9613445697e63db3d492732a17c8d9c91a 100644 (file)
@@ -48,6 +48,12 @@ struct gpt_params {
     float   mirostat_tau      = 5.00f; // target entropy
     float   mirostat_eta      = 0.10f; // learning rate
 
+    // Classifier-Free Guidance
+    // https://arxiv.org/abs/2306.17806
+    std::string cfg_negative_prompt;       // string to help guidance
+    float       cfg_scale         = 1.f;   // How strong is guidance
+    float       cfg_smooth_factor = 1.f;   // Smooth factor between old and new logits
+
     std::string model             = "models/7B/ggml-model.bin"; // model path
     std::string model_alias       = "unknown"; // model alias
     std::string prompt            = "";
@@ -99,6 +105,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
 //
 
 std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params);
+struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
 
 //
 // Console utils
index 07d8fc6ac078188ce5838e4a9d61461f8a14af12..2248c245875b04b88ed84c17e0a6b50ba9825134 100644 (file)
@@ -109,10 +109,16 @@ int main(int argc, char ** argv) {
 
     llama_model * model;
     llama_context * ctx;
+    llama_context * ctx_guidance = NULL;
     g_ctx = &ctx;
 
     // load the model and apply lora adapter, if any
     std::tie(model, ctx) = llama_init_from_gpt_params(params);
+    if (params.cfg_scale > 1.f) {
+        struct llama_context_params lparams = llama_context_params_from_gpt_params(params);
+        ctx_guidance = llama_new_context_with_model(model, lparams);
+    }
+
     if (model == NULL) {
         fprintf(stderr, "%s: error: unable to load model\n", __func__);
         return 1;
@@ -183,15 +189,28 @@ int main(int argc, char ** argv) {
     // tokenize the prompt
     std::vector<llama_token> embd_inp;
 
-    if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
-        // Add a space in front of the first character to match OG llama tokenizer behavior
-        params.prompt.insert(0, 1, ' ');
+    // Add a space in front of the first character to match OG llama tokenizer behavior
+    params.prompt.insert(0, 1, ' ');
 
+    if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
         embd_inp = ::llama_tokenize(ctx, params.prompt, true);
     } else {
         embd_inp = session_tokens;
     }
 
+    // Tokenize negative prompt
+    std::vector<llama_token> guidance_inp;
+    int guidance_offset = 0;
+    int original_prompt_len = 0;
+    if (ctx_guidance) {
+        params.cfg_negative_prompt.insert(0, 1, ' ');
+        guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, true);
+
+        std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true);
+        original_prompt_len = original_inp.size();
+        guidance_offset = (int)guidance_inp.size() - original_prompt_len;
+    }
+
     const int n_ctx = llama_n_ctx(ctx);
 
     if ((int) embd_inp.size() > n_ctx - 4) {
@@ -258,6 +277,16 @@ int main(int argc, char ** argv) {
         for (int i = 0; i < (int) embd_inp.size(); i++) {
             fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]));
         }
+
+        if (ctx_guidance) {
+            fprintf(stderr, "\n");
+            fprintf(stderr, "%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str());
+            fprintf(stderr, "%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size());
+            for (int i = 0; i < (int) guidance_inp.size(); i++) {
+                fprintf(stderr, "%6d -> '%s'\n", guidance_inp[i], llama_token_to_str(ctx, guidance_inp[i]));
+            }
+        }
+
         if (params.n_keep > 0) {
         fprintf(stderr, "%s: static prompt based on n_keep: '", __func__);
             for (int i = 0; i < params.n_keep; i++) {
@@ -334,11 +363,13 @@ int main(int argc, char ** argv) {
     int n_remain           = params.n_predict;
     int n_consumed         = 0;
     int n_session_consumed = 0;
+    int n_past_guidance    = 0;
 
     // the first thing we will do is to output the prompt, so set color accordingly
     console_set_color(con_st, CONSOLE_COLOR_PROMPT);
 
     std::vector<llama_token> embd;
+    std::vector<llama_token> embd_guidance;
 
     // do one empty run to warm up the model
     {
@@ -367,11 +398,12 @@ int main(int argc, char ** argv) {
             // if we run out of context:
             // - take the n_keep first tokens from the original prompt (via n_past)
             // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
-            if (n_past + (int) embd.size() > n_ctx) {
+            if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > n_ctx) {
                 const int n_left = n_past - params.n_keep;
 
                 // always keep the first token - BOS
                 n_past = std::max(1, params.n_keep);
+                n_past_guidance = std::max(1, params.n_keep + guidance_offset);
 
                 // insert n_left/2 tokens at the start of embd from last_n_tokens
                 embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());
@@ -412,6 +444,48 @@ int main(int argc, char ** argv) {
 
             // evaluate tokens in batches
             // embd is typically prepared beforehand to fit within a batch, but not always
+
+            if (ctx_guidance) {
+                int input_size = 0;
+                llama_token* input_buf = NULL;
+
+                if (n_past_guidance < (int) guidance_inp.size()) {
+                    // Guidance context should have the same data with these modifications:
+                    //
+                    // * Replace the initial prompt
+                    // * Shift everything by guidance_offset
+                    embd_guidance = guidance_inp;
+                    if (embd.begin() + original_prompt_len < embd.end()) {
+                        embd_guidance.insert(
+                            embd_guidance.end(),
+                            embd.begin() + original_prompt_len,
+                            embd.end()
+                        );
+                    }
+
+                    input_buf = embd_guidance.data();
+                    input_size = embd_guidance.size();
+                    //fprintf(stderr, "\n---------------------\n");
+                    //for (int i = 0; i < (int) embd_guidance.size(); i++) {
+                        //fprintf(stderr, "%s", llama_token_to_str(ctx, embd_guidance[i]));
+                    //}
+                    //fprintf(stderr, "\n---------------------\n");
+                } else {
+                    input_buf = embd.data();
+                    input_size = embd.size();
+                }
+
+                for (int i = 0; i < input_size; i += params.n_batch) {
+                    int n_eval = std::min(input_size - i, params.n_batch);
+                    if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, params.n_threads)) {
+                        fprintf(stderr, "%s : failed to eval\n", __func__);
+                        return 1;
+                    }
+
+                    n_past_guidance += n_eval;
+                }
+            }
+
             for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
                 int n_eval = (int) embd.size() - i;
                 if (n_eval > params.n_batch) {
@@ -431,6 +505,7 @@ int main(int argc, char ** argv) {
         }
 
         embd.clear();
+        embd_guidance.clear();
 
         if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
             // out of user input, sample next token
@@ -473,6 +548,10 @@ int main(int argc, char ** argv) {
 
                 llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
 
+                if (ctx_guidance) {
+                    llama_sample_classifier_free_guidance(ctx, &candidates_p, ctx_guidance, params.cfg_scale, params.cfg_smooth_factor);
+                }
+
                 // Apply penalties
                 float nl_logit = logits[llama_token_nl()];
                 auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
@@ -668,6 +747,7 @@ int main(int argc, char ** argv) {
     }
 
     llama_print_timings(ctx);
+    if (ctx_guidance) { llama_free(ctx_guidance); }
     llama_free(ctx);
     llama_free_model(model);
 
index 08ec21ab631a81e7924599afc0a889906668795c..2d09d6ce7661913d4ff1b8671ba7b98061bcfc2a 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -2167,6 +2167,62 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l
     }
 }
 
+static void llama_log_softmax(float * array, size_t size) {
+    float max_l = *std::max_element(array, array + size);
+    float sum = 0.f;
+    for (size_t i = 0; i < size; ++i) {
+        float p = expf(array[i] - max_l);
+        sum += p;
+        array[i] = p;
+    }
+
+    for (size_t i = 0; i < size; ++i) {
+        array[i] = logf(array[i] / sum);
+    }
+}
+
+void llama_sample_classifier_free_guidance(
+          struct llama_context * ctx,
+        llama_token_data_array * candidates,
+          struct llama_context * guidance_ctx,
+                         float   scale,
+                         float   smooth_factor) {
+    int64_t t_start_sample_us = t_start_sample_us = ggml_time_us();
+
+    assert(ctx);
+    auto n_vocab = llama_n_vocab(ctx);
+    assert(n_vocab == (int)candidates->size);
+    assert(!candidates->sorted);
+
+    std::vector<float> logits_base;
+    logits_base.reserve(candidates->size);
+    for (size_t i = 0; i < candidates->size; ++i) {
+        logits_base.push_back(candidates->data[i].logit);
+    }
+    llama_log_softmax(logits_base.data(), candidates->size);
+
+    float* logits_guidance = llama_get_logits(guidance_ctx);
+    llama_log_softmax(logits_guidance, n_vocab);
+
+    for (int i = 0; i < n_vocab; ++i) {
+        float logit_guidance = logits_guidance[i];
+        float logit_base = logits_base[i];
+        logits_guidance[i] = scale * (logit_base - logit_guidance) + logit_guidance;
+    }
+
+    llama_log_softmax(logits_guidance, n_vocab);
+
+    for (int i = 0; i < n_vocab; ++i) {
+        float logit_base = logits_base[i];
+        float logit_guidance = logits_guidance[i];
+
+        candidates->data[i].logit = smooth_factor * logit_guidance + (1.f - smooth_factor) * logit_base;
+    }
+
+    if (ctx) {
+        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+    }
+}
 
 llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) {
     assert(ctx);
diff --git a/llama.h b/llama.h
index 686463aa25af8b09b9de04c3079c85eca1ce0074..4596b1ed4dedfe074761c76da343c8b0ff0147fd 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -309,6 +309,18 @@ extern "C" {
     /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
     LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
 
+    /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
+    /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
+    /// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
+    /// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
+    /// @params smooth_factor Smooth factor between guidance logits and original logits. 1.0f means only use guidance logits. 0.0f means only original logits.
+    LLAMA_API void llama_sample_classifier_free_guidance(
+              struct llama_context * ctx,
+            llama_token_data_array * candidates,
+              struct llama_context * guidance_ctx,
+                             float   scale,
+                             float   smooth_factor);
+
     /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
     LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);