options.push_back({ "main", " --cfg-negative-prompt-file FNAME",
"negative prompt file to use for guidance" });
options.push_back({ "main", " --cfg-scale N", "strength of guidance (default: %.1f, 1.0 = disable)", (double)sparams.cfg_scale });
-
+ options.push_back({ "main", " --chat-template JINJA_TEMPLATE",
+ "set custom jinja chat template (default: template taken from model's metadata)\n"
+ "only commonly used templates are accepted:\n"
+ "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
options.push_back({ "grammar" });
options.push_back({ "*", " --grammar GRAMMAR", "BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", sparams.grammar.c_str() });
options.push_back({ "*", " --grammar-file FNAME", "file to read grammar from" });
return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
}
+//
+// Chat template utils
+//
+
bool llama_chat_verify_template(const std::string & tmpl) {
llama_chat_message chat[] = {{"user", "test"}};
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
return res >= 0;
}
+std::string llama_chat_apply_template(const struct llama_model * model,
+ const std::string & tmpl,
+ const std::vector<llama_chat_msg> & msgs,
+ bool add_ass) {
+ int alloc_size = 0;
+ std::vector<llama_chat_message> chat;
+ for (auto & msg : msgs) {
+ chat.push_back({msg.role.c_str(), msg.content.c_str()});
+ alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
+ }
+
+ const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
+ std::vector<char> buf(alloc_size);
+
+ // run the first time to get the total output length
+ int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
+
+ // if it turns out that our buffer is too small, we resize it
+ if ((size_t) res > buf.size()) {
+ buf.resize(res);
+ res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
+ }
+
+ std::string formatted_chat(buf.data(), res);
+ return formatted_chat;
+}
+
+std::string llama_chat_format_single(const struct llama_model * model,
+ const std::string & tmpl,
+ const std::vector<llama_chat_msg> & past_msg,
+ const llama_chat_msg & new_msg,
+ bool add_ass) {
+ auto fmt_past_msg = llama_chat_apply_template(model, tmpl, past_msg, false);
+ std::vector<llama_chat_msg> chat_new(past_msg);
+ chat_new.push_back(new_msg);
+ auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass);
+ auto formatted = fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
+ return formatted;
+}
+
+std::string llama_chat_format_example(const struct llama_model * model,
+ const std::string & tmpl) {
+ std::vector<llama_chat_msg> msgs = {
+ {"system", "You are a helpful assistant"},
+ {"user", "Hello"},
+ {"assistant", "Hi there"},
+ {"user", "How are you?"},
+ };
+ return llama_chat_apply_template(model, tmpl, msgs, true);
+}
+
//
// KV cache utils
//
// Chat template utils
//
+// same with llama_chat_message, but uses std::string
+struct llama_chat_msg {
+ std::string role;
+ std::string content;
+};
+
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
bool llama_chat_verify_template(const std::string & tmpl);
+// CPP wrapper for llama_chat_apply_template
+std::string llama_chat_apply_template(const struct llama_model * model,
+ const std::string & tmpl,
+ const std::vector<llama_chat_msg> & chat,
+ bool add_ass);
+
+// Format single message, while taking into account the position of that message in chat history
+std::string llama_chat_format_single(const struct llama_model * model,
+ const std::string & tmpl,
+ const std::vector<llama_chat_msg> & past_msg,
+ const llama_chat_msg & new_msg,
+ bool add_ass);
+
+// Returns an example of formatted chat
+std::string llama_chat_format_example(const struct llama_model * model,
+ const std::string & tmpl);
+
//
// KV cache utils
//
static std::vector<llama_token> * g_output_tokens;
static bool is_interacting = false;
-static bool file_exists(const std::string &path) {
+static bool file_exists(const std::string & path) {
std::ifstream f(path.c_str());
return f.good();
}
-static bool file_is_empty(const std::string &path) {
+static bool file_is_empty(const std::string & path) {
std::ifstream f;
f.exceptions(std::ifstream::failbit | std::ifstream::badbit);
f.open(path.c_str(), std::ios::in | std::ios::binary | std::ios::ate);
LOG_TEE("%s", text);
}
+static std::string chat_add_and_format(struct llama_model * model, std::vector<llama_chat_msg> & chat_msgs, std::string role, std::string content) {
+ llama_chat_msg new_msg{role, content};
+ auto formatted = llama_chat_format_single(
+ model, g_params->chat_template, chat_msgs, new_msg, role == "user");
+ chat_msgs.push_back({role, content});
+ return formatted;
+}
+
int main(int argc, char ** argv) {
gpt_params params;
g_params = ¶ms;
llama_model * model;
llama_context * ctx;
llama_context * ctx_guidance = NULL;
+ std::vector<llama_chat_msg> chat_msgs;
g_model = &model;
g_ctx = &ctx;
__func__, n_ctx_train, n_ctx);
}
+ LOG_TEE("%s: chat template example: %s\n", __func__, llama_chat_format_example(model, params.chat_template).c_str());
+
// print system information
{
LOG_TEE("\n");
std::vector<llama_token> embd_inp;
- if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
- LOG("tokenize the prompt\n");
- embd_inp = ::llama_tokenize(ctx, params.prompt, true, true);
- } else {
- LOG("use session tokens\n");
- embd_inp = session_tokens;
- }
+ {
+ auto prompt = params.conversation
+ ? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode
+ : params.prompt;
+ if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
+ LOG("tokenize the prompt\n");
+ embd_inp = ::llama_tokenize(ctx, prompt, true, true);
+ } else {
+ LOG("use session tokens\n");
+ embd_inp = session_tokens;
+ }
- LOG("prompt: \"%s\"\n", log_tostr(params.prompt));
- LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
+ LOG("prompt: \"%s\"\n", log_tostr(prompt));
+ LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
+ }
// Should not run without any tokens
if (embd_inp.empty()) {
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
std::ostringstream output_ss; g_output_ss = &output_ss;
+ std::ostringstream assistant_ss; // for storing current assistant message, used in conversation mode
// the first thing we will do is to output the prompt, so set color accordingly
console::set_display(console::prompt);
is_antiprompt = true;
}
+ chat_add_and_format(model, chat_msgs, "system", assistant_ss.str());
is_interacting = true;
printf("\n");
}
}
+ // if current token is not EOG, we add it to current assistant message
+ if (params.conversation) {
+ auto id = llama_sampling_last(ctx_sampling);
+ assistant_ss << llama_token_to_piece(ctx, id, false);
+ }
+
if (n_past > 0 && is_interacting) {
LOG("waiting for user input\n");
string_process_escapes(buffer);
}
+ std::string user_inp = params.conversation
+ ? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
+ : std::move(buffer);
+ // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
- const auto line_inp = ::llama_tokenize(ctx, buffer, false, false);
+ const auto line_inp = ::llama_tokenize(ctx, user_inp, false, params.conversation);
const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true);
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());
output_ss << llama_token_to_piece(ctx, token);
}
+ // reset assistant message
+ assistant_ss.str("");
+
n_remain -= line_inp.size();
LOG("n_remain: %d\n", n_remain);
} else {
// print sample chat example to make it clear which template is used
{
- json chat;
- chat.push_back({{"role", "system"}, {"content", "You are a helpful assistant"}});
- chat.push_back({{"role", "user"}, {"content", "Hello"}});
- chat.push_back({{"role", "assistant"}, {"content", "Hi there"}});
- chat.push_back({{"role", "user"}, {"content", "How are you?"}});
-
- const std::string chat_example = format_chat(ctx_server.model, params.chat_template, chat);
-
LOG_INFO("chat template", {
- {"chat_example", chat_example},
- {"built_in", params.chat_template.empty()},
+ {"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
+ {"built_in", params.chat_template.empty()},
});
}
// Format given chat. If tmpl is empty, we take the template from model metadata
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
- size_t alloc_size = 0;
- // vector holding all allocated string to be passed to llama_chat_apply_template
- std::vector<std::string> str(messages.size() * 2);
- std::vector<llama_chat_message> chat(messages.size());
+ std::vector<llama_chat_msg> chat;
for (size_t i = 0; i < messages.size(); ++i) {
const auto & curr_msg = messages[i];
- str[i*2 + 0] = json_value(curr_msg, "role", std::string(""));
- str[i*2 + 1] = json_value(curr_msg, "content", std::string(""));
- alloc_size += str[i*2 + 1].length();
- chat[i].role = str[i*2 + 0].c_str();
- chat[i].content = str[i*2 + 1].c_str();
+ std::string role = json_value(curr_msg, "role", std::string(""));
+ std::string content = json_value(curr_msg, "content", std::string(""));
+ chat.push_back({role, content});
}
- const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
- std::vector<char> buf(alloc_size * 2);
-
- // run the first time to get the total output length
- int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
-
- // if it turns out that our buffer is too small, we resize it
- if ((size_t) res > buf.size()) {
- buf.resize(res);
- res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
- }
-
- const std::string formatted_chat(buf.data(), res);
-
+ auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true);
LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}});
-
return formatted_chat;
}
if (add_ass) {
ss << "<|im_start|>assistant\n";
}
- } else if (tmpl == "llama2" || tmpl.find("[INST]") != std::string::npos) {
+ } else if (tmpl == "llama2" || tmpl == "mistral" || tmpl.find("[INST]") != std::string::npos) {
// llama2 template and its variants
// [variant] support system message
- bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos;
+ bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos || tmpl == "mistral";
// [variant] space before + after response
bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos;
// [variant] add BOS inside history
#include <cassert>
#include "llama.h"
+#include "common.h"
int main(void) {
llama_chat_message conversation[] = {
std::cout << output << "\n-------------------------\n";
assert(output == expected);
}
+
+ // test llama_chat_format_single
+ std::cout << "\n\n=== llama_chat_format_single ===\n\n";
+ std::vector<llama_chat_msg> chat2;
+ chat2.push_back({"system", "You are a helpful assistant"});
+ chat2.push_back({"user", "Hello"});
+ chat2.push_back({"assistant", "I am assistant"});
+ llama_chat_msg new_msg{"user", "How are you"};
+
+ auto fmt_single = [&](std::string tmpl) {
+ auto output = llama_chat_format_single(nullptr, tmpl, chat2, new_msg, true);
+ std::cout << "fmt_single(" << tmpl << ")\n" << output << "\n-------------------------\n";
+ return output;
+ };
+ assert(fmt_single("chatml") == "<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n");
+ assert(fmt_single("llama2") == "[INST] How are you [/INST]");
+ assert(fmt_single("gemma") == "<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n");
+ assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
+
return 0;
}