]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : fix system prompt cli (#5516)
authorRőczey Barnabás <redacted>
Fri, 16 Feb 2024 10:00:56 +0000 (11:00 +0100)
committerGitHub <redacted>
Fri, 16 Feb 2024 10:00:56 +0000 (12:00 +0200)
examples/server/server.cpp

index 912c750cc6223ca7c6e3e6b43e1f039d8233f605..0cb802ce851adbbbeff324042519dc4e8df68d3b 100644 (file)
@@ -436,10 +436,6 @@ struct llama_server_context
         default_generation_settings_for_props["seed"] = -1;
 
         batch = llama_batch_init(n_ctx, 0, params.n_parallel);
-
-        // empty system prompt
-        system_prompt = "";
-        system_tokens.clear();
     }
 
     std::vector<llama_token> tokenize(const json & json_prompt, bool add_bos) const
@@ -765,27 +761,30 @@ struct llama_server_context
     }
 
     void update_system_prompt() {
-        system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token);
+        kv_cache_clear();
+        system_tokens.clear();
 
-        llama_batch_clear(batch);
+        if (!system_prompt.empty()) {
+            system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token);
 
-        kv_cache_clear();
+            llama_batch_clear(batch);
 
-        for (int i = 0; i < (int) system_tokens.size(); ++i)
-        {
-            llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
-        }
+            for (int i = 0; i < (int)system_tokens.size(); ++i)
+            {
+                llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
+            }
 
-        if (llama_decode(ctx, batch) != 0)
-        {
-            LOG_TEE("%s: llama_decode() failed\n", __func__);
-            return;
-        }
+            if (llama_decode(ctx, batch) != 0)
+            {
+                LOG_TEE("%s: llama_decode() failed\n", __func__);
+                return;
+            }
 
-        // assign the system KV cache to all parallel sequences
-        for (int32_t i = 1; i < params.n_parallel; ++i)
-        {
-            llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size());
+            // assign the system KV cache to all parallel sequences
+            for (int32_t i = 1; i < params.n_parallel; ++i)
+            {
+                llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size());
+            }
         }
 
         LOG_TEE("system prompt updated\n");
@@ -807,10 +806,8 @@ struct llama_server_context
         name_user      = sys_props.value("anti_prompt", "");
         name_assistant = sys_props.value("assistant_name", "");
 
-        if (slots.size() > 0)
-        {
-            notify_system_prompt_changed();
-        }
+
+        notify_system_prompt_changed();
     }
 
     static size_t find_stopping_strings(const std::string &text, const size_t last_token_size,