]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
examples : do not assume BOS when shifting context (#5622)
authorJared Van Bortel <redacted>
Wed, 21 Feb 2024 15:33:54 +0000 (10:33 -0500)
committerGitHub <redacted>
Wed, 21 Feb 2024 15:33:54 +0000 (10:33 -0500)
examples/main/main.cpp
examples/server/server.cpp

index f5d2f48935eb6ed3b7d5be3cbf41ecff9bc18dea..7555dffe441f0aea3847b904f0a70ef417542942 100644 (file)
@@ -334,6 +334,8 @@ int main(int argc, char ** argv) {
     // number of tokens to keep when resetting context
     if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct || params.chatml) {
         params.n_keep = (int)embd_inp.size();
+    } else {
+        params.n_keep += add_bos; // always keep the BOS token
     }
 
     // prefix & suffix for instruct mode
@@ -383,8 +385,8 @@ int main(int argc, char ** argv) {
             }
         }
 
-        if (params.n_keep > 0) {
-        LOG_TEE("%s: static prompt based on n_keep: '", __func__);
+        if (params.n_keep > add_bos) {
+            LOG_TEE("%s: static prompt based on n_keep: '", __func__);
             for (int i = 0; i < params.n_keep; i++) {
                 LOG_TEE("%s", llama_token_to_piece(ctx, embd_inp[i]).c_str());
             }
@@ -540,14 +542,14 @@ int main(int argc, char ** argv) {
                         break;
                     }
 
-                    const int n_left    = n_past - params.n_keep - 1;
+                    const int n_left    = n_past - params.n_keep;
                     const int n_discard = n_left/2;
 
                     LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
                             n_past, n_left, n_ctx, params.n_keep, n_discard);
 
-                    llama_kv_cache_seq_rm   (ctx, 0, params.n_keep + 1            , params.n_keep + n_discard + 1);
-                    llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
+                    llama_kv_cache_seq_rm   (ctx, 0, params.n_keep            , params.n_keep + n_discard);
+                    llama_kv_cache_seq_shift(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
 
                     n_past -= n_discard;
 
index 1c44795128fc155a86fc9c31fffaad2f5153a5c0..c84719a0d15d0739f450c21614f643010c02c333 100644 (file)
@@ -1487,14 +1487,15 @@ struct llama_server_context
                 if (slot.is_processing() && system_tokens.size() + slot.cache_tokens.size() >= (size_t) slot.n_ctx)
                 {
                     // Shift context
-                    const int n_left    = system_tokens.size() + slot.n_past - slot.params.n_keep - 1;
+                    const int n_keep    = slot.params.n_keep + add_bos_token;
+                    const int n_left    = system_tokens.size() + slot.n_past - n_keep;
                     const int n_discard = n_left / 2;
 
-                    LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard);
-                    llama_kv_cache_seq_rm   (ctx, slot.id, slot.params.n_keep + 1            , slot.params.n_keep + n_discard + 1);
-                    llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, system_tokens.size() + slot.n_past, -n_discard);
+                    LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, n_keep, n_left, n_discard);
+                    llama_kv_cache_seq_rm   (ctx, slot.id, n_keep            , n_keep + n_discard);
+                    llama_kv_cache_seq_shift(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
 
-                    for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++)
+                    for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++)
                     {
                         slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
                     }
@@ -1507,7 +1508,7 @@ struct llama_server_context
 
                     LOG_VERBOSE("context shift", {
                         { "n_ctx", n_ctx },
-                        { "n_keep", params.n_keep },
+                        { "n_keep", n_keep },
                         { "n_left", n_left },
                     });
                 }