#include "log.h"
#include "sampling.h"
#include "speculative.h"
+#include "mtmd.h"
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
int id_target = -1;
// used by SERVER_TASK_TYPE_INFERENCE
- slot_params params;
- llama_tokens prompt_tokens;
+ slot_params params;
+ server_tokens prompt_tokens;
int id_selected_slot = -1;
// used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
llama_context * ctx = nullptr;
llama_context * ctx_dft = nullptr;
+ // multimodal
+ mtmd_context * mctx = nullptr;
+
common_speculative * spec = nullptr;
std::vector<common_adapter_lora_info> lora;
int32_t n_prompt_tokens_processed = 0;
// input prompt tokens
- llama_tokens prompt_tokens;
+ server_tokens prompt_tokens;
size_t last_nl_pos = 0;
std::string generated_text;
llama_tokens generated_tokens;
- llama_tokens cache_tokens;
+ server_tokens cache_tokens;
std::vector<completion_token_output> generated_token_probs;
{"is_processing", is_processing()},
{"non_causal", is_non_causal()},
{"params", params.to_json()},
- {"prompt", common_detokenize(ctx, prompt_tokens)},
+ {"prompt", prompt_tokens.detokenize(ctx, true)},
{"next_token",
{
{"has_next_token", has_next_token},
llama_model * model = nullptr;
llama_context * ctx = nullptr;
+ // multimodal
+ mtmd_context * mctx = nullptr;
+
const llama_vocab * vocab = nullptr;
llama_model * model_dft = nullptr;
llama_context_params cparams_dft;
- llama_batch batch = {};
+ llama_batch batch;
bool clean_kv_cache = true;
bool add_bos_token = true;
common_chat_templates_ptr chat_templates;
~server_context() {
+ mtmd_free(mctx);
+
// Clear any sampling context
for (server_slot & slot : slots) {
common_sampler_free(slot.smpl);
chat_templates = common_chat_templates_init(model, "chatml");
}
+ std::string & mmproj_path = params_base.mmproj.path;
+ if (!mmproj_path.empty()) {
+ mtmd_context_params mparams = mtmd_context_params_default();
+ mparams.use_gpu = params_base.mmproj_use_gpu;
+ mparams.print_timings = false;
+ mparams.n_threads = params_base.cpuparams.n_threads;
+ mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
+ mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
+ if (mctx == nullptr) {
+ SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
+ return false;
+ }
+ SRV_INF("loaded multimodal model, '%s'\n", mmproj_path.c_str());
+
+ if (params_base.ctx_shift) {
+ params_base.ctx_shift = false;
+ SRV_WRN("%s\n", "ctx_shift is not supported by multimodal, it will be disabled");
+ }
+
+ if (params_base.n_cache_reuse) {
+ params_base.n_cache_reuse = 0;
+ SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
+ }
+
+ if (!params_base.speculative.model.path.empty()) {
+ SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal");
+ return false;
+ }
+ }
+
return true;
}
slot.ctx = ctx;
slot.n_ctx = n_ctx_slot;
slot.n_predict = params_base.n_predict;
+ slot.mctx = mctx;
+ slot.cache_tokens.has_mtmd = mctx != nullptr;
if (model_dft) {
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
{
const int32_t n_batch = llama_n_batch(ctx);
-
- // only a single seq_id per token is needed
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
}
}
// length of the Longest Common Subsequence between the current slot's prompt and the input prompt
- int cur_lcs_len = common_lcs(slot.cache_tokens, task.prompt_tokens);
+ int cur_lcs_len = slot.cache_tokens.get_common_prefix(task.prompt_tokens);
// fraction of the common subsequence length compared to the current slot's prompt length
float cur_similarity = static_cast<float>(cur_lcs_len) / static_cast<int>(slot.cache_tokens.size());
return ret;
}
- bool can_be_detokenized(const struct llama_context * ctx, const std::vector<llama_token> & tokens) {
- const llama_model * model = llama_get_model(ctx);
- const llama_vocab * vocab = llama_model_get_vocab(model);
- const int32_t n_vocab = llama_vocab_n_tokens(vocab);
- for (const auto & token : tokens) {
- if (token < 0 || token >= n_vocab) {
- return false;
- }
- }
- return true;
- }
-
bool launch_slot_with_task(server_slot & slot, server_task && task) {
slot.reset();
slot.id_task = task.id;
slot.lora = slot.params.lora;
}
- bool can_detokenize = can_be_detokenized(ctx, slot.prompt_tokens);
- if (!can_detokenize) {
+ if (!slot.prompt_tokens.validate(ctx)) {
send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
return false;
}
queue_results.send(std::move(res));
}
+ // if multimodal is enabled, send an error and return false
+ bool ensure_no_mtmd(const int id_task) {
+ if (mctx) {
+ send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED);
+ return false;
+ }
+ return true;
+ }
+
void send_partial_response(server_slot & slot, const completion_token_output & tkn) {
auto res = std::make_unique<server_task_result_cmpl_partial>();
res->content = std::move(slot.generated_text);
res->tokens = std::move(slot.generated_tokens);
res->timings = slot.get_timings();
- res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
+ res->prompt = slot.prompt_tokens.detokenize(ctx, true);
res->response_fields = std::move(slot.params.response_fields);
res->truncated = slot.truncated;
} break;
case SERVER_TASK_TYPE_SLOT_SAVE:
{
+ if (!ensure_no_mtmd(task.id)) {
+ break;
+ }
+
int id_slot = task.slot_action.slot_id;
server_slot * slot = get_slot_by_id(id_slot);
if (slot == nullptr) {
std::string filename = task.slot_action.filename;
std::string filepath = task.slot_action.filepath;
- const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
+ const llama_tokens & tokens = slot->cache_tokens.get_text_tokens();
+ const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count);
const int64_t t_end = ggml_time_us();
const double t_save_ms = (t_end - t_start) / 1000.0;
} break;
case SERVER_TASK_TYPE_SLOT_RESTORE:
{
+ if (!ensure_no_mtmd(task.id)) break;
int id_slot = task.slot_action.slot_id;
server_slot * slot = get_slot_by_id(id_slot);
if (slot == nullptr) {
std::string filename = task.slot_action.filename;
std::string filepath = task.slot_action.filepath;
- slot->cache_tokens.resize(slot->n_ctx);
+ llama_tokens tokens;
+ tokens.resize(slot->n_ctx);
size_t token_count = 0;
- size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
+ size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count);
if (nread == 0) {
- slot->cache_tokens.resize(0);
+ slot->cache_tokens.clear(); // KV may already been invalidated?
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
break;
}
- slot->cache_tokens.resize(token_count);
+ tokens.resize(token_count);
+ slot->cache_tokens.clear();
+ slot->cache_tokens.insert(tokens);
const int64_t t_end = ggml_time_us();
const double t_restore_ms = (t_end - t_start) / 1000.0;
} break;
case SERVER_TASK_TYPE_SLOT_ERASE:
{
+ if (!ensure_no_mtmd(task.id)) break;
int id_slot = task.slot_action.slot_id;
server_slot * slot = get_slot_by_id(id_slot);
if (slot == nullptr) {
res->id = task.id;
queue_results.send(std::move(res));
} break;
+
}
}
continue;
}
+ if (mctx) {
+ // we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded
+ // we don't support ctx_shift because an image chunk may contains multiple tokens
+ GGML_ABORT("not supported by multimodal");
+ }
+
// Shift context
const int n_keep = slot.params.n_keep + add_bos_token;
const int n_left = slot.n_past - n_keep;
llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
if (slot.params.cache_prompt) {
- for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
- slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
+ llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy
+ for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
+ new_tokens[i - n_discard] = new_tokens[i];
}
- slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
+ new_tokens.resize(slot.cache_tokens.size() - n_discard);
+ slot.cache_tokens.clear();
+ slot.cache_tokens.insert(new_tokens);
}
slot.n_past -= n_discard;
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
// print prompt tokens (for debugging)
- if (1) {
+ /*if (1) {
// first 16 tokens (avoid flooding logs)
for (int i = 0; i < std::min<int>(16, prompt_tokens.size()); i++) {
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
for (int i = 0; i < (int) prompt_tokens.size(); i++) {
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
}
- }
+ }*/
// empty prompt passed -> release the slot and send empty response
if (prompt_tokens.empty()) {
// if input prompt is too big, truncate it
if (slot.n_prompt_tokens >= slot.n_ctx) {
+ if (mctx) {
+ // we should never reach this
+ GGML_ABORT("not supported by multimodal");
+ }
const int n_left = slot.n_ctx - slot.params.n_keep;
const int n_block_size = n_left / 2;
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
+ const llama_tokens & curr_tokens = slot.prompt_tokens.get_text_tokens();
llama_tokens new_tokens(
- prompt_tokens.begin(),
- prompt_tokens.begin() + slot.params.n_keep);
+ curr_tokens.begin(),
+ curr_tokens.begin() + slot.params.n_keep);
new_tokens.insert(
new_tokens.end(),
- prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
- prompt_tokens.end());
+ curr_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
+ curr_tokens.end());
- prompt_tokens = std::move(new_tokens);
+ prompt_tokens.clear();
+ prompt_tokens.insert(new_tokens);
slot.truncated = true;
slot.n_prompt_tokens = prompt_tokens.size();
if (slot.params.cache_prompt) {
// reuse any previously computed tokens that are common with the new prompt
- slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens);
+ slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens);
// reuse chunks from the cached prompt by shifting their KV cache in the new position
if (params_base.n_cache_reuse > 0) {
size_t head_c = slot.n_past; // cache
size_t head_p = slot.n_past; // current prompt
+ if (mctx) {
+ // we should never reach this
+ GGML_ABORT("not supported by multimodal");
+ }
+
SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past);
while (head_c < slot.cache_tokens.size() &&
llama_kv_self_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift);
for (size_t i = 0; i < n_match; i++) {
- slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
+ slot.cache_tokens.set_token(head_p + i, slot.cache_tokens[head_c + i]);
slot.n_past++;
}
// remove the non-common part from the cache
slot.cache_tokens.resize(slot.n_past);
+ // check if we should process the image
+ if (slot.n_past < slot.n_prompt_tokens
+ && slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) {
+ // process the image
+ int32_t new_n_past;
+ int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past);
+ int32_t n_pos = new_n_past - slot.n_past;
+
+ if (res != 0) {
+ SLT_ERR(slot, "failed to process image, res = %d\n", res);
+ slot.release();
+ send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
+ continue;
+ }
+
+ if (slot.params.cache_prompt) {
+ const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past);
+ slot.cache_tokens.push_back(chunk.get()); // copy
+ }
+
+ slot.n_past += n_pos;
+ slot.n_prompt_tokens_processed += n_pos;
+ }
+
// add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
+ // get next token to process
+ llama_token cur_tok = slot.prompt_tokens[slot.n_past];
+ if (cur_tok == LLAMA_TOKEN_NULL) {
+ break; // end of text chunk
+ }
+
// without pooling, we want to output the embeddings for all the tokens in the batch
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
- common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
-
+ common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd);
if (slot.params.cache_prompt) {
- slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
+ slot.cache_tokens.push_back(cur_tok);
}
slot.n_prompt_tokens_processed++;
slot.n_past++;
}
+ // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str());
+
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
// entire prompt has been processed
slot.state = SLOT_STATE_DONE_PROMPT;
GGML_ASSERT(batch.n_tokens > 0);
+ GGML_ASSERT((size_t) slot.n_prompt_tokens == slot.prompt_tokens.size());
common_sampler_reset(slot.smpl);
// Process all prompt tokens through sampler system
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
- common_sampler_accept(slot.smpl, prompt_tokens[i], false);
+ llama_token id = slot.prompt_tokens[i];
+ if (id != LLAMA_TOKEN_NULL) {
+ common_sampler_accept(slot.smpl, id, false);
+ }
}
// extract the logits only for the last token
continue;
}
+ if (mctx) {
+ // we should never reach this, as speculative is automatically disabled if mmproj is loaded
+ GGML_ABORT("not supported by multimodal");
+ }
+
// determine the max draft that fits the current slot state
int n_draft_max = slot.params.speculative.n_max;
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
params_spec.p_min = slot.params.speculative.p_min;
- llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
+ const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens();
+ llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
// keep track of total number of tokens generated in the draft
slot.n_draft_total += draft.size();
slot.n_draft_accepted += ids.size() - 1;
slot.cache_tokens.push_back(id);
- slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
+ slot.cache_tokens.insert({ids.begin(), ids.end() - 1});
llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1);
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
{ "total_slots", ctx_server.params_base.n_parallel },
{ "model_path", ctx_server.params_base.model.path },
+ { "modalities", json{{"vision", ctx_server.mctx != nullptr}} }, // TODO: add more in the future
{ "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) },
{ "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)},
{ "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)},
const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok](
server_task_type type,
json & data,
+ const std::vector<raw_buffer> & files,
const std::function<bool()> & is_connection_closed,
httplib::Response & res,
- oaicompat_type oaicompat) {
+ oaicompat_type oaicompat) -> void {
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
if (ctx_server.params_base.embedding) {
// TODO: this log can become very long, put it behind a flag or think about a more compact format
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
- std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
- tasks.reserve(tokenized_prompts.size());
- for (size_t i = 0; i < tokenized_prompts.size(); i++) {
+ // process files
+ mtmd::bitmaps bitmaps;
+ const bool has_mtmd = ctx_server.mctx != nullptr;
+ {
+ if (!has_mtmd && !files.empty()) {
+ throw std::runtime_error("This server does not support multimodal");
+ }
+ for (auto & file : files) {
+ mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(file.data(), file.size()));
+ if (!bmp.ptr) {
+ throw std::runtime_error("Failed to load image");
+ }
+ // calculate bitmap hash (for KV caching)
+ std::string hash = fnv_hash(bmp.data(), bmp.nx()*bmp.ny()*3);
+ bmp.set_id(hash.c_str());
+ bitmaps.entries.push_back(std::move(bmp));
+ }
+ }
+
+ // process prompt
+ std::vector<server_tokens> inputs;
+ if (oaicompat && !prompt.is_string()) {
+ throw std::runtime_error("prompt must be a string");
+ }
+
+ if (oaicompat && has_mtmd) {
+ // multimodal
+ std::string prompt_str = prompt.get<std::string>();
+ mtmd_input_text inp_txt = {
+ prompt_str.c_str(),
+ /* add_special */ true,
+ /* parse_special */ true,
+ };
+ mtmd::input_chunks chunks(mtmd_input_chunks_init());
+ auto bitmaps_c_ptr = bitmaps.c_ptr();
+ int32_t tokenized = mtmd_tokenize(ctx_server.mctx,
+ chunks.ptr.get(),
+ &inp_txt,
+ bitmaps_c_ptr.data(),
+ bitmaps_c_ptr.size());
+ if (tokenized != 0) {
+ throw std::runtime_error("Failed to tokenize prompt");
+ }
+
+ server_tokens tmp(chunks, true);
+ inputs.push_back(std::move(tmp));
+ } else {
+ // non-multimodal version
+ auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
+ for (auto & p : tokenized_prompts) {
+ auto tmp = server_tokens(p, ctx_server.mctx != nullptr);
+ inputs.push_back(std::move(tmp));
+ }
+ }
+
+ tasks.reserve(inputs.size());
+ for (size_t i = 0; i < inputs.size(); i++) {
server_task task = server_task(type);
task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
- task.prompt_tokens = std::move(tokenized_prompts[i]);
+ task.prompt_tokens = std::move(inputs[i]);
task.params = server_task::params_from_json_cmpl(
ctx_server.ctx,
ctx_server.params_base,
const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
json data = json::parse(req.body);
- return handle_completions_impl(
+ std::vector<raw_buffer> files; // dummy
+ handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION,
data,
+ files,
req.is_connection_closed,
res,
OAICOMPAT_TYPE_NONE);
const auto handle_completions_oai = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
json data = oaicompat_completion_params_parse(json::parse(req.body));
- return handle_completions_impl(
+ std::vector<raw_buffer> files; // dummy
+ handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION,
data,
+ files,
req.is_connection_closed,
res,
OAICOMPAT_TYPE_COMPLETION);
tokenized_prompts[0]
);
- return handle_completions_impl(
+ std::vector<raw_buffer> files; // dummy
+ handle_completions_impl(
SERVER_TASK_TYPE_INFILL,
data,
+ files,
req.is_connection_closed,
res,
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
}
auto body = json::parse(req.body);
- json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());
+ std::vector<raw_buffer> files;
+ json data = oaicompat_completion_params_parse(
+ body,
+ params.use_jinja,
+ params.reasoning_format,
+ ctx_server.chat_templates.get(),
+ ctx_server.mctx,
+ files);
- return handle_completions_impl(
+ handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION,
data,
+ files,
req.is_connection_closed,
res,
OAICOMPAT_TYPE_CHAT);
// same with handle_chat_completions, but without inference part
const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request & req, httplib::Response & res) {
auto body = json::parse(req.body);
- json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());
+ std::vector<raw_buffer> files; // dummy, unused
+ json data = oaicompat_completion_params_parse(
+ body,
+ params.use_jinja,
+ params.reasoning_format,
+ ctx_server.chat_templates.get(),
+ ctx_server.mctx,
+ files);
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
};
}
}
- std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
+ auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
for (const auto & tokens : tokenized_prompts) {
// this check is necessary for models that do not add BOS token to the input
if (tokens.empty()) {
task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
- task.prompt_tokens = std::move(tokenized_prompts[i]);
+ task.prompt_tokens = server_tokens(tokenized_prompts[i], ctx_server.mctx != nullptr);
// OAI-compat
task.params.oaicompat = oaicompat;
std::unordered_set<int> task_ids;
{
std::vector<server_task> tasks;
- std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
+ auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
tasks.reserve(tokenized_docs.size());
for (size_t i = 0; i < tokenized_docs.size(); i++) {
+ auto tmp = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
- task.prompt_tokens = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
+ task.prompt_tokens = server_tokens(tmp, ctx_server.mctx != nullptr);
tasks.push_back(std::move(task));
}
#include "common.h"
#include "log.h"
#include "llama.h"
+#include "arg.h" // common_remote_get_content
#include "base64.hpp"
+#include "mtmd.h"
// increase max payload length to allow use of larger context size
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
#include <string>
#include <vector>
#include <memory>
+#include <cinttypes>
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo"
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+using raw_buffer = std::vector<uint8_t>;
+
template <typename T>
static T json_value(const json & body, const std::string & key, const T & default_value) {
// Fallback null to default value
return (isalnum(c) || (c == '+') || (c == '/'));
}
-static inline std::vector<uint8_t> base64_decode(const std::string & encoded_string) {
+static inline raw_buffer base64_decode(const std::string & encoded_string) {
int i = 0;
int j = 0;
int in_ = 0;
uint8_t char_array_4[4];
uint8_t char_array_3[3];
- std::vector<uint8_t> ret;
+ raw_buffer ret;
while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) {
char_array_4[i++] = encoded_string[in_]; in_++;
const json & body, /* openai api json semantics */
bool use_jinja,
common_reasoning_format reasoning_format,
- const struct common_chat_templates * tmpls)
+ const struct common_chat_templates * tmpls,
+ bool allow_non_text,
+ std::vector<raw_buffer> & out_files)
{
json llama_params;
}
}
+ // get input files
+ if (!body.contains("messages")) {
+ throw std::runtime_error("'messages' is required");
+ }
+ json messages = body.at("messages");
+ if (!messages.is_array()) {
+ throw std::runtime_error("Expected 'messages' to be an array");
+ }
+ for (auto & msg : messages) {
+ json & content = msg.at("content");
+ if (content.is_string()) {
+ continue;
+ }
+
+ if (!content.is_array()) {
+ throw std::runtime_error("Expected 'content' to be a string or an array");
+ }
+
+ for (auto & p : content) {
+ std::string type = json_value(p, "type", std::string());
+ json image_url = json_value(p, "image_url", json::object());
+ if (type == "image_url") {
+ if (!allow_non_text) {
+ throw std::runtime_error("image input is not supported by this server");
+ }
+
+ std::string url = json_value(image_url, "url", std::string());
+ if (string_starts_with(url, "http")) {
+ // download remote image
+ // TODO @ngxson : maybe make these params configurable
+ common_remote_params params;
+ params.headers.push_back("User-Agent: llama.cpp/" + build_info);
+ params.max_size = 1024 * 1024 * 10; // 10MB
+ params.timeout = 10; // seconds
+ SRV_INF("downloading image from '%s'\n", url.c_str());
+ auto res = common_remote_get_content(url, params);
+ if (200 <= res.first && res.first < 300) {
+ SRV_INF("downloaded %ld bytes\n", res.second.size());
+ raw_buffer data;
+ data.insert(data.end(), res.second.begin(), res.second.end());
+ out_files.push_back(data);
+ } else {
+ throw std::runtime_error("Failed to download image");
+ }
+
+ } else {
+ // try to decode base64 image
+ std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
+ if (parts.size() != 2) {
+ throw std::runtime_error("Invalid image_url.url value");
+ } else if (!string_starts_with(parts[0], "data:image/")) {
+ throw std::runtime_error("Invalid image_url.url format: " + parts[0]);
+ } else if (!string_ends_with(parts[0], "base64")) {
+ throw std::runtime_error("image_url.url must be base64 encoded");
+ } else {
+ auto base64_data = parts[1];
+ auto decoded_data = base64_decode(base64_data);
+ out_files.push_back(decoded_data);
+ }
+ }
+
+ // replace this chunk with a marker
+ p["type"] = "text";
+ p["text"] = MTMD_DEFAULT_IMAGE_MARKER;
+ p.erase("image_url");
+ }
+ }
+ }
+
common_chat_templates_inputs inputs;
- inputs.messages = common_chat_msgs_parse_oaicompat(body.at("messages"));
+ inputs.messages = common_chat_msgs_parse_oaicompat(messages);
inputs.tools = common_chat_tools_parse_oaicompat(tools);
inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto")));
inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
return lora;
}
+
+//
+// utils for interacting with libmtmd
+// (may need to refactor in near future)
+//
+
+/**
+ * server_tokens is a helper to manage the input tokens and image for the server.
+ * it is made this way to simplify the logic of KV cache management.
+ */
+struct server_tokens {
+ bool has_mtmd = false;
+
+private: // disallow accessing these members directly, risking out-of-sync
+
+ // map a **start** position in tokens to the image chunk
+ std::unordered_map<llama_pos, mtmd::input_chunk_ptr> map_pos_to_image;
+
+ // list of tokens
+ // it can include LLAMA_TOKEN_NULL, which is used to indicate a token that is not a text token
+ // a mtmd_input_chunk can occupy multiple tokens, one llama_token per **position**
+ // important: for models using mrope, an image can contain multiple tokens but will use only one **position**
+ llama_tokens tokens;
+
+ // for ex. with input of 5 text tokens and 2 images:
+ // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1]
+ // pos 0 1 2 3 4 5 6 7 8 9
+ // map_pos_to_image will contain: {5, img0}, {8, img1}
+
+public:
+ server_tokens() = default;
+ ~server_tokens() = default;
+
+ // Prevent copying
+ server_tokens(const server_tokens&) = delete;
+ server_tokens& operator=(const server_tokens&) = delete;
+
+ // Allow moving (usually implicitly generated if members are movable)
+ server_tokens(server_tokens&&) = default;
+ server_tokens& operator=(server_tokens&&) = default;
+
+ // Allow accessing elements using [] operator
+ llama_token operator[](size_t index) { return tokens[index]; }
+ const llama_token& operator[](size_t index) const { return tokens[index]; }
+
+ server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) {
+ for (size_t i = 0; i < mtmd_chunks.size(); ++i) {
+ push_back(mtmd_chunks[i]);
+ }
+ }
+
+ server_tokens(llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {}
+
+ // for debugging
+ std::string str() const {
+ std::ostringstream oss;
+ oss << "tokens: ";
+ for (const auto & t : tokens) {
+ if (t == LLAMA_TOKEN_NULL) {
+ oss << "<embd> ";
+ } else {
+ oss << t << " ";
+ }
+ }
+ oss << "\n";
+ oss << "image pos: ";
+ for (const auto & it : map_pos_to_image) {
+ oss << it.first << ", ";
+ }
+ return oss.str();
+ }
+
+ const mtmd::input_chunk_ptr & find_chunk(llama_pos pos) const {
+ auto it = map_pos_to_image.find(pos);
+ if (it != map_pos_to_image.end()) {
+ return it->second;
+ } else {
+ throw std::runtime_error("Chunk not found");
+ }
+ }
+
+ void push_back(llama_token tok) {
+ if (tok == LLAMA_TOKEN_NULL) {
+ throw std::runtime_error("Invalid token");
+ }
+ tokens.emplace_back(tok);
+ }
+
+ // will create a copy of the chunk if it contains non-text data
+ void push_back(const mtmd_input_chunk * chunk) {
+ auto type = mtmd_input_chunk_get_type(chunk);
+ if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
+ GGML_ASSERT(has_mtmd);
+ auto img_tokens = mtmd_input_chunk_get_tokens_image(chunk);
+ const int n_pos = mtmd_image_tokens_get_n_pos(img_tokens);
+ llama_pos start_pos = tokens.size();
+ for (int i = 0; i < n_pos; ++i) {
+ tokens.emplace_back(LLAMA_TOKEN_NULL);
+ }
+ mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
+ map_pos_to_image[start_pos] = std::move(new_chunk);
+ } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
+ size_t n_tokens;
+ auto text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
+ for (size_t i = 0; i < n_tokens; ++i) {
+ push_back(text_tokens[i]);
+ }
+ } else {
+ GGML_ABORT("Invalid chunk type");
+ }
+ }
+
+ // for compatibility with context shift and prompt truncation
+ void insert(const llama_tokens & inp_tokens) {
+ GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
+ tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end());
+ }
+
+ // for compatibility with speculative decoding, ctx shift, slot save/load
+ const llama_tokens & get_text_tokens() const {
+ GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
+ return tokens;
+ }
+
+ // for compatibility with speculative decoding
+ void set_token(llama_pos pos, llama_token id) {
+ GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
+ tokens[pos] = id;
+ }
+
+ size_t size() const {
+ return tokens.size();
+ }
+
+ bool empty() const {
+ return tokens.empty();
+ }
+
+ void clear() {
+ tokens.clear();
+ }
+
+ void resize(size_t n) {
+ GGML_ASSERT(n <= tokens.size());
+ if (has_mtmd) {
+ // we throw an error if we try to remove a token in the middle of an image
+ // for ex. with input of 5 text tokens and 2 images:
+ // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1]
+ // n 1 2 3 4 5 6 7 8 9 10
+ // allowed to resize ^ ^
+ // disallowed to resize ^ ^ ^
+ if (n > 0) {
+ llama_token last_token = tokens[n - 1];
+ // make sure we never remove tokens in the middle of an image
+ if (last_token == LLAMA_TOKEN_NULL) {
+ find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk
+ }
+ }
+ // remove all image chunks that are not used anymore
+ for (auto it = map_pos_to_image.begin(); it != map_pos_to_image.end(); ) {
+ llama_pos pos = it->first;
+ if (pos >= (llama_pos)n) {
+ it = map_pos_to_image.erase(it);
+ } else {
+ ++it;
+ }
+ }
+ }
+ tokens.resize(n);
+ }
+
+ std::string detokenize(const llama_context * ctx, bool special) const {
+ llama_tokens text_tokens;
+ text_tokens.reserve(tokens.size());
+ for (const auto & t : tokens) {
+ if (t != LLAMA_TOKEN_NULL) {
+ text_tokens.push_back(t);
+ }
+ }
+ return common_detokenize(ctx, text_tokens, special);
+ }
+
+ size_t get_common_prefix(const server_tokens & b) const {
+ size_t max_idx = std::min(tokens.size(), b.tokens.size());
+ for (size_t i = 0; i < max_idx; ++i) {
+ auto & ai = tokens[i];
+ auto & bi = b.tokens[i];
+
+ if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) {
+ GGML_ASSERT(has_mtmd);
+ const auto & a_chunk = find_chunk(i);
+ const auto & b_chunk = b.find_chunk(i);
+ GGML_ASSERT(a_chunk && b_chunk);
+ const auto * a_img = mtmd_input_chunk_get_tokens_image(a_chunk.get());
+ const auto * b_img = mtmd_input_chunk_get_tokens_image(b_chunk.get());
+ std::string ai_id = mtmd_image_tokens_get_id(a_img);
+ std::string bi_id = mtmd_image_tokens_get_id(b_img);
+ size_t a_pos = mtmd_image_tokens_get_n_pos(a_img);
+ size_t b_pos = mtmd_image_tokens_get_n_pos(b_img);
+ if (ai_id == bi_id && a_pos == b_pos) {
+ GGML_ASSERT(a_pos > 0 && "Invalid image token"); // should never happen
+ i += a_pos - 1; // will be +1 by the for loop
+ continue;
+ } else {
+ return i;
+ }
+ } else if (ai == bi) {
+ continue;
+ } else {
+ return i;
+ }
+ }
+ return max_idx; // all tokens are equal
+ }
+
+ // make sure all text tokens are within the vocab range
+ bool validate(const struct llama_context * ctx) const {
+ const llama_model * model = llama_get_model(ctx);
+ const llama_vocab * vocab = llama_model_get_vocab(model);
+ const int32_t n_vocab = llama_vocab_n_tokens(vocab);
+
+ for (size_t i = 0; i < tokens.size(); ++i) {
+ auto & t = tokens[i];
+ if (t == LLAMA_TOKEN_NULL) {
+ try {
+ const auto & chunk = find_chunk(i);
+ const auto * img_tokens = mtmd_input_chunk_get_tokens_image(chunk.get());
+ size_t n_pos = mtmd_image_tokens_get_n_pos(img_tokens);
+ i += n_pos - 1; // will be +1 by the for loop
+ } catch (const std::exception & e) {
+ return false;
+ }
+ } else if (t < 0 || t >= n_vocab) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ // encode and decode the image chunk
+ int32_t process_chunk(
+ llama_context * ctx,
+ mtmd_context * mctx,
+ llama_pos n_past,
+ int32_t seq_id,
+ llama_pos & n_pos_out) {
+ auto it = map_pos_to_image.find(n_past);
+ if (it == map_pos_to_image.end()) {
+ throw std::runtime_error("Chunk not found");
+ }
+ SRV_INF("%s\n", "processing image...");
+ int32_t n_batch = llama_n_batch(ctx);
+ int64_t t0 = ggml_time_ms();
+ llama_pos new_n_past = n_past;
+ int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx,
+ it->second.get(), // chunk
+ n_past,
+ seq_id,
+ n_batch,
+ true, // logits last
+ &new_n_past);
+ SRV_INF("image processed in %" PRId64 " ms\n", ggml_time_ms() - t0);
+ if (result != 0) {
+ LOG_ERR("mtmd_helper_eval failed with status %d", result);
+ n_pos_out = n_past;
+ return result;
+ }
+ n_pos_out = new_n_past;
+ return 0;
+ }
+};
+
+// Computes FNV-1a hash of the data
+static std::string fnv_hash(const uint8_t * data, size_t len) {
+ const uint64_t fnv_prime = 0x100000001b3ULL;
+ uint64_t hash = 0xcbf29ce484222325ULL;
+
+ for (size_t i = 0; i < len; ++i) {
+ hash ^= data[i];
+ hash *= fnv_prime;
+ }
+ return std::to_string(hash);
+}