]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : fix crash when prompt exceeds context size (#3996)
authorAlexey Parfenov <redacted>
Sat, 11 Nov 2023 05:48:21 +0000 (05:48 +0000)
committerGitHub <redacted>
Sat, 11 Nov 2023 05:48:21 +0000 (23:48 -0600)
examples/server/server.cpp

index cbf36ad6752b671cb0740ad785a1d3fd91f2f0cc..46862a84b99da7093369257cd86f3be5eb1faaea 100644 (file)
@@ -1557,6 +1557,35 @@ struct llama_server_context
 
                     slot.num_prompt_tokens = prompt_tokens.size();
 
+                    if (slot.params.n_keep < 0)
+                    {
+                        slot.params.n_keep = slot.num_prompt_tokens;
+                    }
+                    slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
+
+                    // if input prompt is too big, truncate it
+                    if (slot.num_prompt_tokens >= slot.n_ctx)
+                    {
+                        const int n_left = slot.n_ctx - slot.params.n_keep;
+                        const int n_block_size = n_left / 2;
+                        const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
+
+                        std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep);
+                        new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
+
+                        LOG_VERBOSE("input truncated", {
+                            {"n_ctx",  slot.n_ctx},
+                            {"n_keep", slot.params.n_keep},
+                            {"n_left", n_left},
+                            {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
+                        });
+                        slot.truncated = true;
+                        prompt_tokens = new_tokens;
+
+                        slot.num_prompt_tokens = prompt_tokens.size();
+                        GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx);
+                    }
+
                     if (!slot.params.cache_prompt)
                     {
                         llama_sampling_reset(slot.ctx_sampling);
@@ -1566,35 +1595,6 @@ struct llama_server_context
                     }
                     else
                     {
-                        if (slot.params.n_keep < 0)
-                        {
-                            slot.params.n_keep = slot.num_prompt_tokens;
-                        }
-                        slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
-
-                        // if input prompt is too big, truncate it
-                        if (slot.num_prompt_tokens >= slot.n_ctx)
-                        {
-                            const int n_left = slot.n_ctx - slot.params.n_keep;
-                            const int n_block_size = n_left / 2;
-                            const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
-
-                            std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep);
-                            new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
-
-                            LOG_VERBOSE("input truncated", {
-                                                            {"n_ctx",  slot.n_ctx},
-                                                            {"n_keep", slot.params.n_keep},
-                                                            {"n_left", n_left},
-                                                            {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
-                                                        });
-                            slot.truncated = true;
-                            prompt_tokens = new_tokens;
-
-                            slot.num_prompt_tokens = prompt_tokens.size();
-                            GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx);
-                        }
-
                         // push the prompt into the sampling context (do not apply grammar)
                         for (auto &token : prompt_tokens)
                         {