return 0;
}
- std::string read_all(const std::string & filename){
- open(filename, "r");
- lock();
- if (!file) {
- printe("Error opening file '%s': %s", filename.c_str(), strerror(errno));
- return "";
- }
-
+ std::string to_string() {
fseek(file, 0, SEEK_END);
- size_t size = ftell(file);
+ const size_t size = ftell(file);
fseek(file, 0, SEEK_SET);
-
std::string out;
out.resize(size);
- size_t read_size = fread(&out[0], 1, size, file);
+ const size_t read_size = fread(&out[0], 1, size, file);
if (read_size != size) {
- printe("Error reading file '%s': %s", filename.c_str(), strerror(errno));
- return "";
+ printe("Error reading file: %s", strerror(errno));
}
+
return out;
}
// Reads a chat template file to be used
static std::string read_chat_template_file(const std::string & chat_template_file) {
- if(chat_template_file.empty()){
- return "";
- }
-
File file;
- std::string chat_template = "";
- chat_template = file.read_all(chat_template_file);
- if(chat_template.empty()){
+ if (!file.open(chat_template_file, "r")) {
printe("Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno));
return "";
}
- return chat_template;
+
+ return file.to_string();
+}
+
+static int process_user_message(const Opt & opt, const std::string & user_input, LlamaData & llama_data,
+ const common_chat_templates_ptr & chat_templates, int & prev_len,
+ const bool stdout_a_terminal) {
+ add_message("user", opt.user.empty() ? user_input : opt.user, llama_data);
+ int new_len;
+ if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, opt.use_jinja) < 0) {
+ return 1;
+ }
+
+ std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
+ std::string response;
+ if (generate_response(llama_data, prompt, response, stdout_a_terminal)) {
+ return 1;
+ }
+
+ if (!opt.user.empty()) {
+ return 2;
+ }
+
+ add_message("assistant", response, llama_data);
+ if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, opt.use_jinja) < 0) {
+ return 1;
+ }
+
+ return 0;
}
// Main chat loop function
-static int chat_loop(LlamaData & llama_data, const std::string & user, const std::string & chat_template_file, bool use_jinja) {
+static int chat_loop(LlamaData & llama_data, const Opt & opt) {
int prev_len = 0;
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
-
- std::string chat_template = "";
- if(!chat_template_file.empty()){
- chat_template = read_chat_template_file(chat_template_file);
+ std::string chat_template;
+ if (!opt.chat_template_file.empty()) {
+ chat_template = read_chat_template_file(opt.chat_template_file);
}
- auto chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template.empty() ? nullptr : chat_template);
+ common_chat_templates_ptr chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template);
static const bool stdout_a_terminal = is_stdout_a_terminal();
while (true) {
// Get user input
std::string user_input;
- if (get_user_input(user_input, user) == 1) {
+ if (get_user_input(user_input, opt.user) == 1) {
return 0;
}
- add_message("user", user.empty() ? user_input : user, llama_data);
- int new_len;
- if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, use_jinja) < 0) {
- return 1;
- }
-
- std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
- std::string response;
- if (generate_response(llama_data, prompt, response, stdout_a_terminal)) {
+ const int ret = process_user_message(opt, user_input, llama_data, chat_templates, prev_len, stdout_a_terminal);
+ if (ret == 1) {
return 1;
- }
-
- if (!user.empty()) {
+ } else if (ret == 2) {
break;
}
-
- add_message("assistant", response, llama_data);
- if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, use_jinja) < 0) {
- return 1;
- }
}
return 0;
return 1;
}
- if (chat_loop(llama_data, opt.user, opt.chat_template_file, opt.use_jinja)) {
+ if (chat_loop(llama_data, opt)) {
return 1;
}