]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
starcoder : add repeat penalty (#311)
authorthe-crypt-keeper <redacted>
Sun, 2 Jul 2023 14:52:52 +0000 (10:52 -0400)
committerGitHub <redacted>
Sun, 2 Jul 2023 14:52:52 +0000 (17:52 +0300)
* implement repeat penalty processing for starcoder

* show effective parameters at starcoder startup

---------

Co-authored-by: Mike Ravkine <redacted>
examples/common.cpp
examples/common.h
examples/starcoder/main.cpp

index fe00278c2584e1c49d243420d3f69bab8c2262ca..7b01089b0e04105ed0975d3a5c7ab0290305b0b1 100644 (file)
@@ -39,6 +39,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
             params.top_p = std::stof(argv[++i]);
         } else if (arg == "--temp") {
             params.temp = std::stof(argv[++i]);
+        } else if (arg == "--repeat-last-n") {
+            params.repeat_last_n = std::stof(argv[++i]);
+        } else if (arg == "--repeat-penalty") {
+            params.repeat_penalty = std::stof(argv[++i]);            
         } else if (arg == "-b" || arg == "--batch_size") {
             params.n_batch = std::stoi(argv[++i]);
         } else if (arg == "-m" || arg == "--model") {
@@ -90,6 +94,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stderr, "  --top_k N             top-k sampling (default: %d)\n", params.top_k);
     fprintf(stderr, "  --top_p N             top-p sampling (default: %.1f)\n", params.top_p);
     fprintf(stderr, "  --temp N              temperature (default: %.1f)\n", params.temp);
+    fprintf(stderr, "  --repeat-last-n N     last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n);
+    fprintf(stderr, "  --repeat-penalty N    penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)\n", (double)params.repeat_penalty);    
     fprintf(stderr, "  -b N, --batch_size N  batch size for prompt processing (default: %d)\n", params.n_batch);
     fprintf(stderr, "  -m FNAME, --model FNAME\n");
     fprintf(stderr, "                        model path (default: %s)\n", params.model.c_str());
index 7e9b867d3d6f1f8af3a00140a153aa321c68d72f..12b2b339d670533476d5c3e8acfab140db04c0bb 100644 (file)
@@ -23,6 +23,8 @@ struct gpt_params {
     int32_t top_k = 40;
     float   top_p = 0.9f;
     float   temp  = 0.9f;
+    int32_t repeat_last_n  = 64;
+    float   repeat_penalty = 1.00f;
 
     int32_t n_batch = 8; // batch size for prompt processing
 
index 2016f8974c8dc8bf5979cd5d69e2e8a5687c0cd7..5c6065980faba76f4dde08796c82488b8a4adda1 100644 (file)
@@ -782,6 +782,16 @@ int main(int argc, char ** argv) {
         test_gpt_tokenizer(vocab, params.token_test);
     }
 
+    if (params.repeat_last_n == -1) {
+        params.repeat_last_n = model.hparams.n_ctx;
+    }
+    printf("\n");
+    printf("%s: temp           = %.3f\n", __func__, params.temp);
+    printf("%s: top_k          = %d\n",   __func__, params.top_k);
+    printf("%s: top_p          = %.3f\n", __func__, params.top_p);
+    printf("%s: repeat_last_n  = %d\n",   __func__, params.repeat_last_n);
+    printf("%s: repeat_penalty = %.3f\n", __func__, params.repeat_penalty);
+    
     int n_past = 0;
 
     int64_t t_sample_us  = 0;
@@ -789,6 +799,9 @@ int main(int argc, char ** argv) {
 
     std::vector<float> logits;
 
+    std::vector<int32_t> last_n_tokens(model.hparams.n_ctx);
+    std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
+    
     // tokenize the prompt
     std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
 
@@ -847,17 +860,23 @@ int main(int argc, char ** argv) {
             {
                 const int64_t t_start_sample_us = ggml_time_us();
 
-                id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);
-
+                id = gpt_sample_top_k_top_p_repeat(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens.data(), last_n_tokens.size(), top_k, top_p, temp, params.repeat_last_n, params.repeat_penalty, rng);
                 t_sample_us += ggml_time_us() - t_start_sample_us;
             }
 
             // add it to the context
             embd.push_back(id);
+
+            last_n_tokens.erase(last_n_tokens.begin());
+            last_n_tokens.push_back(id);            
         } else {
             // if here, it means we are still processing the input prompt
             for (int k = i; k < embd_inp.size(); k++) {
                 embd.push_back(embd_inp[k]);
+
+                last_n_tokens.erase(last_n_tokens.begin());
+                last_n_tokens.push_back(embd_inp[k]); 
+
                 if (embd.size() >= params.n_batch) {
                     break;
                 }