llama_batch batch;
int n_batch;
+ std::vector<mtmd_bitmap> bitmaps;
+
// note: we know that gemma3 template is "linear", meaning each turn is completely separated to another
// so here we don't need to keep track of chat history
common_chat_templates_ptr tmpls;
antiprompt_tokens.begin()
);
}
+
+ bool load_image(const std::string & fname) {
+ mtmd_bitmap bitmap;
+ if (mtmd_helper_bitmap_init_from_file(fname.c_str(), bitmap)) {
+ return false;
+ }
+ bitmaps.push_back(std::move(bitmap));
+ return true;
+ }
};
static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
llama_tokens generated_tokens;
for (int i = 0; i < n_predict; i++) {
if (i > n_predict || !g_is_generating || g_is_interrupted) {
- printf("\n");
+ LOG("\n");
break;
}
common_sampler_accept(smpl, token_id, true);
if (llama_vocab_is_eog(ctx.vocab, token_id) || ctx.check_antiprompt(generated_tokens)) {
- printf("\n");
+ LOG("\n");
break; // end of generation
}
- printf("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
+ LOG("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
fflush(stdout);
if (g_is_interrupted) {
- printf("\n");
+ LOG("\n");
break;
}
return 0;
}
-static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vector<std::string> & images_fname, bool add_bos = false) {
- std::vector<mtmd_bitmap> bitmaps;
-
+static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, bool add_bos = false) {
common_chat_templates_inputs tmpl_inputs;
tmpl_inputs.messages = {msg};
tmpl_inputs.add_generation_prompt = true;
auto formatted_chat = common_chat_templates_apply(ctx.tmpls.get(), tmpl_inputs);
LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str());
- for (auto & fname : images_fname) {
- mtmd_bitmap bitmap;
- if (mtmd_helper_bitmap_init_from_file(fname.c_str(), bitmap)) {
- LOG_ERR("Unable to load image %s\n", fname.c_str());
- return 2; // image not found
- }
- bitmaps.push_back(std::move(bitmap));
- }
-
mtmd_input_text text;
text.text = formatted_chat.prompt;
text.add_special = add_bos;
if (g_is_interrupted) return 0;
- int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, bitmaps);
+ int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, ctx.bitmaps);
if (res != 0) {
LOG_ERR("Unable to tokenize prompt, res = %d\n", res);
return 1;
}
+ ctx.bitmaps.clear();
+
if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks, ctx.n_past, 0, ctx.n_batch)) {
LOG_ERR("Unable to eval prompt\n");
return 1;
ctx.n_past += mtmd_helper_get_n_pos(chunks);
+ LOG("\n");
+
return 0;
}
}
mtmd_cli_context ctx(params);
- printf("%s: %s\n", __func__, params.model.path.c_str());
+ LOG("%s: loading model: %s\n", __func__, params.model.path.c_str());
bool is_single_turn = !params.prompt.empty() && !params.image.empty();
common_chat_msg msg;
msg.role = "user";
msg.content = params.prompt;
- if (eval_message(ctx, msg, params.image, true)) {
+ for (const auto & image : params.image) {
+ if (!ctx.load_image(image)) {
+ return 1; // error is already printed by libmtmd
+ }
+ }
+ if (eval_message(ctx, msg, true)) {
return 1;
}
if (!g_is_interrupted && generate_response(ctx, smpl, n_predict)) {
LOG("\n");
bool is_first_msg = true;
- std::vector<std::string> images_fname;
std::string content;
while (!g_is_interrupted) {
continue;
}
g_is_generating = true;
- if (line.find("/image") == 0) {
+ if (line == "/image" || line.find("/image ") == 0) {
+ if (line.size() < 8) {
+ LOG_ERR("ERR: Missing image filename\n");
+ continue;
+ }
std::string image = line.substr(7);
- images_fname.push_back(string_strip(image));
- content += "<__image__>";
+ if (ctx.load_image(image)) {
+ LOG("Image %s loaded\n", image.c_str());
+ content += "<__image__>";
+ }
+ // else, error is already printed by libmtmd
continue;
} else {
content += line;
common_chat_msg msg;
msg.role = "user";
msg.content = content;
- int ret = eval_message(ctx, msg, images_fname, is_first_msg);
- if (g_is_interrupted) break;
- if (ret == 2) {
- // non-fatal error
- images_fname.clear();
- content.clear();
- continue;
- }
+ int ret = eval_message(ctx, msg, is_first_msg);
if (ret) {
return 1;
}
+ if (g_is_interrupted) break;
if (generate_response(ctx, smpl, n_predict)) {
return 1;
}
- images_fname.clear();
content.clear();
is_first_msg = false;
}