]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add repetition penalty (#20)
authorbeiller <redacted>
Sun, 12 Mar 2023 09:27:42 +0000 (05:27 -0400)
committerGitHub <redacted>
Sun, 12 Mar 2023 09:27:42 +0000 (11:27 +0200)
* Adding repeat penalization

* Update utils.h

* Update utils.cpp

* Numeric fix

Should probably still scale by temp even if penalized

* Update comments, more proper application

I see that numbers can go negative so a fix from a referenced commit

* Minor formatting

---------

Co-authored-by: Georgi Gerganov <redacted>
main.cpp
utils.cpp
utils.h

index 2f47480698f1e932ecec27fb7236fdea62c13ca0..f02b5ddbde94dae213e28d6211c1305d7501449d 100644 (file)
--- a/main.cpp
+++ b/main.cpp
@@ -792,7 +792,7 @@ int main(int argc, char ** argv) {
         printf("%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str());
     }
     printf("\n");
-    printf("sampling parameters: temp = %f, top_k = %d, top_p = %f\n", params.temp, params.top_k, params.top_p);
+    printf("sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
     printf("\n\n");
 
     std::vector<gpt_vocab::id> embd;
@@ -801,6 +801,10 @@ int main(int argc, char ** argv) {
     size_t mem_per_token = 0;
     llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
 
+    int last_n_size = params.repeat_last_n;
+    std::vector<gpt_vocab::id> last_n_tokens(last_n_size);
+    std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
+
     for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
         // predict
         if (embd.size() > 0) {
@@ -821,6 +825,7 @@ int main(int argc, char ** argv) {
             // sample next token
             const float top_p = params.top_p;
             const float temp  = params.temp;
+            const float repeat_penalty = params.repeat_penalty;
 
             const int n_vocab = model.hparams.n_vocab;
 
@@ -829,7 +834,10 @@ int main(int argc, char ** argv) {
             {
                 const int64_t t_start_sample_us = ggml_time_us();
 
-                id = llama_sample_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_p, temp, rng);
+                id = llama_sample_top_p(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_p, temp, rng);
+
+                last_n_tokens.erase(last_n_tokens.begin());
+                last_n_tokens.push_back(id);
 
                 t_sample_us += ggml_time_us() - t_start_sample_us;
             }
@@ -840,6 +848,8 @@ int main(int argc, char ** argv) {
             // if here, it means we are still processing the input prompt
             for (int k = i; k < embd_inp.size(); k++) {
                 embd.push_back(embd_inp[k]);
+                last_n_tokens.erase(last_n_tokens.begin());
+                last_n_tokens.push_back(embd_inp[k]);
                 if (embd.size() > params.n_batch) {
                     break;
                 }
index abb34756ac026a60890bc2fe4299cd6df9ef3f83..49023bd7b8626ef9aa91ff17dde1417566fa44d8 100644 (file)
--- a/utils.cpp
+++ b/utils.cpp
@@ -23,6 +23,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
             params.top_p = std::stof(argv[++i]);
         } else if (arg == "--temp") {
             params.temp = std::stof(argv[++i]);
+        } else if (arg == "--repeat_last_n") {
+            params.repeat_last_n = std::stoi(argv[++i]);
+        } else if (arg == "--repeat_penalty") {
+            params.repeat_penalty = std::stof(argv[++i]);
         } else if (arg == "-b" || arg == "--batch_size") {
             params.n_batch = std::stoi(argv[++i]);
         } else if (arg == "-m" || arg == "--model") {
@@ -52,6 +56,8 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) {
     fprintf(stderr, "  -n N, --n_predict N   number of tokens to predict (default: %d)\n", params.n_predict);
     fprintf(stderr, "  --top_k N             top-k sampling (default: %d)\n", params.top_k);
     fprintf(stderr, "  --top_p N             top-p sampling (default: %.1f)\n", params.top_p);
+    fprintf(stderr, "  --repeat_last_n N     last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
+    fprintf(stderr, "  --repeat_penalty N    penalize repeat sequence of tokens (default: %.1f)\n", params.repeat_penalty);
     fprintf(stderr, "  --temp N              temperature (default: %.1f)\n", params.temp);
     fprintf(stderr, "  -b N, --batch_size N  batch size for prompt processing (default: %d)\n", params.n_batch);
     fprintf(stderr, "  -m FNAME, --model FNAME\n");
@@ -372,6 +378,8 @@ gpt_vocab::id gpt_sample_top_k_top_p(
 gpt_vocab::id llama_sample_top_p(
         const gpt_vocab & vocab,
         const float * logits,
+        std::vector<gpt_vocab::id> & last_n_tokens,
+        double repeat_penalty,
         double top_p,
         double temp,
         std::mt19937 & rng) {
@@ -383,7 +391,18 @@ gpt_vocab::id llama_sample_top_p(
     {
         const double scale = 1.0/temp;
         for (int i = 0; i < n_logits; ++i) {
-            logits_id.push_back(std::make_pair(logits[i]*scale, i));
+            // repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
+            // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
+            if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
+                // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
+                if (logits[i] < 0.0) {
+                    logits_id.push_back(std::make_pair(logits[i]*scale*repeat_penalty, i));
+                } else {
+                    logits_id.push_back(std::make_pair(logits[i]*scale/repeat_penalty, i));
+                }                
+            } else {
+                logits_id.push_back(std::make_pair(logits[i]*scale, i));
+            }
         }
     }
 
diff --git a/utils.h b/utils.h
index bbe8fe823d01e9869b8b6ab6bbae43271156fbd3..e331904baa33baff456ec78ffabd6f88caec68cc 100644 (file)
--- a/utils.h
+++ b/utils.h
@@ -16,11 +16,13 @@ struct gpt_params {
     int32_t seed      = -1; // RNG seed
     int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
     int32_t n_predict = 128; // new tokens to predict
+    int32_t repeat_last_n = 64;  // last n tokens to penalize
 
     // sampling parameters
     int32_t top_k = 40; // unused
     float   top_p = 0.95f;
     float   temp  = 0.80f;
+    float   repeat_penalty  = 1.30f;
 
     int32_t n_batch = 8; // batch size for prompt processing
 
@@ -89,6 +91,8 @@ gpt_vocab::id gpt_sample_top_k_top_p(
 gpt_vocab::id llama_sample_top_p(
         const gpt_vocab & vocab,
         const float * logits,
+        std::vector<gpt_vocab::id> & last_n_tokens,
+        double repeat_penalty,
         double top_p,
         double temp,
         std::mt19937 & rng);