]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : allow to specify custom prompt for penalty calculation (#3727)
authorAlexey Parfenov <redacted>
Sat, 23 Dec 2023 09:31:49 +0000 (09:31 +0000)
committerGitHub <redacted>
Sat, 23 Dec 2023 09:31:49 +0000 (11:31 +0200)
common/sampling.cpp
common/sampling.h
examples/server/README.md
examples/server/server.cpp

index 5b15204be88c48c6c9527782dac5201d8f452cd7..8e45909f1faf2042ff30efcb7be0e200f5f676ed 100644 (file)
@@ -203,12 +203,14 @@ static llama_token llama_sampling_sample_impl(
     }
 
     // apply penalties
-    if (!prev.empty()) {
+    const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
+    const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
+    if (penalty_tokens_used_size) {
         const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
 
         llama_sample_repetition_penalties(ctx_main, &cur_p,
-                prev.data() + prev.size() - penalty_last_n,
-                penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
+                penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
+                penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
 
         if (!penalize_nl) {
             for (size_t idx = 0; idx < cur_p.size; idx++) {
index fdfa9eed1467b1fba8165e1f4034d5f87537bd22..f16ef97e34a10ca08b23acb64eec8ae5e68c443c 100644 (file)
@@ -36,6 +36,9 @@ typedef struct llama_sampling_params {
     float       cfg_scale     = 1.f; // how strong is guidance
 
     std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
+
+    std::vector<llama_token> penalty_prompt_tokens;
+    bool                     use_penalty_prompt_tokens = false;
 } llama_sampling_params;
 
 // general sampler context
index 0751b9612f17a386bc4881127579e8a06b6e5f59..f1e586a1c103a23ba68d0410dc38c97eeaa109e9 100644 (file)
@@ -148,6 +148,8 @@ node index.js
 
     `frequency_penalty`: Repeat alpha frequency penalty (default: 0.0, 0.0 = disabled);
 
+    `penalty_prompt`: This will replace the `prompt` for the purpose of the penalty evaluation. Can be either `null`, a string or an array of numbers representing tokens (default: `null` = use the original `prompt`).
+
     `mirostat`: Enable Mirostat sampling, controlling perplexity during text generation (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0).
 
     `mirostat_tau`: Set the Mirostat target entropy, parameter tau (default: 5.0).
index 04038530f94da5a0cde82e5979173fb66df1c066..72dfe452c2d7ae8f446b6cb28c6fc56937eb1357 100644 (file)
@@ -761,6 +761,42 @@ struct llama_server_context
             slot->prompt = "";
         }
 
+        slot->sparams.penalty_prompt_tokens.clear();
+        slot->sparams.use_penalty_prompt_tokens = false;
+        const auto &penalty_prompt = data.find("penalty_prompt");
+        if (penalty_prompt != data.end())
+        {
+            if (penalty_prompt->is_string())
+            {
+                const auto penalty_prompt_string = penalty_prompt->get<std::string>();
+                auto penalty_tokens = llama_tokenize(model, penalty_prompt_string, false);
+                slot->sparams.penalty_prompt_tokens.swap(penalty_tokens);
+                if (slot->params.n_predict > 0)
+                {
+                    slot->sparams.penalty_prompt_tokens.reserve(slot->sparams.penalty_prompt_tokens.size() + slot->params.n_predict);
+                }
+                slot->sparams.use_penalty_prompt_tokens = true;
+            }
+            else if (penalty_prompt->is_array())
+            {
+                const auto n_tokens = penalty_prompt->size();
+                slot->sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot->params.n_predict));
+                const int n_vocab = llama_n_vocab(model);
+                for (const auto &penalty_token : *penalty_prompt)
+                {
+                    if (penalty_token.is_number_integer())
+                    {
+                        const auto tok = penalty_token.get<llama_token>();
+                        if (tok >= 0 && tok < n_vocab)
+                        {
+                            slot->sparams.penalty_prompt_tokens.push_back(tok);
+                        }
+                    }
+                }
+                slot->sparams.use_penalty_prompt_tokens = true;
+            }
+        }
+
         slot->sparams.logit_bias.clear();
 
         if (json_value(data, "ignore_eos", false))
@@ -992,6 +1028,12 @@ struct llama_server_context
         slot.generated_text += token_str;
         slot.has_next_token = true;
 
+        if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1)
+        {
+            // we can change penalty_prompt_tokens because it is always created from scratch each request
+            slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
+        }
+
         // check if there is incomplete UTF-8 character at the end
         bool incomplete = false;
         for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i)
@@ -1183,6 +1225,8 @@ struct llama_server_context
             {"repeat_penalty",    slot.sparams.penalty_repeat},
             {"presence_penalty",  slot.sparams.penalty_present},
             {"frequency_penalty", slot.sparams.penalty_freq},
+            {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens},
+            {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
             {"mirostat",          slot.sparams.mirostat},
             {"mirostat_tau",      slot.sparams.mirostat_tau},
             {"mirostat_eta",      slot.sparams.mirostat_eta},