]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : add `n_discard` parameter (#6300)
authorJan Boon <redacted>
Tue, 26 Mar 2024 08:47:43 +0000 (16:47 +0800)
committerGitHub <redacted>
Tue, 26 Mar 2024 08:47:43 +0000 (10:47 +0200)
examples/server/server.cpp

index c4c545c3e0ac4e42c82ff9052d609b7e984152fd..526de596e34c0a15e05c9ab700f89b01b7c373a5 100644 (file)
@@ -99,6 +99,7 @@ struct slot_params {
 
     uint32_t seed      = -1; // RNG seed
     int32_t  n_keep    =  0; // number of tokens to keep from initial prompt
+    int32_t  n_discard =  0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
     int32_t  n_predict = -1; // new tokens to predict
 
     std::vector<std::string> antiprompt;
@@ -846,6 +847,7 @@ struct server_context {
         slot.sparams.mirostat_eta      = json_value(data, "mirostat_eta",      default_sparams.mirostat_eta);
         slot.sparams.penalize_nl       = json_value(data, "penalize_nl",       default_sparams.penalize_nl);
         slot.params.n_keep             = json_value(data, "n_keep",            slot.params.n_keep);
+        slot.params.n_discard          = json_value(data, "n_discard",         default_params.n_discard);
         slot.params.seed               = json_value(data, "seed",              default_params.seed);
         slot.sparams.n_probs           = json_value(data, "n_probs",           default_sparams.n_probs);
         slot.sparams.min_keep          = json_value(data, "min_keep",          default_sparams.min_keep);
@@ -1253,6 +1255,7 @@ struct server_context {
             {"stop",                      slot.params.antiprompt},
             {"n_predict",                 slot.params.n_predict}, // TODO: fix duplicate key n_predict
             {"n_keep",                    slot.params.n_keep},
+            {"n_discard",                 slot.params.n_discard},
             {"ignore_eos",                ignore_eos},
             {"stream",                    slot.params.stream},
             {"logit_bias",                slot.sparams.logit_bias},
@@ -1696,7 +1699,7 @@ struct server_context {
                     // Shift context
                     const int n_keep    = slot.params.n_keep + add_bos_token;
                     const int n_left    = (int) system_tokens.size() + slot.n_past - n_keep;
-                    const int n_discard = n_left / 2;
+                    const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
 
                     LOG_INFO("slot context shift", {
                         {"id_slot",         slot.id},