]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
simple-chat : fix BOS being added to each message (#11278)
authorGeorgi Gerganov <redacted>
Sun, 19 Jan 2025 16:12:09 +0000 (18:12 +0200)
committerGitHub <redacted>
Sun, 19 Jan 2025 16:12:09 +0000 (18:12 +0200)
examples/simple-chat/simple-chat.cpp

index e8eda9c22328817828c3abf113569e106e22f5d0..26422601d5a3eae1506034b788d64aa676258bb9 100644 (file)
@@ -95,11 +95,11 @@ int main(int argc, char ** argv) {
     llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
 
     // helper function to evaluate a prompt and generate a response
-    auto generate = [&](const std::string & prompt) {
+    auto generate = [&](const std::string & prompt, bool is_first) {
         std::string response;
 
         // tokenize the prompt
-        const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
+        const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
         std::vector<llama_token> prompt_tokens(n_prompt_tokens);
         if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_get_kv_cache_used_cells(ctx) == 0, true) < 0) {
             GGML_ABORT("failed to tokenize the prompt\n");
@@ -180,7 +180,7 @@ int main(int argc, char ** argv) {
 
         // generate a response
         printf("\033[33m");
-        std::string response = generate(prompt);
+        std::string response = generate(prompt, prev_len == 0);
         printf("\n\033[0m");
 
         // add the response to the messages