]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
examples : fix add_special conditions (#11311)
authorGeorgi Gerganov <redacted>
Mon, 20 Jan 2025 14:36:08 +0000 (16:36 +0200)
committerGitHub <redacted>
Mon, 20 Jan 2025 14:36:08 +0000 (16:36 +0200)
examples/run/run.cpp
examples/simple-chat/simple-chat.cpp

index dd9ea79e86adbd45077cb3dd9978d08e9d5363fc..d04108e7183654c7ce4fbb732800b84ef907e3ca 100644 (file)
@@ -729,10 +729,12 @@ static int apply_chat_template(LlamaData & llama_data, const bool append) {
 
 // Function to tokenize the prompt
 static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
-                           std::vector<llama_token> & prompt_tokens) {
-    const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
+                           std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) {
+    const bool is_first = llama_get_kv_cache_used_cells(llama_data.context.get()) == 0;
+
+    const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
     prompt_tokens.resize(n_prompt_tokens);
-    if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true,
+    if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first,
                        true) < 0) {
         printe("failed to tokenize the prompt\n");
         return -1;
@@ -778,7 +780,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
     const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get());
 
     std::vector<llama_token> tokens;
-    if (tokenize_prompt(vocab, prompt, tokens) < 0) {
+    if (tokenize_prompt(vocab, prompt, tokens, llama_data) < 0) {
         return 1;
     }
 
index 26422601d5a3eae1506034b788d64aa676258bb9..212b3fd79a609d0fe00b9b64c0e44fb13d2f2dc8 100644 (file)
@@ -95,13 +95,15 @@ int main(int argc, char ** argv) {
     llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
 
     // helper function to evaluate a prompt and generate a response
-    auto generate = [&](const std::string & prompt, bool is_first) {
+    auto generate = [&](const std::string & prompt) {
         std::string response;
 
+        const bool is_first = llama_get_kv_cache_used_cells(ctx) == 0;
+
         // tokenize the prompt
         const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
         std::vector<llama_token> prompt_tokens(n_prompt_tokens);
-        if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_get_kv_cache_used_cells(ctx) == 0, true) < 0) {
+        if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first, true) < 0) {
             GGML_ABORT("failed to tokenize the prompt\n");
         }
 
@@ -180,7 +182,7 @@ int main(int argc, char ** argv) {
 
         // generate a response
         printf("\033[33m");
-        std::string response = generate(prompt, prev_len == 0);
+        std::string response = generate(prompt);
         printf("\n\033[0m");
 
         // add the response to the messages