using json = nlohmann::json;
inline static json oaicompat_completion_params_parse(
- const json &body /* openai api json semantics */)
+ const json &body, /* openai api json semantics */
+ const std::string &chat_template)
{
json llama_params;
+ std::string formatted_prompt = chat_template == "chatml"
+ ? format_chatml(body["messages"]) // OpenAI 'messages' to chatml (with <|im_start|>,...)
+ : format_llama2(body["messages"]); // OpenAI 'messages' to llama2 (with [INST],...)
llama_params["__oaicompat"] = true;
// https://platform.openai.com/docs/api-reference/chat/create
llama_sampling_params default_sparams;
llama_params["model"] = json_value(body, "model", std::string("unknown"));
- llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt'
+ llama_params["prompt"] = formatted_prompt;
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
llama_params["temperature"] = json_value(body, "temperature", 0.0);
llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k);
std::string hostname = "127.0.0.1";
std::vector<std::string> api_keys;
std::string public_path = "examples/server/public";
+ std::string chat_template = "chatml";
int32_t port = 8080;
int32_t read_timeout = 600;
int32_t write_timeout = 600;
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`");
printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`");
+ printf(" --chat-template FORMAT_NAME");
+ printf(" set chat template, possible valus is: llama2, chatml (default %s)", sparams.chat_template.c_str());
printf("\n");
}
log_set_target(stdout);
LOG_INFO("logging to file is disabled.", {});
}
+ else if (arg == "--chat-template")
+ {
+ if (++i >= argc)
+ {
+ invalid_param = true;
+ break;
+ }
+ std::string value(argv[i]);
+ if (value != "chatml" && value != "llama2") {
+ fprintf(stderr, "error: chat template can be \"llama2\" or \"chatml\", but got: %s\n", value.c_str());
+ invalid_param = true;
+ break;
+ }
+ sparams.chat_template = value;
+ }
else if (arg == "--override-kv")
{
if (++i >= argc) {
// TODO: add mount point without "/v1" prefix -- how?
- svr.Post("/v1/chat/completions", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
+ svr.Post("/v1/chat/completions", [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) {
return;
}
- json data = oaicompat_completion_params_parse(json::parse(req.body));
+ json data = oaicompat_completion_params_parse(json::parse(req.body), sparams.chat_template);
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
: default_value;
}
+inline std::string format_llama2(std::vector<json> messages)
+{
+ std::ostringstream output;
+ bool is_inside_turn = false;
+
+ for (auto it = messages.begin(); it != messages.end(); ++it) {
+ if (!is_inside_turn) {
+ output << "[INST] ";
+ }
+ std::string role = json_value(*it, "role", std::string("user"));
+ std::string content = json_value(*it, "content", std::string(""));
+ if (role == "system") {
+ output << "<<SYS>>\n" << content << "\n<<SYS>>\n\n";
+ is_inside_turn = true;
+ } else if (role == "user") {
+ output << content << " [/INST]";
+ is_inside_turn = true;
+ } else {
+ output << " " << content << " </s>";
+ is_inside_turn = false;
+ }
+ }
+
+ LOG_VERBOSE("format_llama2", {{"text", output.str()}});
+
+ return output.str();
+}
+
inline std::string format_chatml(std::vector<json> messages)
{
std::ostringstream chatml_msgs;
chatml_msgs << "<|im_start|>assistant" << '\n';
+ LOG_VERBOSE("format_chatml", {{"text", chatml_msgs.str()}});
+
return chatml_msgs.str();
}