set(TARGET_SRCS
server.cpp
- utils.hpp
server-http.cpp
server-http.h
+ server-task.cpp
+ server-task.h
+ server-queue.cpp
+ server-queue.h
+ server-common.cpp
+ server-common.h
)
set(PUBLIC_ASSETS
index.html.gz
--- /dev/null
+#include "common.h"
+#include "log.h"
+#include "llama.h"
+#include "mtmd.h"
+#include "mtmd-helper.h"
+#include "chat.h"
+#include "arg.h" // for common_remote_get_content; TODO: use download.h only
+#include "base64.hpp"
+
+#include "server-common.h"
+
+#include <random>
+#include <sstream>
+
+json format_error_response(const std::string & message, const enum error_type type) {
+ std::string type_str;
+ int code = 500;
+ switch (type) {
+ case ERROR_TYPE_INVALID_REQUEST:
+ type_str = "invalid_request_error";
+ code = 400;
+ break;
+ case ERROR_TYPE_AUTHENTICATION:
+ type_str = "authentication_error";
+ code = 401;
+ break;
+ case ERROR_TYPE_NOT_FOUND:
+ type_str = "not_found_error";
+ code = 404;
+ break;
+ case ERROR_TYPE_SERVER:
+ type_str = "server_error";
+ code = 500;
+ break;
+ case ERROR_TYPE_PERMISSION:
+ type_str = "permission_error";
+ code = 403;
+ break;
+ case ERROR_TYPE_NOT_SUPPORTED:
+ type_str = "not_supported_error";
+ code = 501;
+ break;
+ case ERROR_TYPE_UNAVAILABLE:
+ type_str = "unavailable_error";
+ code = 503;
+ break;
+ case ERROR_TYPE_EXCEED_CONTEXT_SIZE:
+ type_str = "exceed_context_size_error";
+ code = 400;
+ break;
+ }
+ return json {
+ {"code", code},
+ {"message", message},
+ {"type", type_str},
+ };
+}
+
+//
+// random string / id
+//
+
+std::string random_string() {
+ static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
+
+ std::random_device rd;
+ std::mt19937 generator(rd());
+
+ std::string result(32, ' ');
+
+ for (int i = 0; i < 32; ++i) {
+ result[i] = str[generator() % str.size()];
+ }
+
+ return result;
+}
+
+std::string gen_chatcmplid() {
+ return "chatcmpl-" + random_string();
+}
+
+std::string gen_tool_call_id() {
+ return random_string();
+}
+
+//
+// lora utils
+//
+
+bool lora_all_alora(const std::vector<common_adapter_lora_info> & loras) {
+ bool found_alora = false;
+ for (const auto & lora : loras) {
+ if (lora.scale != 0) {
+ if (llama_adapter_get_alora_n_invocation_tokens(lora.ptr) == 0) {
+ return false;
+ }
+ found_alora = true;
+ }
+ }
+ return found_alora;
+}
+
+bool lora_should_clear_cache(
+ const std::vector<common_adapter_lora_info> & current,
+ const std::vector<common_adapter_lora_info> & next) {
+
+ // This should always be called after determining that the two sets are
+ // _not_ equal. This assert is therefore some slightly wasted work and
+ // should be safe to remove as long as this method is called correctly.
+ GGML_ASSERT(!are_lora_equal(current, next));
+
+ return (
+ !(lora_get_enabled_ids(current).empty() || lora_all_alora(current)) ||
+ !lora_all_alora(next));
+}
+
+std::vector<common_adapter_lora_info> parse_lora_request(
+ const std::vector<common_adapter_lora_info> & lora_base,
+ const json & data) {
+ std::vector<common_adapter_lora_info> lora(lora_base);
+ int max_idx = lora.size();
+
+ // clear existing value
+ for (auto & entry : lora) {
+ entry.scale = 0.0f;
+ }
+
+ // set value
+ for (const auto & entry : data) {
+ int id = json_value(entry, "id", -1);
+ float scale = json_value(entry, "scale", 0.0f);
+ if (0 <= id && id < max_idx) {
+ lora[id].scale = scale;
+ } else {
+ throw std::runtime_error("invalid adapter id");
+ }
+ }
+
+ return lora;
+}
+
+bool are_lora_equal(
+ const std::vector<common_adapter_lora_info> & l1,
+ const std::vector<common_adapter_lora_info> & l2) {
+ if (l1.size() != l2.size()) {
+ return false;
+ }
+ for (size_t i = 0; i < l1.size(); ++i) {
+ // we don't check lora.path to reduce the time complexity
+ if (l1[i].scale != l2[i].scale || l1[i].ptr != l2[i].ptr) {
+ return false;
+ }
+ }
+ return true;
+}
+
+std::vector<size_t> lora_get_enabled_ids(const std::vector<common_adapter_lora_info> & loras) {
+ std::vector<size_t> enabled_ids;
+ for (size_t i = 0; i < loras.size(); ++i) {
+ if (loras[i].scale > 0) {
+ enabled_ids.push_back(i);
+ }
+ }
+ return enabled_ids;
+}
+
+//
+// base64 utils (TODO: use the base64::decode from base64.hpp)
+//
+
+static const std::string base64_chars =
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ "abcdefghijklmnopqrstuvwxyz"
+ "0123456789+/";
+
+static inline bool is_base64(uint8_t c) {
+ return (isalnum(c) || (c == '+') || (c == '/'));
+}
+
+static inline raw_buffer base64_decode(const std::string & encoded_string) {
+ int i = 0;
+ int j = 0;
+ int in_ = 0;
+
+ int in_len = encoded_string.size();
+
+ uint8_t char_array_4[4];
+ uint8_t char_array_3[3];
+
+ raw_buffer ret;
+
+ while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) {
+ char_array_4[i++] = encoded_string[in_]; in_++;
+ if (i == 4) {
+ for (i = 0; i < 4; i++) {
+ char_array_4[i] = base64_chars.find(char_array_4[i]);
+ }
+
+ char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4);
+ char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
+ char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
+
+ for (i = 0; (i < 3); i++) {
+ ret.push_back(char_array_3[i]);
+ }
+
+ i = 0;
+ }
+ }
+
+ if (i) {
+ for (j = i; j < 4; j++) {
+ char_array_4[j] = 0;
+ }
+
+ for (j = 0; j < 4; j++) {
+ char_array_4[j] = base64_chars.find(char_array_4[j]);
+ }
+
+ char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4);
+ char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
+ char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
+
+ for (j = 0; j < i - 1; j++) {
+ ret.push_back(char_array_3[j]);
+ }
+ }
+
+ return ret;
+}
+
+//
+// server_tokens implementation
+//
+
+server_tokens::server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) {
+ for (size_t i = 0; i < mtmd_chunks.size(); ++i) {
+ push_back(mtmd_chunks[i]);
+ }
+}
+
+server_tokens::server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {
+}
+
+llama_pos server_tokens::pos_next() const {
+ if (!has_mtmd) {
+ return tokens.size();
+ }
+
+ llama_pos res = tokens.size();
+
+ for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
+ const auto & chunk = it->second;
+ res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get());
+ }
+
+ return res;
+}
+
+std::string server_tokens::str() const {
+ std::ostringstream oss;
+ oss << "tokens: ";
+ for (size_t idx = 0; idx < tokens.size(); ++idx) {
+ llama_token t = tokens[idx];
+ oss << "idx:" << idx << " ";
+ if (t == LLAMA_TOKEN_NULL) {
+ oss << "<embd> ";
+ } else {
+ oss << t << " ";
+ }
+ }
+ oss << "\n";
+ oss << "image idx: ";
+ for (const auto & it : map_idx_to_media) {
+ oss << it.first << ", ";
+ }
+ return oss.str();
+}
+
+const mtmd::input_chunk_ptr & server_tokens::find_chunk(size_t idx) const {
+ auto it = map_idx_to_media.find(idx);
+ if (it != map_idx_to_media.end()) {
+ return it->second;
+ }
+ throw std::runtime_error("Chunk not found");
+}
+
+void server_tokens::push_back(llama_token tok) {
+ if (tok == LLAMA_TOKEN_NULL) {
+ throw std::runtime_error("Invalid token");
+ }
+ tokens.emplace_back(tok);
+}
+
+void server_tokens::push_back(const mtmd_input_chunk * chunk) {
+ auto type = mtmd_input_chunk_get_type(chunk);
+ if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
+ GGML_ASSERT(has_mtmd);
+ const size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk);
+ size_t start_idx = tokens.size();
+ for (size_t i = 0; i < n_tokens; ++i) {
+ tokens.emplace_back(LLAMA_TOKEN_NULL);
+ }
+ mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
+ map_idx_to_media[start_idx] = std::move(new_chunk);
+ } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
+ size_t n_tokens;
+ const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
+ for (size_t i = 0; i < n_tokens; ++i) {
+ push_back(text_tokens[i]);
+ }
+ } else {
+ GGML_ABORT("Invalid chunk type");
+ }
+}
+
+void server_tokens::push_back(server_tokens & tokens) {
+ size_t start_idx = size();
+ for (size_t i = 0; i < tokens.size(); i++) {
+ push_back(tokens[i]);
+ }
+ if (tokens.has_mtmd) {
+ // Assert if we are copying MTMD chunks to a server_tokens that does not have mtmd.
+ // We could also just check, but this will prevent silently dropping MTMD data.
+ GGML_ASSERT(has_mtmd);
+ for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) {
+ auto * chunk = tokens.map_idx_to_media[it->first].get();
+ mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
+ map_idx_to_media[start_idx + it->first] = std::move(new_chunk);
+ }
+ }
+}
+
+void server_tokens::insert(const llama_tokens & inp_tokens) {
+ GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
+ tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end());
+}
+
+const llama_tokens & server_tokens::get_text_tokens() const {
+ GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
+ return tokens;
+}
+
+void server_tokens::set_token(llama_pos pos, llama_token id) {
+ GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
+ tokens[pos] = id;
+}
+
+void server_tokens::keep_first(size_t n) {
+ GGML_ASSERT(n <= tokens.size());
+ if (has_mtmd) {
+ if (n == tokens.size()) {
+ return; // nothing to do
+ }
+ // we throw an error if we try to remove a token in the middle of an image
+ // for ex. with input of 5 text tokens and 2 images:
+ // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1]
+ // n 1 2 3 4 5 6 7 8 9 10
+ // allowed to resize ^ ^
+ // disallowed to resize ^ ^ ^
+ if (n > 0) {
+ // make sure we never remove tokens in the middle of an image
+ // note that the case where we keep a full image at the end is allowed:
+ // tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] != LLAMA_TOKEN_NULL
+ if (tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] == LLAMA_TOKEN_NULL) {
+ find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk
+ }
+ }
+ // remove all image chunks that are not used anymore
+ for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ) {
+ size_t idx = it->first;
+ if (idx >= n) {
+ it = map_idx_to_media.erase(it);
+ } else {
+ ++it;
+ }
+ }
+ }
+ tokens.resize(n);
+}
+
+std::string server_tokens::detokenize(const llama_context * ctx, bool special) const {
+ llama_tokens text_tokens;
+ text_tokens.reserve(tokens.size());
+ for (const auto & t : tokens) {
+ if (t != LLAMA_TOKEN_NULL) {
+ text_tokens.push_back(t);
+ }
+ }
+ return common_detokenize(ctx, text_tokens, special);
+}
+
+size_t server_tokens::get_common_prefix(const server_tokens & b) const {
+ const size_t max_idx = std::min(tokens.size(), b.tokens.size());
+
+ if (!has_mtmd) {
+ for (size_t i = 0; i < max_idx; ++i) {
+ if (tokens[i] == b.tokens[i]) {
+ continue;
+ }
+
+ return i;
+ }
+
+ return max_idx;
+ }
+
+ for (size_t i = 0; i < max_idx; ++i) {
+ const llama_token ai = tokens[i];
+ const llama_token bi = b.tokens[i];
+
+ if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) {
+ const auto & a_chunk = find_chunk(i);
+ const auto & b_chunk = b.find_chunk(i);
+
+ GGML_ASSERT(a_chunk && b_chunk);
+
+ const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get());
+ const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get());
+
+ const size_t n_tok_a = mtmd_input_chunk_get_n_tokens(a_chunk.get());
+ const size_t n_tok_b = mtmd_input_chunk_get_n_tokens(b_chunk.get());
+
+ if (id_ai == id_bi && n_tok_a == n_tok_b) {
+ GGML_ASSERT(n_tok_a > 0 && "Invalid media chunk"); // should never happen
+ i += n_tok_a - 1; // will be +1 by the for loop
+ continue;
+ }
+
+ return i;
+ }
+
+ if (ai == bi) {
+ continue;
+ }
+
+ return i;
+ }
+
+ return max_idx; // all tokens are equal
+}
+
+bool server_tokens::validate(const struct llama_context * ctx) const {
+ const llama_model * model = llama_get_model(ctx);
+ const llama_vocab * vocab = llama_model_get_vocab(model);
+ const int32_t n_vocab = llama_vocab_n_tokens(vocab);
+
+ for (size_t i = 0; i < tokens.size(); ++i) {
+ const auto & t = tokens[i];
+ if (t == LLAMA_TOKEN_NULL) {
+ try {
+ const auto & chunk = find_chunk(i);
+ size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk.get());
+ i += n_tokens - 1; // will be +1 by the for loop
+ } catch (const std::exception & e) {
+ return false;
+ }
+ } else if (t < 0 || t >= n_vocab) {
+ return false;
+ }
+ }
+ return true;
+}
+
+int32_t server_tokens::process_chunk(
+ llama_context * ctx,
+ mtmd_context * mctx,
+ size_t idx,
+ llama_pos pos,
+ int32_t seq_id,
+ size_t & n_tokens_out) const {
+ const auto & chunk = find_chunk(idx);
+ const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
+ ? "image" : "audio";
+ SRV_INF("processing %s...\n", name);
+ int32_t n_batch = llama_n_batch(ctx);
+ int64_t t0 = ggml_time_ms();
+ llama_pos new_n_past; // unused for now
+ int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx,
+ chunk.get(),
+ pos,
+ seq_id,
+ n_batch,
+ true, // logits last
+ &new_n_past);
+ SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0);
+ if (result != 0) {
+ LOG_ERR("mtmd_helper_eval failed with status %d", result);
+ n_tokens_out = 0;
+ return result;
+ }
+ n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get());
+ return 0;
+}
+
+//
+// tokenizer and input processing utils
+//
+
+bool json_is_array_of_numbers(const json & data) {
+ if (data.is_array()) {
+ for (const auto & e : data) {
+ if (!e.is_number_integer()) {
+ return false;
+ }
+ }
+ return true;
+ }
+ return false;
+}
+
+bool json_is_array_of_mixed_numbers_strings(const json & data) {
+ bool seen_string = false;
+ bool seen_number = false;
+ if (data.is_array()) {
+ for (const auto & e : data) {
+ seen_string |= e.is_string();
+ seen_number |= e.is_number_integer();
+ if (seen_number && seen_string) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+bool json_is_array_and_contains_numbers(const json & data) {
+ if (data.is_array()) {
+ for (const auto & e : data) {
+ if (e.is_number_integer()) {
+ return true;
+ }
+ }
+ return false;
+ }
+ return false;
+}
+
+json json_get_nested_values(const std::vector<std::string> & paths, const json & js) {
+ json result = json::object();
+
+ for (const std::string & path : paths) {
+ json current = js;
+ const auto keys = string_split<std::string>(path, /*separator*/ '/');
+ bool valid_path = true;
+ for (const std::string & k : keys) {
+ if (valid_path && current.is_object() && current.contains(k)) {
+ current = current[k];
+ } else {
+ valid_path = false;
+ }
+ }
+ if (valid_path) {
+ result[path] = current;
+ }
+ }
+ return result;
+}
+
+llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) {
+ // If `add_bos` is true, we only add BOS, when json_prompt is a string,
+ // or the first element of the json_prompt array is a string.
+ llama_tokens prompt_tokens;
+
+ if (json_prompt.is_array()) {
+ bool first = true;
+ for (const auto & p : json_prompt) {
+ if (p.is_string()) {
+ auto s = p.template get<std::string>();
+
+ llama_tokens p;
+ if (first) {
+ p = common_tokenize(vocab, s, add_special, parse_special);
+ first = false;
+ } else {
+ p = common_tokenize(vocab, s, false, parse_special);
+ }
+
+ prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
+ } else {
+ if (first) {
+ first = false;
+ }
+
+ prompt_tokens.push_back(p.template get<llama_token>());
+ }
+ }
+ } else {
+ auto s = json_prompt.template get<std::string>();
+ prompt_tokens = common_tokenize(vocab, s, add_special, parse_special);
+ }
+
+ return prompt_tokens;
+}
+
+size_t validate_utf8(const std::string& text) {
+ size_t len = text.size();
+ if (len == 0) return 0;
+
+ // Check the last few bytes to see if a multi-byte character is cut off
+ for (size_t i = 1; i <= 4 && i <= len; ++i) {
+ unsigned char c = text[len - i];
+ // Check for start of a multi-byte sequence from the end
+ if ((c & 0xE0) == 0xC0) {
+ // 2-byte character start: 110xxxxx
+ // Needs at least 2 bytes
+ if (i < 2) return len - i;
+ } else if ((c & 0xF0) == 0xE0) {
+ // 3-byte character start: 1110xxxx
+ // Needs at least 3 bytes
+ if (i < 3) return len - i;
+ } else if ((c & 0xF8) == 0xF0) {
+ // 4-byte character start: 11110xxx
+ // Needs at least 4 bytes
+ if (i < 4) return len - i;
+ }
+ }
+
+ // If no cut-off multi-byte character is found, return full length
+ return len;
+}
+
+// Computes FNV-1a hash of the data
+static std::string fnv_hash(const uint8_t * data, size_t len) {
+ const uint64_t fnv_prime = 0x100000001b3ULL;
+ uint64_t hash = 0xcbf29ce484222325ULL;
+
+ for (size_t i = 0; i < len; ++i) {
+ hash ^= data[i];
+ hash *= fnv_prime;
+ }
+ return std::to_string(hash);
+}
+
+server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector<raw_buffer> files) {
+ mtmd::bitmaps bitmaps;
+ for (auto & file : files) {
+ mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(mctx, file.data(), file.size()));
+ if (!bmp.ptr) {
+ throw std::runtime_error("Failed to load image or audio file");
+ }
+ // calculate bitmap hash (for KV caching)
+ std::string hash = fnv_hash(bmp.data(), bmp.n_bytes());
+ bmp.set_id(hash.c_str());
+ bitmaps.entries.push_back(std::move(bmp));
+ }
+ // process prompt
+ std::vector<server_tokens> inputs;
+ // multimodal
+ mtmd_input_text inp_txt = {
+ prompt.c_str(),
+ /* add_special */ true,
+ /* parse_special */ true,
+ };
+ mtmd::input_chunks chunks(mtmd_input_chunks_init());
+ auto bitmaps_c_ptr = bitmaps.c_ptr();
+ int32_t tokenized = mtmd_tokenize(mctx,
+ chunks.ptr.get(),
+ &inp_txt,
+ bitmaps_c_ptr.data(),
+ bitmaps_c_ptr.size());
+ if (tokenized != 0) {
+ throw std::runtime_error("Failed to tokenize prompt");
+ }
+ auto result = server_tokens(chunks, true);
+ return result;
+}
+
+/**
+ * break the input "prompt" object into multiple prompt if needed, then tokenize them
+ * use tokenize_input_prompts() if the input could be an array.
+ * this supports these cases:
+ * - "prompt": "string"
+ * - "prompt": [12, 34, 56]
+ * - "prompt": [12, 34, "string", 56, 78]
+ * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] }
+ */
+static server_tokens tokenize_input_subprompt(const llama_vocab * vocab, mtmd_context * mctx, const json & json_prompt, bool add_special, bool parse_special) {
+ constexpr char JSON_STRING_PROMPT_KEY[] = "prompt_string";
+ constexpr char JSON_MTMD_DATA_KEY[] = "multimodal_data";
+ const bool has_mtmd = mctx != nullptr;
+ if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) {
+ // string or mixed
+ llama_tokens tmp = tokenize_mixed(vocab, json_prompt, add_special, parse_special);
+ return server_tokens(tmp, false);
+ } else if (json_is_array_of_numbers(json_prompt)) {
+ // array of tokens
+ llama_tokens tmp = json_prompt.get<llama_tokens>();
+ return server_tokens(tmp, false);
+ } else if (json_prompt.contains(JSON_STRING_PROMPT_KEY)) {
+ // JSON object with prompt key.
+ if (json_prompt.contains(JSON_MTMD_DATA_KEY)) {
+ if (!has_mtmd)
+ throw std::runtime_error("Multimodal data provided, but model does not support multimodal requests.");
+
+ // JSON object with prompt and multimodal key.
+ std::vector<raw_buffer> files;
+ for (const auto & entry : json_prompt.at(JSON_MTMD_DATA_KEY)) {
+ files.push_back(base64_decode(entry));
+ }
+ return process_mtmd_prompt(mctx, json_prompt.at(JSON_STRING_PROMPT_KEY), files);
+ } else {
+ // Not multimodal, but contains a subobject.
+ llama_tokens tmp = tokenize_mixed(vocab, json_prompt.at(JSON_STRING_PROMPT_KEY), add_special, parse_special);
+ return server_tokens(tmp, false);
+ }
+ } else {
+ throw std::runtime_error("\"prompt\" elements must be a string, a list of tokens, a JSON object containing a prompt string, or a list of mixed strings & tokens.");
+ }
+}
+
+std::vector<server_tokens> tokenize_input_prompts(const llama_vocab * vocab, mtmd_context * mctx, const json & json_prompt, bool add_special, bool parse_special) {
+ std::vector<server_tokens> result;
+ if (json_prompt.is_array() && !json_is_array_and_contains_numbers(json_prompt)) {
+ result.reserve(json_prompt.size());
+ for (const auto & p : json_prompt) {
+ result.push_back(tokenize_input_subprompt(vocab, mctx, p,add_special, parse_special));
+ }
+ } else {
+ result.push_back(tokenize_input_subprompt(vocab, mctx, json_prompt, add_special, parse_special));
+ }
+ if (result.empty()) {
+ throw std::runtime_error("\"prompt\" must not be empty");
+ }
+ return result;
+}
+
+
+//
+// OAI utils
+//
+
+// used by /completions endpoint
+json oaicompat_completion_params_parse(const json & body) {
+ json llama_params;
+
+ if (!body.contains("prompt")) {
+ throw std::runtime_error("\"prompt\" is required");
+ }
+
+ // Handle "stop" field
+ if (body.contains("stop") && body.at("stop").is_string()) {
+ llama_params["stop"] = json::array({body.at("stop").get<std::string>()});
+ } else {
+ llama_params["stop"] = json_value(body, "stop", json::array());
+ }
+
+ // Handle "n" field
+ int n_choices = json_value(body, "n", 1);
+ if (n_choices != 1) {
+ throw std::runtime_error("Only one completion choice is allowed");
+ }
+
+ // Handle "echo" field
+ if (json_value(body, "echo", false)) {
+ throw std::runtime_error("Only no echo is supported");
+ }
+
+ // Params supported by OAI but unsupported by llama.cpp
+ static const std::vector<std::string> unsupported_params { "best_of", "suffix" };
+ for (const auto & param : unsupported_params) {
+ if (body.contains(param)) {
+ throw std::runtime_error("Unsupported param: " + param);
+ }
+ }
+
+ // Copy remaining properties to llama_params
+ for (const auto & item : body.items()) {
+ // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"
+ if (!llama_params.contains(item.key()) || item.key() == "n_predict") {
+ llama_params[item.key()] = item.value();
+ }
+ }
+
+ return llama_params;
+}
+
+// used by /chat/completions endpoint
+json oaicompat_chat_params_parse(
+ json & body, /* openai api json semantics */
+ const oaicompat_parser_options & opt,
+ std::vector<raw_buffer> & out_files)
+{
+ json llama_params;
+
+ auto tools = json_value(body, "tools", json());
+ auto has_tools = tools.is_array() && !tools.empty();
+ auto stream = json_value(body, "stream", false);
+ auto tool_choice = json_value(body, "tool_choice", std::string("auto"));
+
+ if (!opt.use_jinja) {
+ if (has_tools) {
+ throw std::runtime_error("tools param requires --jinja flag");
+ }
+ if (tool_choice != "auto") {
+ throw std::runtime_error("tool_choice param requires --jinja flag");
+ }
+ }
+
+ // Handle "stop" field
+ if (body.contains("stop") && body.at("stop").is_string()) {
+ llama_params["stop"] = json::array({body.at("stop").get<std::string>()});
+ } else {
+ llama_params["stop"] = json_value(body, "stop", json::array());
+ }
+
+ auto json_schema = json_value(body, "json_schema", json());
+ auto grammar = json_value(body, "grammar", std::string());
+ if (!json_schema.is_null() && !grammar.empty()) {
+ throw std::runtime_error("Cannot use both json_schema and grammar");
+ }
+
+ // Handle "response_format" field
+ if (body.contains("response_format")) {
+ json response_format = json_value(body, "response_format", json::object());
+ std::string response_type = json_value(response_format, "type", std::string());
+ if (response_type == "json_object") {
+ json_schema = json_value(response_format, "schema", json::object());
+ } else if (response_type == "json_schema") {
+ auto schema_wrapper = json_value(response_format, "json_schema", json::object());
+ json_schema = json_value(schema_wrapper, "schema", json::object());
+ } else if (!response_type.empty() && response_type != "text") {
+ throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
+ }
+ }
+
+ // get input files
+ if (!body.contains("messages")) {
+ throw std::runtime_error("'messages' is required");
+ }
+ json & messages = body.at("messages");
+ if (!messages.is_array()) {
+ throw std::runtime_error("Expected 'messages' to be an array");
+ }
+ for (auto & msg : messages) {
+ std::string role = json_value(msg, "role", std::string());
+ if (role != "assistant" && !msg.contains("content")) {
+ throw std::runtime_error("All non-assistant messages must contain 'content'");
+ }
+ if (role == "assistant") {
+ if (!msg.contains("content") && !msg.contains("tool_calls")) {
+ throw std::runtime_error("Assistant message must contain either 'content' or 'tool_calls'!");
+ }
+ if (!msg.contains("content")) {
+ continue; // avoid errors with no content
+ }
+ }
+ json & content = msg.at("content");
+ if (content.is_string() || content.is_null()) {
+ continue;
+ }
+
+ if (!content.is_array()) {
+ throw std::runtime_error("Expected 'content' to be a string or an array");
+ }
+
+ for (auto & p : content) {
+ std::string type = json_value(p, "type", std::string());
+ if (type == "image_url") {
+ if (!opt.allow_image) {
+ throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
+ }
+
+ json image_url = json_value(p, "image_url", json::object());
+ std::string url = json_value(image_url, "url", std::string());
+ if (string_starts_with(url, "http")) {
+ // download remote image
+ // TODO @ngxson : maybe make these params configurable
+ common_remote_params params;
+ params.headers.push_back("User-Agent: llama.cpp/" + build_info);
+ params.max_size = 1024 * 1024 * 10; // 10MB
+ params.timeout = 10; // seconds
+ SRV_INF("downloading image from '%s'\n", url.c_str());
+ auto res = common_remote_get_content(url, params);
+ if (200 <= res.first && res.first < 300) {
+ SRV_INF("downloaded %ld bytes\n", res.second.size());
+ raw_buffer data;
+ data.insert(data.end(), res.second.begin(), res.second.end());
+ out_files.push_back(data);
+ } else {
+ throw std::runtime_error("Failed to download image");
+ }
+
+ } else {
+ // try to decode base64 image
+ std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
+ if (parts.size() != 2) {
+ throw std::runtime_error("Invalid image_url.url value");
+ } else if (!string_starts_with(parts[0], "data:image/")) {
+ throw std::runtime_error("Invalid image_url.url format: " + parts[0]);
+ } else if (!string_ends_with(parts[0], "base64")) {
+ throw std::runtime_error("image_url.url must be base64 encoded");
+ } else {
+ auto base64_data = parts[1];
+ auto decoded_data = base64_decode(base64_data);
+ out_files.push_back(decoded_data);
+ }
+ }
+
+ // replace this chunk with a marker
+ p["type"] = "text";
+ p["text"] = mtmd_default_marker();
+ p.erase("image_url");
+
+ } else if (type == "input_audio") {
+ if (!opt.allow_audio) {
+ throw std::runtime_error("audio input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
+ }
+
+ json input_audio = json_value(p, "input_audio", json::object());
+ std::string data = json_value(input_audio, "data", std::string());
+ std::string format = json_value(input_audio, "format", std::string());
+ // while we also support flac, we don't allow it here so we matches the OAI spec
+ if (format != "wav" && format != "mp3") {
+ throw std::runtime_error("input_audio.format must be either 'wav' or 'mp3'");
+ }
+ auto decoded_data = base64_decode(data); // expected to be base64 encoded
+ out_files.push_back(decoded_data);
+
+ // replace this chunk with a marker
+ p["type"] = "text";
+ p["text"] = mtmd_default_marker();
+ p.erase("input_audio");
+
+ } else if (type != "text") {
+ throw std::runtime_error("unsupported content[].type");
+ }
+ }
+ }
+
+ common_chat_templates_inputs inputs;
+ inputs.messages = common_chat_msgs_parse_oaicompat(messages);
+ inputs.tools = common_chat_tools_parse_oaicompat(tools);
+ inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice);
+ inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
+ inputs.grammar = grammar;
+ inputs.use_jinja = opt.use_jinja;
+ inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
+ inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
+ inputs.reasoning_format = opt.reasoning_format;
+ inputs.enable_thinking = opt.enable_thinking;
+ if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
+ if (body.contains("grammar")) {
+ throw std::runtime_error("Cannot use custom grammar constraints with tools.");
+ }
+ llama_params["parse_tool_calls"] = true;
+ }
+
+ // merge the template args provided from command line with the args provided in the user request
+ auto chat_template_kwargs_object = json_value(body, "chat_template_kwargs", json::object());
+ inputs.chat_template_kwargs = opt.chat_template_kwargs;
+ for (const auto & item : chat_template_kwargs_object.items()) {
+ inputs.chat_template_kwargs[item.key()] = item.value().dump();
+ }
+
+ // parse the "enable_thinking" kwarg to override the default value
+ auto enable_thinking_kwarg = json_value(inputs.chat_template_kwargs, "enable_thinking", std::string(""));
+ if (enable_thinking_kwarg == "true") {
+ inputs.enable_thinking = true;
+ } else if (enable_thinking_kwarg == "false") {
+ inputs.enable_thinking = false;
+ } else if (!enable_thinking_kwarg.empty() && enable_thinking_kwarg[0] == '"') {
+ throw std::runtime_error("invalid type for \"enable_thinking\" (expected boolean, got string)");
+ }
+
+ // if the assistant message appears at the end of list, we do not add end-of-turn token
+ // for ex. this can be useful to modify the reasoning process in reasoning models
+ bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" && opt.prefill_assistant;
+ common_chat_msg last_message;
+ if (prefill_assistant_message) {
+ last_message = inputs.messages.back();
+ inputs.messages.pop_back();
+
+ /* sanity check, max one assistant message at the end of the list */
+ if (!inputs.messages.empty() && inputs.messages.back().role == "assistant"){
+ throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list.");
+ }
+
+ /* TODO: test this properly */
+ inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE;
+
+ if ( inputs.enable_thinking ) {
+ throw std::runtime_error("Assistant response prefill is incompatible with enable_thinking.");
+ }
+
+ inputs.add_generation_prompt = true;
+ }
+
+ // Apply chat template to the list of messages
+ auto chat_params = common_chat_templates_apply(opt.tmpls, inputs);
+
+ /* Append assistant prefilled message */
+ if (prefill_assistant_message) {
+ if (!last_message.content_parts.empty()) {
+ for (auto & p : last_message.content_parts) {
+ chat_params.prompt += p.text;
+ }
+ } else {
+ chat_params.prompt += last_message.content;
+ }
+ }
+
+ llama_params["chat_format"] = static_cast<int>(chat_params.format);
+ llama_params["prompt"] = chat_params.prompt;
+ if (!chat_params.grammar.empty()) {
+ llama_params["grammar"] = chat_params.grammar;
+ }
+ llama_params["grammar_lazy"] = chat_params.grammar_lazy;
+ auto grammar_triggers = json::array();
+ for (const auto & trigger : chat_params.grammar_triggers) {
+ server_grammar_trigger ct(trigger);
+ grammar_triggers.push_back(ct.to_json());
+ }
+ llama_params["grammar_triggers"] = grammar_triggers;
+ llama_params["preserved_tokens"] = chat_params.preserved_tokens;
+ llama_params["thinking_forced_open"] = chat_params.thinking_forced_open;
+ for (const auto & stop : chat_params.additional_stops) {
+ llama_params["stop"].push_back(stop);
+ }
+
+ // Handle "n" field
+ int n_choices = json_value(body, "n", 1);
+ if (n_choices != 1) {
+ throw std::runtime_error("Only one completion choice is allowed");
+ }
+
+ // Handle "logprobs" field
+ // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
+ if (json_value(body, "logprobs", false)) {
+ if (has_tools && stream) {
+ throw std::runtime_error("logprobs is not supported with tools + stream");
+ }
+ llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
+ } else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) {
+ throw std::runtime_error("top_logprobs requires logprobs to be set to true");
+ }
+
+ // Copy remaining properties to llama_params
+ // This allows user to use llama.cpp-specific params like "mirostat", ... via OAI endpoint.
+ // See "launch_slot_with_task()" for a complete list of params supported by llama.cpp
+ for (const auto & item : body.items()) {
+ // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"
+ if (!llama_params.contains(item.key()) || item.key() == "n_predict") {
+ llama_params[item.key()] = item.value();
+ }
+ }
+
+ return llama_params;
+}
+
+json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64) {
+ json data = json::array();
+ int32_t n_tokens = 0;
+ int i = 0;
+ for (const auto & elem : embeddings) {
+ json embedding_obj;
+
+ if (use_base64) {
+ const auto& vec = json_value(elem, "embedding", json::array()).get<std::vector<float>>();
+ const char* data_ptr = reinterpret_cast<const char*>(vec.data());
+ size_t data_size = vec.size() * sizeof(float);
+ embedding_obj = {
+ {"embedding", base64::encode(data_ptr, data_size)},
+ {"index", i++},
+ {"object", "embedding"},
+ {"encoding_format", "base64"}
+ };
+ } else {
+ embedding_obj = {
+ {"embedding", json_value(elem, "embedding", json::array())},
+ {"index", i++},
+ {"object", "embedding"}
+ };
+ }
+ data.push_back(embedding_obj);
+
+ n_tokens += json_value(elem, "tokens_evaluated", 0);
+ }
+
+ json res = json {
+ {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
+ {"object", "list"},
+ {"usage", json {
+ {"prompt_tokens", n_tokens},
+ {"total_tokens", n_tokens}
+ }},
+ {"data", data}
+ };
+
+ return res;
+}
+
+json format_response_rerank(
+ const json & request,
+ const json & ranks,
+ bool is_tei_format,
+ std::vector<std::string> & texts,
+ int top_n) {
+ int32_t n_tokens = 0;
+ bool return_text = is_tei_format && json_value(request, "return_text", false);
+ std::vector<json> elements; // Temporary vector to hold unsorted elements
+ std::string score_label = is_tei_format ? "score" : "relevance_score";
+ for (const auto & rank : ranks) {
+ int index = json_value(rank, "index", 0);
+ json elem = json{
+ {"index", index},
+ {score_label, json_value(rank, "score", 0.0)},
+ };
+ n_tokens += json_value(rank, "tokens_evaluated", 0);
+ if (return_text) {
+ elem["text"] = std::move(texts[index]);
+ }
+ elements.push_back(elem);
+ }
+
+ std::sort(elements.begin(), elements.end(), [score_label](const json& a, const json& b) {
+ return json_value(a, score_label, 0.0) > json_value(b, score_label, 0.0);
+ });
+
+ elements.resize(std::min(top_n, (int)elements.size()));
+ json results = elements;
+
+ if (is_tei_format) return results;
+
+ json res = json{
+ {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
+ {"object", "list"},
+ {"usage", json{
+ {"prompt_tokens", n_tokens},
+ {"total_tokens", n_tokens}
+ }},
+ {"results", results}
+ };
+
+ return res;
+}
+
+
+//
+// other utils
+//
+
+std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) {
+ std::vector<llama_token_data> cur;
+ const auto * logits = llama_get_logits_ith(ctx, idx);
+
+ const llama_model * model = llama_get_model(ctx);
+ const llama_vocab * vocab = llama_model_get_vocab(model);
+
+ const int n_vocab = llama_vocab_n_tokens(vocab);
+
+ cur.resize(n_vocab);
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+ cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
+ }
+
+ // sort tokens by logits
+ std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) {
+ return a.logit > b.logit;
+ });
+
+ // apply softmax
+ float max_l = cur[0].logit;
+ float cum_sum = 0.0f;
+ for (size_t i = 0; i < cur.size(); ++i) {
+ float p = expf(cur[i].logit - max_l);
+ cur[i].p = p;
+ cum_sum += p;
+ }
+ for (size_t i = 0; i < cur.size(); ++i) {
+ cur[i].p /= cum_sum;
+ }
+
+ return cur;
+}
+
+std::string safe_json_to_str(const json & data) {
+ return data.dump(-1, ' ', false, json::error_handler_t::replace);
+}
+
+// TODO: reuse llama_detokenize
+template <class Iter>
+static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
+ std::string ret;
+ for (; begin != end; ++begin) {
+ ret += common_token_to_piece(ctx, *begin);
+ }
+
+ return ret;
+}
+
+std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens) {
+ return tokens_to_str(ctx, tokens.begin(), tokens.end());
+}
+
+// format incomplete utf-8 multibyte character for output
+std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) {
+ std::string out = token == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(ctx, token);
+
+ // if the size is 1 and first bit is 1, meaning it's a partial character
+ // (size > 1 meaning it's already a known token)
+ if (out.size() == 1 && (out[0] & 0x80) == 0x80) {
+ std::stringstream ss;
+ ss << std::hex << (out[0] & 0xff);
+ std::string res(ss.str());
+ out = "byte: \\x" + res;
+ }
+
+ return out;
+}
+
+// format server-sent event (SSE), return the formatted string to send
+// note: if data is a json array, it will be sent as multiple events, one per item
+std::string format_sse(const json & data) {
+ std::ostringstream ss;
+ auto send_single = [&ss](const json & data) {
+ ss << "data: " <<
+ safe_json_to_str(data) <<
+ "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row).
+ };
+
+ if (data.is_array()) {
+ for (const auto & item : data) {
+ send_single(item);
+ }
+ } else {
+ send_single(data);
+ }
+
+ return ss.str();
+}
+
+bool is_valid_utf8(const std::string & str) {
+ const unsigned char* bytes = reinterpret_cast<const unsigned char*>(str.data());
+ const unsigned char* end = bytes + str.length();
+
+ while (bytes < end) {
+ if (*bytes <= 0x7F) {
+ // 1-byte sequence (0xxxxxxx)
+ bytes++;
+ } else if ((*bytes & 0xE0) == 0xC0) {
+ // 2-byte sequence (110xxxxx 10xxxxxx)
+ if (end - bytes < 2 || (bytes[1] & 0xC0) != 0x80)
+ return false;
+ bytes += 2;
+ } else if ((*bytes & 0xF0) == 0xE0) {
+ // 3-byte sequence (1110xxxx 10xxxxxx 10xxxxxx)
+ if (end - bytes < 3 || (bytes[1] & 0xC0) != 0x80 || (bytes[2] & 0xC0) != 0x80)
+ return false;
+ bytes += 3;
+ } else if ((*bytes & 0xF8) == 0xF0) {
+ // 4-byte sequence (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx)
+ if (end - bytes < 4 || (bytes[1] & 0xC0) != 0x80 ||
+ (bytes[2] & 0xC0) != 0x80 || (bytes[3] & 0xC0) != 0x80)
+ return false;
+ bytes += 4;
+ } else {
+ // Invalid UTF-8 lead byte
+ return false;
+ }
+ }
+
+ return true;
+}
+
+llama_tokens format_prompt_infill(
+ const llama_vocab * vocab,
+ const json & input_prefix,
+ const json & input_suffix,
+ const json & input_extra,
+ const int n_batch,
+ const int n_predict,
+ const int n_ctx,
+ const bool spm_infill,
+ const llama_tokens & tokens_prompt
+ ) {
+ // TODO: optimize this block by reducing memory allocations and movement
+
+ // use FIM repo-level pattern:
+ // ref: https://arxiv.org/pdf/2409.12186
+ //
+ // [FIM_REP]myproject
+ // [FIM_SEP]filename0
+ // extra chunk 0
+ // [FIM_SEP]filename1
+ // extra chunk 1
+ // ...
+ // [FIM_SEP]filename
+ // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
+ //
+ llama_tokens extra_tokens;
+ extra_tokens.reserve(n_ctx);
+
+ auto tokens_prefix = tokenize_mixed(vocab, input_prefix, false, false);
+ auto tokens_suffix = tokenize_mixed(vocab, input_suffix, false, false);
+
+ if (llama_vocab_fim_rep(vocab) != LLAMA_TOKEN_NULL) {
+ // TODO: make project name an input
+ static const auto k_fim_repo = common_tokenize(vocab, "myproject\n", false, false);
+
+ extra_tokens.push_back(llama_vocab_fim_rep(vocab));
+ extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
+ }
+ for (const auto & chunk : input_extra) {
+ // { "text": string, "filename": string }
+ const std::string text = json_value(chunk, "text", std::string());
+ const std::string filename = json_value(chunk, "filename", std::string("tmp"));
+
+ if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) {
+ const auto k_fim_file = common_tokenize(vocab, filename + "\n", false, false);
+
+ extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab));
+ extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
+ } else {
+ // chunk separator in binary form to avoid confusing the AI
+ static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
+ static const auto k_chunk_prefix_tokens = common_tokenize(vocab, k_chunk_prefix_str, false, false);
+
+ extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
+ }
+
+ const auto chunk_tokens = common_tokenize(vocab, text, false, false);
+ extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
+ }
+
+ if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) {
+ // TODO: current filename
+ static const auto k_fim_file = common_tokenize(vocab, "filename\n", false, false);
+
+ extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab));
+ extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
+ }
+
+ // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
+ const int n_prefix_take = std::min<int>(tokens_prefix.size(), 3*(n_batch/4));
+ const int n_suffix_take = std::min<int>(tokens_suffix.size(), std::max<int>(0, (n_batch/4) - (2 + tokens_prompt.size())));
+
+ SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n", n_prefix_take, n_suffix_take, (n_prefix_take + n_suffix_take));
+
+ // fill the rest of the context with extra chunks
+ const int n_extra_take = std::min<int>(std::max<int>(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size());
+
+ tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
+ tokens_suffix.resize(n_suffix_take);
+
+ tokens_prefix.insert(tokens_prefix.begin(), llama_vocab_fim_pre(vocab));
+ tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
+ tokens_suffix.insert(tokens_suffix.begin(), llama_vocab_fim_suf(vocab));
+
+ auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix;
+ auto embd_end = spm_infill ? tokens_prefix : tokens_suffix;
+
+ if (llama_vocab_get_add_bos(vocab)) {
+ embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab));
+ }
+
+ SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size());
+
+ // put the extra context before the FIM prefix
+ embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end());
+
+ embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
+ embd_inp.push_back(llama_vocab_fim_mid(vocab));
+
+ return embd_inp;
+}
+
+server_tokens format_prompt_rerank(
+ const struct llama_model * model,
+ const struct llama_vocab * vocab,
+ mtmd_context * mctx,
+ const std::string & query,
+ const std::string & doc) {
+ server_tokens result = {};
+
+ const char * rerank_prompt = llama_model_chat_template(model, "rerank");
+
+ if (rerank_prompt != nullptr) {
+ std::string prompt = rerank_prompt;
+ string_replace_all(prompt, "{query}" , query);
+ string_replace_all(prompt, "{document}", doc );
+ server_tokens tokens = tokenize_input_subprompt(vocab, mctx, prompt, false, true);
+ result.push_back(tokens);
+ } else {
+ // Get EOS token - use SEP token as fallback if EOS is not available
+ server_tokens query_tokens = tokenize_input_subprompt(vocab, mctx, query, false, false);
+ server_tokens doc_tokens = tokenize_input_subprompt(vocab, mctx, doc, false, false);
+ llama_token eos_token = llama_vocab_eos(vocab);
+ if (eos_token == LLAMA_TOKEN_NULL) {
+ eos_token = llama_vocab_sep(vocab);
+ }
+
+ if (llama_vocab_get_add_bos(vocab)) {
+ result.push_back(llama_vocab_bos(vocab));
+ }
+ result.push_back(query_tokens);
+ if (llama_vocab_get_add_eos(vocab)) {
+ result.push_back(eos_token);
+ }
+ if (llama_vocab_get_add_sep(vocab)) {
+ result.push_back(llama_vocab_sep(vocab));
+ }
+ result.push_back(doc_tokens);
+ if (llama_vocab_get_add_eos(vocab)) {
+ result.push_back(eos_token);
+ }
+ }
+
+ return result;
+}
--- /dev/null
+#pragma once
+
+#include "common.h"
+#include "log.h"
+#include "llama.h"
+#include "chat.h"
+#include "mtmd.h"
+
+#define JSON_ASSERT GGML_ASSERT
+#include <nlohmann/json.hpp>
+
+#include <string>
+#include <vector>
+#include <cinttypes>
+
+#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo"
+
+const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT);
+
+using json = nlohmann::ordered_json;
+
+#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
+#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
+#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
+#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
+
+#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+
+using raw_buffer = std::vector<uint8_t>;
+
+template <typename T>
+static T json_value(const json & body, const std::string & key, const T & default_value) {
+ // Fallback null to default value
+ if (body.contains(key) && !body.at(key).is_null()) {
+ try {
+ return body.at(key);
+ } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const & err) {
+ LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value: %s\n", key.c_str(), json(default_value).type_name(), err.what());
+ return default_value;
+ }
+ } else {
+ return default_value;
+ }
+}
+
+// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
+enum error_type {
+ ERROR_TYPE_INVALID_REQUEST,
+ ERROR_TYPE_AUTHENTICATION,
+ ERROR_TYPE_SERVER,
+ ERROR_TYPE_NOT_FOUND,
+ ERROR_TYPE_PERMISSION,
+ ERROR_TYPE_UNAVAILABLE, // custom error
+ ERROR_TYPE_NOT_SUPPORTED, // custom error
+ ERROR_TYPE_EXCEED_CONTEXT_SIZE, // custom error
+};
+
+// thin wrapper around common_grammar_trigger with (de)serialization functions
+struct server_grammar_trigger {
+ common_grammar_trigger value;
+
+ server_grammar_trigger() = default;
+ server_grammar_trigger(const common_grammar_trigger & value) : value(value) {}
+ server_grammar_trigger(const json & in) {
+ value.type = (common_grammar_trigger_type) in.at("type").get<int>();
+ value.value = in.at("value").get<std::string>();
+ if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
+ value.token = (llama_token) in.at("token").get<int>();
+ }
+ }
+
+ json to_json() const {
+ json out {
+ {"type", (int) value.type},
+ {"value", value.value},
+ };
+ if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
+ out["token"] = (int) value.token;
+ }
+ return out;
+ }
+};
+
+json format_error_response(const std::string & message, const enum error_type type);
+
+//
+// random string / id
+//
+
+std::string random_string();
+std::string gen_chatcmplid();
+std::string gen_tool_call_id();
+
+//
+// lora utils
+//
+
+// check whether the given lora set has only aloras activated (empty => false)
+bool lora_all_alora(const std::vector<common_adapter_lora_info> & loras);
+
+// if the two sets of loras are different, they require a cache clear unless the
+// change is only from aloras to aloras.
+bool lora_should_clear_cache(
+ const std::vector<common_adapter_lora_info> & current,
+ const std::vector<common_adapter_lora_info> & next);
+
+std::vector<common_adapter_lora_info> parse_lora_request(
+ const std::vector<common_adapter_lora_info> & lora_base,
+ const json & data);
+
+bool are_lora_equal(
+ const std::vector<common_adapter_lora_info> & l1,
+ const std::vector<common_adapter_lora_info> & l2);
+
+// get the ids of all enabled loras
+std::vector<size_t> lora_get_enabled_ids(const std::vector<common_adapter_lora_info> & loras);
+
+//
+// server_tokens
+//
+
+/**
+ * server_tokens is a helper to manage the input tokens and image for the server.
+ * it is made this way to simplify the logic of KV cache management.
+ */
+struct server_tokens {
+ bool has_mtmd = false;
+
+private: // disallow accessing these members directly, risking out-of-sync
+
+ // map a **start** index in tokens to the image chunk
+ // note: the order need to be in-sync with tokens
+ std::map<size_t, mtmd::input_chunk_ptr> map_idx_to_media;
+
+ // list of tokens
+ // if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk
+ // otherwise, it is a normal text token
+ // note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list
+ // note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos
+ llama_tokens tokens;
+
+ // for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos):
+ // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1]
+ // idx 0 1 2 3 4 5 6 7 8 9 10
+ // pos 0 1 2 3 4 5 5 5 7 7 7
+ // map_idx_to_media will contain: {5, img0}, {8, img1}
+
+public:
+ server_tokens() = default;
+ ~server_tokens() = default;
+
+ // Prevent copying
+ // TODO: server_tokens should be copyable - remove this:
+ server_tokens(const server_tokens&) = delete;
+ server_tokens& operator=(const server_tokens&) = delete;
+
+ // Allow moving (usually implicitly generated if members are movable)
+ server_tokens(server_tokens&&) = default;
+ server_tokens& operator=(server_tokens&&) = default;
+
+ // Allow accessing elements using [] operator
+ llama_token operator[](size_t index) { return tokens[index]; }
+ const llama_token& operator[](size_t index) const { return tokens[index]; }
+
+ server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd);
+ server_tokens(const llama_tokens & tokens, bool has_mtmd);
+
+ // for debugging
+ std::string str() const;
+
+ llama_pos pos_next() const;
+ const mtmd::input_chunk_ptr & find_chunk(size_t idx) const;
+
+ void push_back(llama_token tok);
+
+ // will create a copy of the chunk if it contains non-text data
+ void push_back(const mtmd_input_chunk * chunk);
+
+ // appends server tokens, updates the media map. copies media chunks.
+ void push_back(server_tokens & tokens);
+
+ // for compatibility with context shift and prompt truncation
+ void insert(const llama_tokens & inp_tokens);
+
+ // for compatibility with speculative decoding, ctx shift, slot save/load
+ const llama_tokens & get_text_tokens() const;
+
+ // for compatibility with speculative decoding
+ void set_token(llama_pos pos, llama_token id);
+
+ size_t size() const { return tokens.size(); }
+
+ bool empty() const { return tokens.empty(); }
+
+ void clear() {
+ map_idx_to_media.clear();
+ tokens.clear();
+ }
+
+ void keep_first(size_t n);
+
+ std::string detokenize(const llama_context * ctx, bool special) const;
+
+ size_t get_common_prefix(const server_tokens & b) const;
+
+ // make sure all text tokens are within the vocab range
+ bool validate(const struct llama_context * ctx) const;
+
+ // encode and decode the image chunk
+ int32_t process_chunk(
+ llama_context * ctx,
+ mtmd_context * mctx,
+ size_t idx,
+ llama_pos pos,
+ int32_t seq_id,
+ size_t & n_tokens_out) const;
+};
+
+
+//
+// tokenizer and input processing utils
+//
+
+bool json_is_array_of_numbers(const json & data);
+
+// is array having BOTH numbers & strings?
+bool json_is_array_of_mixed_numbers_strings(const json & data);
+
+// does array have any individual integers/tokens?
+bool json_is_array_and_contains_numbers(const json & data);
+
+// get value by path(key1 / key2)
+json json_get_nested_values(const std::vector<std::string> & paths, const json & js);
+
+/**
+ * this handles 2 cases:
+ * - only string, example: "string"
+ * - mixed string and tokens, example: [12, 34, "string", 56, 78]
+ */
+llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special);
+
+// return the last index of character that can form a valid string
+// if the last character is potentially cut in half, return the index before the cut
+// if validate_utf8(text) == text.size(), then the whole text is valid utf8
+size_t validate_utf8(const std::string& text);
+
+// process mtmd prompt, return the server_tokens containing both text tokens and media chunks
+server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector<raw_buffer> files);
+
+/**
+ * break the input "prompt" object into multiple prompt if needed, then tokenize them
+ * this supports these cases:
+ * - "prompt": "string"
+ * - "prompt": [12, 34, 56]
+ * - "prompt": [12, 34, "string", 56, 78]
+ * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] }
+ * and multiple prompts (multi-tasks):
+ * - "prompt": ["string1", "string2"]
+ * - "prompt": ["string1", [12, 34, 56]]
+ * - "prompt": [[12, 34, 56], [78, 90, 12]]
+ * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56], { "prompt_string": "string", "multimodal_data": [ "base64" ]}]
+ */
+std::vector<server_tokens> tokenize_input_prompts(
+ const llama_vocab * vocab,
+ mtmd_context * mctx,
+ const json & json_prompt,
+ bool add_special,
+ bool parse_special);
+
+//
+// OAI utils
+//
+
+// used by /completions endpoint
+json oaicompat_completion_params_parse(const json & body);
+
+struct oaicompat_parser_options {
+ bool use_jinja;
+ bool prefill_assistant;
+ common_reasoning_format reasoning_format;
+ std::map<std::string,std::string> chat_template_kwargs;
+ common_chat_templates * tmpls;
+ bool allow_image;
+ bool allow_audio;
+ bool enable_thinking = true;
+};
+
+// used by /chat/completions endpoint
+json oaicompat_chat_params_parse(
+ json & body, /* openai api json semantics */
+ const oaicompat_parser_options & opt,
+ std::vector<raw_buffer> & out_files);
+
+// TODO: move it to server-task.cpp
+json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false);
+
+// TODO: move it to server-task.cpp
+json format_response_rerank(
+ const json & request,
+ const json & ranks,
+ bool is_tei_format,
+ std::vector<std::string> & texts,
+ int top_n);
+
+//
+// other utils
+//
+
+std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx);
+
+std::string safe_json_to_str(const json & data);
+
+std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens);
+
+// format incomplete utf-8 multibyte character for output
+std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token);
+
+// format server-sent event (SSE), return the formatted string to send
+// note: if data is a json array, it will be sent as multiple events, one per item
+std::string format_sse(const json & data);
+
+bool is_valid_utf8(const std::string & str);
+
+//
+// formatting output responses
+// TODO: move these to server-task.cpp
+//
+
+llama_tokens format_prompt_infill(
+ const llama_vocab * vocab,
+ const json & input_prefix,
+ const json & input_suffix,
+ const json & input_extra,
+ const int n_batch,
+ const int n_predict,
+ const int n_ctx,
+ const bool spm_infill,
+ const llama_tokens & tokens_prompt);
+
+// format rerank task: [BOS]query[EOS][SEP]doc[EOS].
+server_tokens format_prompt_rerank(
+ const struct llama_model * model,
+ const struct llama_vocab * vocab,
+ mtmd_context * mctx,
+ const std::string & query,
+ const std::string & doc);
-#include "utils.hpp"
#include "common.h"
#include "server-http.h"
+#include "server-common.h"
#include <cpp-httplib/httplib.h>
--- /dev/null
+#include "server-task.h"
+#include "server-queue.h"
+
+#include "log.h"
+
+#include <chrono>
+
+#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+
+#define RES_INF(fmt, ...) LOG_INF("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define RES_WRN(fmt, ...) LOG_WRN("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define RES_ERR(fmt, ...) LOG_ERR("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define RES_DBG(fmt, ...) LOG_DBG("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+
+//
+// server_queue
+//
+
+int server_queue::post(server_task && task, bool front) {
+ std::unique_lock<std::mutex> lock(mutex_tasks);
+ GGML_ASSERT(task.id != -1);
+ // if this is cancel task make sure to clean up pending tasks
+ if (task.type == SERVER_TASK_TYPE_CANCEL) {
+ cleanup_pending_task(task.id_target);
+ }
+ const int task_id = task.id;
+ QUE_DBG("new task, id = %d, front = %d\n", task_id, front);
+ if (front) {
+ queue_tasks.push_front(std::move(task));
+ } else {
+ queue_tasks.push_back(std::move(task));
+ }
+ condition_tasks.notify_one();
+ return task_id;
+}
+
+int server_queue::post(std::vector<server_task> && tasks, bool front) {
+ std::unique_lock<std::mutex> lock(mutex_tasks);
+ for (auto & task : tasks) {
+ if (task.id == -1) {
+ task.id = id++;
+ }
+ // if this is cancel task make sure to clean up pending tasks
+ if (task.type == SERVER_TASK_TYPE_CANCEL) {
+ cleanup_pending_task(task.id_target);
+ }
+ QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front);
+ if (front) {
+ queue_tasks.push_front(std::move(task));
+ } else {
+ queue_tasks.push_back(std::move(task));
+ }
+ }
+ condition_tasks.notify_one();
+ return 0;
+}
+
+void server_queue::defer(server_task && task) {
+ std::unique_lock<std::mutex> lock(mutex_tasks);
+ QUE_DBG("defer task, id = %d\n", task.id);
+ queue_tasks_deferred.push_back(std::move(task));
+ condition_tasks.notify_one();
+}
+
+int server_queue::get_new_id() {
+ std::unique_lock<std::mutex> lock(mutex_tasks);
+ int new_id = id++;
+ return new_id;
+}
+
+void server_queue::on_new_task(std::function<void(server_task &&)> callback) {
+ callback_new_task = std::move(callback);
+}
+
+void server_queue::on_update_slots(std::function<void(void)> callback) {
+ callback_update_slots = std::move(callback);
+}
+
+void server_queue::pop_deferred_task() {
+ std::unique_lock<std::mutex> lock(mutex_tasks);
+ if (!queue_tasks_deferred.empty()) {
+ queue_tasks.emplace_front(std::move(queue_tasks_deferred.front()));
+ queue_tasks_deferred.pop_front();
+ }
+ condition_tasks.notify_one();
+}
+
+void server_queue::terminate() {
+ std::unique_lock<std::mutex> lock(mutex_tasks);
+ running = false;
+ condition_tasks.notify_all();
+}
+
+void server_queue::start_loop() {
+ running = true;
+
+ while (true) {
+ QUE_DBG("%s", "processing new tasks\n");
+
+ while (true) {
+ std::unique_lock<std::mutex> lock(mutex_tasks);
+ if (!running) {
+ QUE_DBG("%s", "terminate\n");
+ return;
+ }
+ if (queue_tasks.empty()) {
+ lock.unlock();
+ break;
+ }
+ server_task task = std::move(queue_tasks.front());
+ queue_tasks.pop_front();
+ lock.unlock();
+
+ QUE_DBG("processing task, id = %d\n", task.id);
+ callback_new_task(std::move(task));
+ }
+
+ // all tasks in the current loop is processed, slots data is now ready
+ QUE_DBG("%s", "update slots\n");
+
+ callback_update_slots();
+
+ QUE_DBG("%s", "waiting for new tasks\n");
+ {
+ std::unique_lock<std::mutex> lock(mutex_tasks);
+ if (!running) {
+ QUE_DBG("%s", "terminate\n");
+ return;
+ }
+ if (queue_tasks.empty()) {
+ condition_tasks.wait(lock, [&]{
+ return (!queue_tasks.empty() || !running);
+ });
+ }
+ }
+ }
+}
+
+void server_queue::cleanup_pending_task(int id_target) {
+ // no need lock because this is called exclusively by post()
+ auto rm_func = [id_target](const server_task & task) {
+ return task.id == id_target;
+ };
+ queue_tasks.erase(
+ std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func),
+ queue_tasks.end());
+ queue_tasks_deferred.erase(
+ std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func),
+ queue_tasks_deferred.end());
+}
+
+//
+// server_response
+//
+
+void server_response::add_waiting_task_id(int id_task) {
+ RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
+
+ std::unique_lock<std::mutex> lock(mutex_results);
+ waiting_task_ids.insert(id_task);
+}
+
+void server_response::add_waiting_tasks(const std::vector<server_task> & tasks) {
+ std::unique_lock<std::mutex> lock(mutex_results);
+
+ for (const auto & task : tasks) {
+ RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size());
+ waiting_task_ids.insert(task.id);
+ }
+}
+
+void server_response::remove_waiting_task_id(int id_task) {
+ RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
+
+ std::unique_lock<std::mutex> lock(mutex_results);
+ waiting_task_ids.erase(id_task);
+ // make sure to clean up all pending results
+ queue_results.erase(
+ std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) {
+ return res->id == id_task;
+ }),
+ queue_results.end());
+}
+
+void server_response::remove_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
+ std::unique_lock<std::mutex> lock(mutex_results);
+
+ for (const auto & id_task : id_tasks) {
+ RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
+ waiting_task_ids.erase(id_task);
+ }
+}
+
+server_task_result_ptr server_response::recv(const std::unordered_set<int> & id_tasks) {
+ while (true) {
+ std::unique_lock<std::mutex> lock(mutex_results);
+ condition_results.wait(lock, [&]{
+ if (!running) {
+ RES_DBG("%s : queue result stop\n", __func__);
+ std::terminate(); // we cannot return here since the caller is HTTP code
+ }
+ return !queue_results.empty();
+ });
+
+ for (size_t i = 0; i < queue_results.size(); i++) {
+ if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
+ server_task_result_ptr res = std::move(queue_results[i]);
+ queue_results.erase(queue_results.begin() + i);
+ return res;
+ }
+ }
+ }
+
+ // should never reach here
+}
+
+server_task_result_ptr server_response::recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout) {
+ while (true) {
+ std::unique_lock<std::mutex> lock(mutex_results);
+
+ for (int i = 0; i < (int) queue_results.size(); i++) {
+ if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
+ server_task_result_ptr res = std::move(queue_results[i]);
+ queue_results.erase(queue_results.begin() + i);
+ return res;
+ }
+ }
+
+ std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout));
+ if (!running) {
+ RES_DBG("%s : queue result stop\n", __func__);
+ std::terminate(); // we cannot return here since the caller is HTTP code
+ }
+ if (cr_res == std::cv_status::timeout) {
+ return nullptr;
+ }
+ }
+
+ // should never reach here
+}
+
+server_task_result_ptr server_response::recv(int id_task) {
+ std::unordered_set<int> id_tasks = {id_task};
+ return recv(id_tasks);
+}
+
+void server_response::send(server_task_result_ptr && result) {
+ RES_DBG("sending result for task id = %d\n", result->id);
+
+ std::unique_lock<std::mutex> lock(mutex_results);
+ for (const auto & id_task : waiting_task_ids) {
+ if (result->id == id_task) {
+ RES_DBG("task id = %d pushed to result queue\n", result->id);
+
+ queue_results.emplace_back(std::move(result));
+ condition_results.notify_all();
+ return;
+ }
+ }
+}
+
+void server_response::terminate() {
+ running = false;
+ condition_results.notify_all();
+}
--- /dev/null
+#pragma once
+
+#include "server-task.h"
+
+#include <condition_variable>
+#include <deque>
+#include <mutex>
+#include <unordered_set>
+
+struct server_queue {
+private:
+ int id = 0;
+ bool running;
+
+ // queues
+ std::deque<server_task> queue_tasks;
+ std::deque<server_task> queue_tasks_deferred;
+
+ std::mutex mutex_tasks;
+ std::condition_variable condition_tasks;
+
+ // callback functions
+ std::function<void(server_task &&)> callback_new_task;
+ std::function<void(void)> callback_update_slots;
+
+public:
+ // Add a new task to the end of the queue
+ int post(server_task && task, bool front = false);
+
+ // multi-task version of post()
+ int post(std::vector<server_task> && tasks, bool front = false);
+
+ // Add a new task, but defer until one slot is available
+ void defer(server_task && task);
+
+ // Get the next id for creating a new task
+ int get_new_id();
+
+ // Register function to process a new task
+ void on_new_task(std::function<void(server_task &&)> callback);
+
+ // Register the function to be called when all slots data is ready to be processed
+ void on_update_slots(std::function<void(void)> callback);
+
+ // Call when the state of one slot is changed, it will move one task from deferred to main queue
+ void pop_deferred_task();
+
+ // end the start_loop routine
+ void terminate();
+
+ /**
+ * Main loop consists of these steps:
+ * - Wait until a new task arrives
+ * - Process the task (i.e. maybe copy data into slot)
+ * - Check if multitask is finished
+ * - Update all slots
+ */
+ void start_loop();
+
+ // for metrics
+ size_t queue_tasks_deferred_size() {
+ std::unique_lock<std::mutex> lock(mutex_tasks);
+ return queue_tasks_deferred.size();
+ }
+
+private:
+ void cleanup_pending_task(int id_target);
+};
+
+struct server_response {
+private:
+ bool running = true;
+
+ // for keeping track of all tasks waiting for the result
+ std::unordered_set<int> waiting_task_ids;
+
+ // the main result queue (using ptr for polymorphism)
+ std::vector<server_task_result_ptr> queue_results;
+
+ std::mutex mutex_results;
+ std::condition_variable condition_results;
+
+public:
+ // add the id_task to the list of tasks waiting for response
+ void add_waiting_task_id(int id_task);
+
+ void add_waiting_tasks(const std::vector<server_task> & tasks);
+
+ // when the request is finished, we can remove task associated with it
+ void remove_waiting_task_id(int id_task);
+
+ // remove multiple tasks from waiting list
+ void remove_waiting_task_ids(const std::unordered_set<int> & id_tasks);
+
+ // This function blocks the thread until there is a response for one of the id_tasks
+ server_task_result_ptr recv(const std::unordered_set<int> & id_tasks);
+
+ // same as recv(), but have timeout in seconds
+ // if timeout is reached, nullptr is returned
+ server_task_result_ptr recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout);
+
+ // single-task version of recv()
+ server_task_result_ptr recv(int id_task);
+
+ // Send a new result to a waiting id_task
+ void send(server_task_result_ptr && result);
+
+ // terminate the waiting loop
+ void terminate();
+};
--- /dev/null
+#include "server-common.h"
+#include "server-task.h"
+
+#include "common.h"
+#include "llama.h"
+#include "chat.h"
+#include "sampling.h"
+#include "json-schema-to-grammar.h"
+
+using json = nlohmann::ordered_json;
+
+//
+// task_params
+//
+
+json task_params::format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) const {
+ json data = json::array();
+ for (const auto & lb : logit_bias) {
+ data.push_back(json{
+ {"bias", lb.bias},
+ {"token", lb.token},
+ });
+ }
+ return data;
+}
+
+json task_params::to_json(bool only_metrics) const {
+ std::vector<std::string> samplers;
+ samplers.reserve(sampling.samplers.size());
+ for (const auto & sampler : sampling.samplers) {
+ samplers.emplace_back(common_sampler_type_to_str(sampler));
+ }
+
+ json lora = json::array();
+ for (size_t i = 0; i < this->lora.size(); ++i) {
+ lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
+ }
+
+ if (only_metrics) {
+ return json {
+ {"seed", sampling.seed},
+ {"temperature", sampling.temp},
+ {"dynatemp_range", sampling.dynatemp_range},
+ {"dynatemp_exponent", sampling.dynatemp_exponent},
+ {"top_k", sampling.top_k},
+ {"top_p", sampling.top_p},
+ {"min_p", sampling.min_p},
+ {"top_n_sigma", sampling.top_n_sigma},
+ {"xtc_probability", sampling.xtc_probability},
+ {"xtc_threshold", sampling.xtc_threshold},
+ {"typical_p", sampling.typ_p},
+ {"repeat_last_n", sampling.penalty_last_n},
+ {"repeat_penalty", sampling.penalty_repeat},
+ {"presence_penalty", sampling.penalty_present},
+ {"frequency_penalty", sampling.penalty_freq},
+ {"dry_multiplier", sampling.dry_multiplier},
+ {"dry_base", sampling.dry_base},
+ {"dry_allowed_length", sampling.dry_allowed_length},
+ {"dry_penalty_last_n", sampling.dry_penalty_last_n},
+ {"mirostat", sampling.mirostat},
+ {"mirostat_tau", sampling.mirostat_tau},
+ {"mirostat_eta", sampling.mirostat_eta},
+ {"max_tokens", n_predict},
+ {"n_predict", n_predict}, // TODO: deduplicate?
+ {"n_keep", n_keep},
+ {"n_discard", n_discard},
+ {"ignore_eos", sampling.ignore_eos},
+ {"stream", stream},
+ {"n_probs", sampling.n_probs},
+ {"min_keep", sampling.min_keep},
+ {"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)},
+ {"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)},
+ {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content},
+ {"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open},
+ {"samplers", samplers},
+ {"speculative.n_max", speculative.n_max},
+ {"speculative.n_min", speculative.n_min},
+ {"speculative.p_min", speculative.p_min},
+ {"timings_per_token", timings_per_token},
+ {"post_sampling_probs", post_sampling_probs},
+ {"lora", lora},
+ };
+ }
+
+ auto grammar_triggers = json::array();
+ for (const auto & trigger : sampling.grammar_triggers) {
+ server_grammar_trigger ct(trigger);
+ grammar_triggers.push_back(ct.to_json());
+ }
+
+ return json {
+ {"seed", sampling.seed},
+ {"temperature", sampling.temp},
+ {"dynatemp_range", sampling.dynatemp_range},
+ {"dynatemp_exponent", sampling.dynatemp_exponent},
+ {"top_k", sampling.top_k},
+ {"top_p", sampling.top_p},
+ {"min_p", sampling.min_p},
+ {"top_n_sigma", sampling.top_n_sigma},
+ {"xtc_probability", sampling.xtc_probability},
+ {"xtc_threshold", sampling.xtc_threshold},
+ {"typical_p", sampling.typ_p},
+ {"repeat_last_n", sampling.penalty_last_n},
+ {"repeat_penalty", sampling.penalty_repeat},
+ {"presence_penalty", sampling.penalty_present},
+ {"frequency_penalty", sampling.penalty_freq},
+ {"dry_multiplier", sampling.dry_multiplier},
+ {"dry_base", sampling.dry_base},
+ {"dry_allowed_length", sampling.dry_allowed_length},
+ {"dry_penalty_last_n", sampling.dry_penalty_last_n},
+ {"dry_sequence_breakers", sampling.dry_sequence_breakers},
+ {"mirostat", sampling.mirostat},
+ {"mirostat_tau", sampling.mirostat_tau},
+ {"mirostat_eta", sampling.mirostat_eta},
+ {"stop", antiprompt},
+ {"max_tokens", n_predict},
+ {"n_predict", n_predict}, // TODO: deduplicate?
+ {"n_keep", n_keep},
+ {"n_discard", n_discard},
+ {"ignore_eos", sampling.ignore_eos},
+ {"stream", stream},
+ {"logit_bias", format_logit_bias(sampling.logit_bias)},
+ {"n_probs", sampling.n_probs},
+ {"min_keep", sampling.min_keep},
+ {"grammar", sampling.grammar},
+ {"grammar_lazy", sampling.grammar_lazy},
+ {"grammar_triggers", grammar_triggers},
+ {"preserved_tokens", sampling.preserved_tokens},
+ {"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)},
+ {"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)},
+ {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content},
+ {"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open},
+ {"samplers", samplers},
+ {"speculative.n_max", speculative.n_max},
+ {"speculative.n_min", speculative.n_min},
+ {"speculative.p_min", speculative.p_min},
+ {"timings_per_token", timings_per_token},
+ {"post_sampling_probs", post_sampling_probs},
+ {"lora", lora},
+ };
+}
+
+//
+// server_task
+//
+
+task_params server_task::params_from_json_cmpl(
+ const llama_context * ctx,
+ const common_params & params_base,
+ const json & data) {
+ const llama_model * model = llama_get_model(ctx);
+ const llama_vocab * vocab = llama_model_get_vocab(model);
+
+ task_params params;
+
+ // Sampling parameter defaults are loaded from the global server context (but individual requests can still them)
+ task_params defaults;
+ defaults.sampling = params_base.sampling;
+ defaults.speculative = params_base.speculative;
+ defaults.n_keep = params_base.n_keep;
+ defaults.n_predict = params_base.n_predict;
+ defaults.antiprompt = params_base.antiprompt;
+
+ // enabling this will output extra debug information in the HTTP responses from the server
+ params.verbose = params_base.verbosity > 9;
+ params.timings_per_token = json_value(data, "timings_per_token", false);
+
+ params.stream = json_value(data, "stream", false);
+ auto stream_opt = json_value(data, "stream_options", json::object());
+ params.include_usage = json_value(stream_opt, "include_usage", false);
+ params.cache_prompt = json_value(data, "cache_prompt", true);
+ params.return_tokens = json_value(data, "return_tokens", false);
+ params.return_progress = json_value(data, "return_progress", false);
+ params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
+ params.n_indent = json_value(data, "n_indent", defaults.n_indent);
+ params.n_keep = json_value(data, "n_keep", defaults.n_keep);
+ params.n_discard = json_value(data, "n_discard", defaults.n_discard);
+ //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
+ params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
+ params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
+
+ params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
+ params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
+ params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
+ params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma);
+ params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
+ params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
+ params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
+ params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
+ params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
+ params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
+ params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
+ params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
+ params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
+ params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
+ params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
+ params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
+ params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
+ params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
+ params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
+ params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
+ params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
+ params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
+ params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
+ params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
+ params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
+
+ params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
+ params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
+ params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
+
+ params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min);
+ params.speculative.n_min = std::max(params.speculative.n_min, 0);
+ params.speculative.n_max = std::max(params.speculative.n_max, 0);
+
+ // Use OpenAI API logprobs only if n_probs wasn't provided
+ if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){
+ params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs);
+ }
+
+ if (data.contains("lora")) {
+ if (data.at("lora").is_array()) {
+ params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora"));
+ } else {
+ throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
+ }
+ } else {
+ params.lora = params_base.lora_adapters;
+ }
+
+ // TODO: add more sanity checks for the input parameters
+
+ if (params.sampling.penalty_last_n < -1) {
+ throw std::runtime_error("Error: repeat_last_n must be >= -1");
+ }
+
+ if (params.sampling.dry_penalty_last_n < -1) {
+ throw std::runtime_error("Error: dry_penalty_last_n must be >= -1");
+ }
+
+ if (params.sampling.penalty_last_n == -1) {
+ // note: should be the slot's context and not the full context, but it's ok
+ params.sampling.penalty_last_n = llama_n_ctx(ctx);
+ }
+
+ if (params.sampling.dry_penalty_last_n == -1) {
+ params.sampling.dry_penalty_last_n = llama_n_ctx(ctx);
+ }
+
+ if (params.sampling.dry_base < 1.0f) {
+ params.sampling.dry_base = defaults.sampling.dry_base;
+ }
+
+ // sequence breakers for DRY
+ {
+ // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
+ // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
+
+ if (data.contains("dry_sequence_breakers")) {
+ params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
+ if (params.sampling.dry_sequence_breakers.empty()) {
+ throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings");
+ }
+ }
+ }
+
+ // process "json_schema" and "grammar"
+ if (data.contains("json_schema") && !data.contains("grammar")) {
+ try {
+ auto schema = json_value(data, "json_schema", json::object());
+ SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str());
+ params.sampling.grammar = json_schema_to_grammar(schema);
+ SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str());
+ } catch (const std::exception & e) {
+ throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
+ }
+ } else {
+ params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
+ SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str());
+ params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy);
+ SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false");
+ }
+
+ {
+ auto it = data.find("chat_format");
+ if (it != data.end()) {
+ params.oaicompat_chat_syntax.format = static_cast<common_chat_format>(it->get<int>());
+ SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format));
+ } else {
+ params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format;
+ }
+ common_reasoning_format reasoning_format = params_base.reasoning_format;
+ if (data.contains("reasoning_format")) {
+ reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get<std::string>());
+ }
+ params.oaicompat_chat_syntax.reasoning_format = reasoning_format;
+ params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY);
+ params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
+ params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false);
+ }
+
+ {
+ const auto preserved_tokens = data.find("preserved_tokens");
+ if (preserved_tokens != data.end()) {
+ for (const auto & t : *preserved_tokens) {
+ auto ids = common_tokenize(vocab, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
+ if (ids.size() == 1) {
+ SRV_DBG("Preserved token: %d\n", ids[0]);
+ params.sampling.preserved_tokens.insert(ids[0]);
+ } else {
+ // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
+ SRV_DBG("Not preserved because more than 1 token: %s\n", t.get<std::string>().c_str());
+ }
+ }
+ }
+ const auto grammar_triggers = data.find("grammar_triggers");
+ if (grammar_triggers != data.end()) {
+ for (const auto & t : *grammar_triggers) {
+ server_grammar_trigger ct(t);
+ if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
+ const auto & word = ct.value.value;
+ auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true);
+ if (ids.size() == 1) {
+ auto token = ids[0];
+ if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) {
+ throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word);
+ }
+ SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str());
+ common_grammar_trigger trigger;
+ trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
+ trigger.value = word;
+ trigger.token = token;
+ params.sampling.grammar_triggers.push_back(std::move(trigger));
+ } else {
+ SRV_DBG("Grammar trigger word: `%s`\n", word.c_str());
+ params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
+ }
+ } else {
+ if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) {
+ SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str());
+ } else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) {
+ SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str());
+ } else {
+ throw std::runtime_error("Unknown grammar trigger type");
+ }
+ params.sampling.grammar_triggers.emplace_back(std::move(ct.value));
+ }
+ }
+ }
+ if (params.sampling.grammar_lazy && params.sampling.grammar_triggers.empty()) {
+ throw std::runtime_error("Error: no triggers set for lazy grammar!");
+ }
+ }
+
+ {
+ params.sampling.logit_bias.clear();
+
+ const auto & logit_bias = data.find("logit_bias");
+ if (logit_bias != data.end() && logit_bias->is_array()) {
+ const int n_vocab = llama_vocab_n_tokens(vocab);
+ for (const auto & el : *logit_bias) {
+ // TODO: we may want to throw errors here, in case "el" is incorrect
+ if (el.is_array() && el.size() == 2) {
+ float bias;
+ if (el[1].is_number()) {
+ bias = el[1].get<float>();
+ } else if (el[1].is_boolean() && !el[1].get<bool>()) {
+ bias = -INFINITY;
+ } else {
+ continue;
+ }
+
+ if (el[0].is_number_integer()) {
+ llama_token tok = el[0].get<llama_token>();
+ if (tok >= 0 && tok < n_vocab) {
+ params.sampling.logit_bias.push_back({tok, bias});
+ }
+ } else if (el[0].is_string()) {
+ auto toks = common_tokenize(vocab, el[0].get<std::string>(), false);
+ for (auto tok : toks) {
+ params.sampling.logit_bias.push_back({tok, bias});
+ }
+ }
+ }
+ }
+ } else if (logit_bias != data.end() && logit_bias->is_object()) {
+ const int n_vocab = llama_vocab_n_tokens(vocab);
+ for (const auto & el : logit_bias->items()) {
+ float bias;
+ const auto & key = el.key();
+ const auto & value = el.value();
+ if (value.is_number()) {
+ bias = value.get<float>();
+ } else if (value.is_boolean() && !value.get<bool>()) {
+ bias = -INFINITY;
+ } else {
+ continue;
+ }
+
+ char *end;
+ llama_token tok = strtol(key.c_str(), &end, 10);
+ if (*end == 0) {
+ if (tok >= 0 && tok < n_vocab) {
+ params.sampling.logit_bias.push_back({tok, bias});
+ }
+ } else {
+ auto toks = common_tokenize(vocab, key, false);
+ for (auto tok : toks) {
+ params.sampling.logit_bias.push_back({tok, bias});
+ }
+ }
+ }
+ }
+
+ params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos);
+ if (params.sampling.ignore_eos) {
+ params.sampling.logit_bias.insert(
+ params.sampling.logit_bias.end(),
+ defaults.sampling.logit_bias_eog.begin(), defaults.sampling.logit_bias_eog.end());
+ }
+ }
+
+ {
+ params.antiprompt.clear();
+
+ const auto & stop = data.find("stop");
+ if (stop != data.end() && stop->is_array()) {
+ for (const auto & word : *stop) {
+ if (!word.empty()) {
+ params.antiprompt.push_back(word);
+ }
+ }
+ }
+ // set reverse prompt from cli args if not set in the request
+ if (params.antiprompt.empty()) {
+ params.antiprompt = defaults.antiprompt;
+ }
+ }
+
+ {
+ const auto samplers = data.find("samplers");
+ if (samplers != data.end()) {
+ if (samplers->is_array()) {
+ params.sampling.samplers = common_sampler_types_from_names(*samplers, false);
+ } else if (samplers->is_string()){
+ params.sampling.samplers = common_sampler_types_from_chars(samplers->get<std::string>());
+ }
+ } else {
+ params.sampling.samplers = defaults.sampling.samplers;
+ }
+ }
+
+ std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
+ params.oaicompat_model = json_value(data, "model", model_name);
+
+ return params;
+}
+
+//
+// result_timings
+//
+
+json result_timings::to_json() const {
+ json base = {
+ {"cache_n", cache_n},
+
+ {"prompt_n", prompt_n},
+ {"prompt_ms", prompt_ms},
+ {"prompt_per_token_ms", prompt_per_token_ms},
+ {"prompt_per_second", prompt_per_second},
+
+ {"predicted_n", predicted_n},
+ {"predicted_ms", predicted_ms},
+ {"predicted_per_token_ms", predicted_per_token_ms},
+ {"predicted_per_second", predicted_per_second},
+ };
+
+ if (draft_n > 0) {
+ base["draft_n"] = draft_n;
+ base["draft_n_accepted"] = draft_n_accepted;
+ }
+
+ return base;
+}
+
+//
+// result_prompt_progress
+//
+json result_prompt_progress::to_json() const {
+ return json {
+ {"total", total},
+ {"cache", cache},
+ {"processed", processed},
+ {"time_ms", time_ms},
+ };
+}
+
+static inline std::string stop_type_to_str(stop_type type) {
+ switch (type) {
+ case STOP_TYPE_EOS: return "eos";
+ case STOP_TYPE_WORD: return "word";
+ case STOP_TYPE_LIMIT: return "limit";
+ default: return "none";
+ }
+}
+
+//
+// completion_token_output
+//
+
+json completion_token_output::to_json(bool post_sampling_probs) const {
+ json probs_for_token = json::array();
+ for (const auto & p : probs) {
+ std::string txt(p.txt);
+ txt.resize(validate_utf8(txt));
+ probs_for_token.push_back(json {
+ {"id", p.tok},
+ {"token", txt},
+ {"bytes", str_to_bytes(p.txt)},
+ {
+ post_sampling_probs ? "prob" : "logprob",
+ post_sampling_probs ? p.prob : logarithm(p.prob)
+ },
+ });
+ }
+ return probs_for_token;
+}
+
+json completion_token_output::probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs) {
+ json out = json::array();
+ for (const auto & p : probs) {
+ std::string txt(p.text_to_send);
+ txt.resize(validate_utf8(txt));
+ out.push_back(json {
+ {"id", p.tok},
+ {"token", txt},
+ {"bytes", str_to_bytes(p.text_to_send)},
+ {
+ post_sampling_probs ? "prob" : "logprob",
+ post_sampling_probs ? p.prob : logarithm(p.prob)
+ },
+ {
+ post_sampling_probs ? "top_probs" : "top_logprobs",
+ p.to_json(post_sampling_probs)
+ },
+ });
+ }
+ return out;
+}
+
+float completion_token_output::logarithm(float x) {
+ // nlohmann::json converts -inf to null, so we need to prevent that
+ return x == 0.0f ? std::numeric_limits<float>::lowest() : std::log(x);
+}
+
+std::vector<unsigned char> completion_token_output::str_to_bytes(const std::string & str) {
+ std::vector<unsigned char> bytes;
+ for (unsigned char c : str) {
+ bytes.push_back(c);
+ }
+ return bytes;
+}
+
+//
+// server_task_result_cmpl_final
+//
+json server_task_result_cmpl_final::to_json() {
+ switch (oaicompat) {
+ case OAICOMPAT_TYPE_NONE:
+ return to_json_non_oaicompat();
+ case OAICOMPAT_TYPE_COMPLETION:
+ return to_json_oaicompat();
+ case OAICOMPAT_TYPE_CHAT:
+ return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat();
+ default:
+ GGML_ASSERT(false && "Invalid oaicompat_type");
+ }
+}
+
+json server_task_result_cmpl_final::to_json_non_oaicompat() {
+ json res = json {
+ {"index", index},
+ {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
+ {"tokens", stream ? llama_tokens {} : tokens},
+ {"id_slot", id_slot},
+ {"stop", true},
+ {"model", oaicompat_model},
+ {"tokens_predicted", n_decoded},
+ {"tokens_evaluated", n_prompt_tokens},
+ {"generation_settings", generation_params.to_json()},
+ {"prompt", prompt},
+ {"has_new_line", has_new_line},
+ {"truncated", truncated},
+ {"stop_type", stop_type_to_str(stop)},
+ {"stopping_word", stopping_word},
+ {"tokens_cached", n_tokens_cached},
+ {"timings", timings.to_json()},
+ };
+ if (!stream && !probs_output.empty()) {
+ res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
+ }
+ return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
+}
+
+json server_task_result_cmpl_final::to_json_oaicompat() {
+ std::time_t t = std::time(0);
+ json logprobs = json(nullptr); // OAI default to null
+ if (!stream && probs_output.size() > 0) {
+ logprobs = json{
+ {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
+ };
+ }
+ json finish_reason = "length";
+ if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
+ finish_reason = "stop";
+ }
+ json res = json {
+ {"choices", json::array({
+ json{
+ {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
+ {"index", index},
+ {"logprobs", logprobs},
+ {"finish_reason", finish_reason},
+ }
+ })},
+ {"created", t},
+ {"model", oaicompat_model},
+ {"system_fingerprint", build_info},
+ {"object", "text_completion"},
+ {"usage", json {
+ {"completion_tokens", n_decoded},
+ {"prompt_tokens", n_prompt_tokens},
+ {"total_tokens", n_decoded + n_prompt_tokens}
+ }},
+ {"id", oaicompat_cmpl_id}
+ };
+
+ // extra fields for debugging purposes
+ if (verbose) {
+ res["__verbose"] = to_json_non_oaicompat();
+ }
+ if (timings.prompt_n >= 0) {
+ res.push_back({"timings", timings.to_json()});
+ }
+
+ return res;
+}
+
+json server_task_result_cmpl_final::to_json_oaicompat_chat() {
+ std::string finish_reason = "length";
+ common_chat_msg msg;
+ if (!oaicompat_msg.empty()) {
+ msg = oaicompat_msg;
+ } else {
+ msg.role = "assistant";
+ msg.content = content;
+ }
+ if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
+ finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
+ }
+
+ json choice {
+ {"finish_reason", finish_reason},
+ {"index", 0},
+ {"message", msg.to_json_oaicompat<json>()},
+ };
+
+ if (!stream && probs_output.size() > 0) {
+ choice["logprobs"] = json{
+ {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
+ };
+ }
+
+ std::time_t t = std::time(0);
+
+ json res = json {
+ {"choices", json::array({choice})},
+ {"created", t},
+ {"model", oaicompat_model},
+ {"system_fingerprint", build_info},
+ {"object", "chat.completion"},
+ {"usage", json {
+ {"completion_tokens", n_decoded},
+ {"prompt_tokens", n_prompt_tokens},
+ {"total_tokens", n_decoded + n_prompt_tokens}
+ }},
+ {"id", oaicompat_cmpl_id}
+ };
+
+ // extra fields for debugging purposes
+ if (verbose) {
+ res["__verbose"] = to_json_non_oaicompat();
+ }
+ if (timings.prompt_n >= 0) {
+ res.push_back({"timings", timings.to_json()});
+ }
+
+ return res;
+}
+
+json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() {
+ std::time_t t = std::time(0);
+ std::string finish_reason = "length";
+ if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
+ finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls";
+ }
+
+ json deltas = json::array();
+ for (const auto & diff : oaicompat_msg_diffs) {
+ deltas.push_back({
+ {"choices", json::array({
+ json {
+ {"finish_reason", nullptr},
+ {"index", 0},
+ {"delta", common_chat_msg_diff_to_json_oaicompat<json>(diff)},
+ },
+ })},
+ {"created", t},
+ {"id", oaicompat_cmpl_id},
+ {"model", oaicompat_model},
+ {"system_fingerprint", build_info},
+ {"object", "chat.completion.chunk"},
+ });
+ }
+
+ deltas.push_back({
+ {"choices", json::array({
+ json {
+ {"finish_reason", finish_reason},
+ {"index", 0},
+ {"delta", json::object()},
+ },
+ })},
+ {"created", t},
+ {"id", oaicompat_cmpl_id},
+ {"model", oaicompat_model},
+ {"system_fingerprint", build_info},
+ {"object", "chat.completion.chunk"},
+ });
+
+ if (include_usage) {
+ // OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage
+ // https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices
+ deltas.push_back({
+ {"choices", json::array()},
+ {"created", t},
+ {"id", oaicompat_cmpl_id},
+ {"model", oaicompat_model},
+ {"system_fingerprint", build_info},
+ {"object", "chat.completion.chunk"},
+ {"usage", json {
+ {"completion_tokens", n_decoded},
+ {"prompt_tokens", n_prompt_tokens},
+ {"total_tokens", n_decoded + n_prompt_tokens},
+ }},
+ });
+ }
+
+ if (timings.prompt_n >= 0) {
+ deltas.back().push_back({"timings", timings.to_json()});
+ }
+
+ // extra fields for debugging purposes
+ if (verbose && !deltas.empty()) {
+ deltas.front()["__verbose"] = to_json_non_oaicompat();
+ }
+
+ return deltas;
+}
+
+//
+// server_task_result_cmpl_partial
+//
+json server_task_result_cmpl_partial::to_json() {
+ switch (oaicompat) {
+ case OAICOMPAT_TYPE_NONE:
+ return to_json_non_oaicompat();
+ case OAICOMPAT_TYPE_COMPLETION:
+ return to_json_oaicompat();
+ case OAICOMPAT_TYPE_CHAT:
+ return to_json_oaicompat_chat();
+ default:
+ GGML_ASSERT(false && "Invalid oaicompat_type");
+ }
+}
+
+json server_task_result_cmpl_partial::to_json_non_oaicompat() {
+ // non-OAI-compat JSON
+ json res = json {
+ {"index", index},
+ {"content", content},
+ {"tokens", tokens},
+ {"stop", false},
+ {"id_slot", id_slot},
+ {"tokens_predicted", n_decoded},
+ {"tokens_evaluated", n_prompt_tokens},
+ };
+ // populate the timings object when needed (usually for the last response or with timings_per_token enabled)
+ if (timings.prompt_n > 0) {
+ res.push_back({"timings", timings.to_json()});
+ }
+ if (is_progress) {
+ res.push_back({"prompt_progress", progress.to_json()});
+ }
+ if (!prob_output.probs.empty()) {
+ res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs);
+ }
+ return res;
+}
+
+json server_task_result_cmpl_partial::to_json_oaicompat() {
+ std::time_t t = std::time(0);
+ json logprobs = json(nullptr); // OAI default to null
+ if (prob_output.probs.size() > 0) {
+ logprobs = json{
+ {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
+ };
+ }
+ json res = json {
+ {"choices", json::array({
+ json{
+ {"text", content},
+ {"index", index},
+ {"logprobs", logprobs},
+ {"finish_reason", nullptr},
+ }
+ })},
+ {"created", t},
+ {"model", oaicompat_model},
+ {"system_fingerprint", build_info},
+ {"object", "text_completion"},
+ {"id", oaicompat_cmpl_id}
+ };
+
+ // extra fields for debugging purposes
+ if (verbose) {
+ res["__verbose"] = to_json_non_oaicompat();
+ }
+ if (timings.prompt_n >= 0) {
+ res.push_back({"timings", timings.to_json()});
+ }
+ if (is_progress) {
+ res.push_back({"prompt_progress", progress.to_json()});
+ }
+
+ return res;
+}
+
+json server_task_result_cmpl_partial::to_json_oaicompat_chat() {
+ bool first = n_decoded == 1;
+ std::time_t t = std::time(0);
+ json choices;
+
+ std::vector<json> deltas;
+ auto add_delta = [&](const json & delta) {
+ deltas.push_back({
+ {"choices", json::array({
+ json {
+ {"finish_reason", nullptr},
+ {"index", 0},
+ {"delta", delta},
+ },
+ })},
+ {"created", t},
+ {"id", oaicompat_cmpl_id},
+ {"model", oaicompat_model},
+ {"system_fingerprint", build_info},
+ {"object", "chat.completion.chunk"},
+ });
+ };
+ // We have to send an initial update to conform to openai behavior
+ if (first || is_progress) {
+ add_delta({
+ {"role", "assistant"},
+ {"content", nullptr},
+ });
+ }
+
+ for (const auto & diff : oaicompat_msg_diffs) {
+ add_delta(common_chat_msg_diff_to_json_oaicompat<json>(diff));
+ }
+
+ if (!deltas.empty()) {
+ auto & last_json = deltas[deltas.size() - 1];
+ GGML_ASSERT(last_json.at("choices").size() >= 1);
+
+ if (prob_output.probs.size() > 0) {
+ last_json.at("choices").at(0)["logprobs"] = json {
+ {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
+ };
+ }
+
+ if (timings.prompt_n >= 0) {
+ last_json.push_back({"timings", timings.to_json()});
+ }
+ if (is_progress) {
+ last_json.push_back({"prompt_progress", progress.to_json()});
+ }
+ }
+
+ return deltas;
+}
+
+//
+// server_task_result_embd
+//
+json server_task_result_embd::to_json() {
+ return oaicompat == OAICOMPAT_TYPE_EMBEDDING
+ ? to_json_oaicompat()
+ : to_json_non_oaicompat();
+}
+
+json server_task_result_embd::to_json_non_oaicompat() {
+ return json {
+ {"index", index},
+ {"embedding", embedding},
+ };
+}
+
+json server_task_result_embd::to_json_oaicompat() {
+ return json {
+ {"index", index},
+ {"embedding", embedding[0]},
+ {"tokens_evaluated", n_tokens},
+ };
+}
+
+//
+// server_task_result_rerank
+//
+json server_task_result_rerank::to_json() {
+ return json {
+ {"index", index},
+ {"score", score},
+ {"tokens_evaluated", n_tokens},
+ };
+}
+
+//
+// server_task_result_error
+//
+json server_task_result_error::to_json() {
+ json res = format_error_response(err_msg, err_type);
+ if (err_type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) {
+ res["n_prompt_tokens"] = n_prompt_tokens;
+ res["n_ctx"] = n_ctx;
+ }
+ return res;
+}
+
+//
+// server_task_result_metrics
+//
+json server_task_result_metrics::to_json() {
+ return json {
+ { "idle", n_idle_slots },
+ { "processing", n_processing_slots },
+ { "deferred", n_tasks_deferred },
+ { "t_start", t_start },
+
+ { "n_prompt_tokens_processed_total", n_prompt_tokens_processed_total },
+ { "t_tokens_generation_total", t_tokens_generation_total },
+ { "n_tokens_predicted_total", n_tokens_predicted_total },
+ { "t_prompt_processing_total", t_prompt_processing_total },
+
+ { "n_tokens_max", n_tokens_max },
+
+ { "n_prompt_tokens_processed", n_prompt_tokens_processed },
+ { "t_prompt_processing", t_prompt_processing },
+ { "n_tokens_predicted", n_tokens_predicted },
+ { "t_tokens_generation", t_tokens_generation },
+
+ { "n_decode_total", n_decode_total },
+ { "n_busy_slots_total", n_busy_slots_total },
+
+ { "slots", slots_data },
+ };
+}
+
+//
+// server_task_result_slot_save_load
+//
+json server_task_result_slot_save_load::to_json() {
+ if (is_save) {
+ return json {
+ { "id_slot", id_slot },
+ { "filename", filename },
+ { "n_saved", n_tokens },
+ { "n_written", n_bytes },
+ { "timings", {
+ { "save_ms", t_ms }
+ }},
+ };
+ }
+
+ return json {
+ { "id_slot", id_slot },
+ { "filename", filename },
+ { "n_restored", n_tokens },
+ { "n_read", n_bytes },
+ { "timings", {
+ { "restore_ms", t_ms }
+ }},
+ };
+}
+
+//
+// server_task_result_slot_erase
+//
+json server_task_result_slot_erase::to_json() {
+ return json {
+ { "id_slot", id_slot },
+ { "n_erased", n_erased },
+ };
+}
+
+//
+// server_task_result_apply_lora
+//
+
+json server_task_result_apply_lora::to_json() {
+ return json {{ "success", true }};
+}
+
+//
+// server_prompt_cache
+//
+size_t server_prompt_cache::size() const {
+ size_t res = 0;
+
+ for (const auto & state : states) {
+ res += state.size();
+ }
+
+ return res;
+}
+
+size_t server_prompt_cache::n_tokens() const {
+ size_t res = 0;
+
+ for (const auto & state : states) {
+ res += state.n_tokens();
+ }
+
+ return res;
+}
+
+server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size) {
+ // first check if the current state is contained fully in the cache
+ for (auto it = states.begin(); it != states.end(); ++it) {
+ const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens);
+
+ if (cur_lcp_len == (int) prompt.tokens.size()) {
+ SRV_WRN("%s", " - prompt is already in the cache, skipping\n");
+ return nullptr;
+ }
+ }
+
+ // next, remove any cached prompts that are fully contained in the current prompt
+ for (auto it = states.begin(); it != states.end();) {
+ const int len = it->tokens.get_common_prefix(prompt.tokens);
+
+ if (len == (int) it->tokens.size()) {
+ SRV_WRN(" - removing obsolete cached prompt with length %d\n", len);
+
+ it = states.erase(it);
+ } else {
+ ++it;
+ }
+ }
+
+ std::vector<uint8_t> state_data;
+
+ // check if we can allocate enough memory for the new state
+ try {
+ state_data.resize(state_size);
+ } catch (const std::bad_alloc & e) {
+ SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what());
+
+ limit_size = std::max<size_t>(1, 0.4*size());
+
+ SRV_WRN(" - cache size limit reduced to %.3f MiB\n", limit_size / (1024.0 * 1024.0));
+
+ update();
+
+ return nullptr;
+ }
+
+ // TODO: for some reason we can't copy server_tokens, so we have to do this workaround
+ auto & cur = states.emplace_back();
+ cur = {
+ /*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false),
+ /*.data =*/ std::move(state_data),
+ /*.checkpoints =*/ prompt.checkpoints,
+ };
+
+ return &cur;
+}
+
+bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) {
+ const int lcp_best = prompt.tokens.get_common_prefix(tokens_new);
+
+ float f_keep_best = float(lcp_best) / prompt.tokens.size();
+ float sim_best = float(lcp_best) / tokens_new.size();
+
+ SRV_WRN(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
+
+ auto it_best = states.end();
+
+ // find the most similar cached prompt, that would also preserve the most context
+ for (auto it = states.begin(); it != states.end(); ++it) {
+ const int lcp_cur = it->tokens.get_common_prefix(tokens_new);
+
+ const float f_keep_cur = float(lcp_cur) / it->tokens.size();
+ const float sim_cur = float(lcp_cur) / tokens_new.size();
+
+ // don't trash large prompts
+ if (f_keep_cur < 0.25f) {
+ continue;
+ }
+
+ if (f_keep_best < f_keep_cur && sim_best < sim_cur) {
+ f_keep_best = f_keep_cur;
+ sim_best = sim_cur;
+
+ it_best = it;
+ }
+ }
+
+ if (it_best != states.end()) {
+ SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
+
+ const size_t size = it_best->data.size();
+ const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0);
+ if (n != size) {
+ SRV_WRN("failed to restore state with size %zu\n", size);
+
+ return false;
+ }
+
+ it_best->data.clear();
+ it_best->data.shrink_to_fit();
+
+ prompt = std::move(*it_best);
+
+ states.erase(it_best);
+ }
+
+ return true;
+}
+
+void server_prompt_cache::update() {
+ if (limit_size > 0) {
+ // always keep at least one state, regardless of the limits
+ while (states.size() > 1 && size() > limit_size) {
+ if (states.empty()) {
+ break;
+ }
+
+ SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0));
+
+ states.pop_front();
+ }
+ }
+
+ // average size per token
+ const float size_per_token = std::max<float>(1.0f, float(size()) / (std::max<size_t>(1, n_tokens())));
+
+ // dynamically increase the token limit if it can fit in the memory limit
+ const size_t limit_tokens_cur = limit_size > 0 ? std::max<size_t>(limit_tokens, limit_size/size_per_token) : limit_tokens;
+
+ if (limit_tokens > 0) {
+ while (states.size() > 1 && n_tokens() > limit_tokens_cur) {
+ if (states.empty()) {
+ break;
+ }
+
+ SRV_WRN(" - cache token limit (%zu, est: %zu) reached, removing oldest entry (size = %.3f MiB)\n",
+ limit_tokens, limit_tokens_cur, states.front().size() / (1024.0 * 1024.0));
+
+ states.pop_front();
+ }
+ }
+
+ SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n",
+ states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur);
+
+ for (const auto & state : states) {
+ SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n",
+ (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
+ }
+}
--- /dev/null
+#pragma once
+
+#include "common.h"
+#include "llama.h"
+
+#include <string>
+#include <unordered_set>
+#include <list>
+
+// TODO: prevent including the whole server-common.h as we only use server_tokens
+#include "server-common.h"
+
+using json = nlohmann::ordered_json;
+
+enum server_task_type {
+ SERVER_TASK_TYPE_COMPLETION,
+ SERVER_TASK_TYPE_EMBEDDING,
+ SERVER_TASK_TYPE_RERANK,
+ SERVER_TASK_TYPE_INFILL,
+ SERVER_TASK_TYPE_CANCEL,
+ SERVER_TASK_TYPE_NEXT_RESPONSE,
+ SERVER_TASK_TYPE_METRICS,
+ SERVER_TASK_TYPE_SLOT_SAVE,
+ SERVER_TASK_TYPE_SLOT_RESTORE,
+ SERVER_TASK_TYPE_SLOT_ERASE,
+ SERVER_TASK_TYPE_SET_LORA,
+};
+
+// TODO: change this to more generic "response_format" to replace the "format_response_*" in server-common
+enum oaicompat_type {
+ OAICOMPAT_TYPE_NONE,
+ OAICOMPAT_TYPE_CHAT,
+ OAICOMPAT_TYPE_COMPLETION,
+ OAICOMPAT_TYPE_EMBEDDING,
+};
+
+enum stop_type {
+ STOP_TYPE_NONE,
+ STOP_TYPE_EOS,
+ STOP_TYPE_WORD,
+ STOP_TYPE_LIMIT,
+};
+
+struct task_params {
+ bool stream = true;
+ bool include_usage = false;
+ bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
+ bool return_tokens = false;
+ bool return_progress = false;
+
+ int32_t n_keep = 0; // number of tokens to keep from initial prompt
+ int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
+ int32_t n_predict = -1; // new tokens to predict
+ int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters
+
+ int64_t t_max_prompt_ms = -1; // TODO: implement
+ int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
+
+ std::vector<common_adapter_lora_info> lora;
+
+ std::vector<std::string> antiprompt;
+ std::vector<std::string> response_fields;
+ bool timings_per_token = false;
+ bool post_sampling_probs = false;
+
+ struct common_params_sampling sampling;
+ struct common_params_speculative speculative;
+
+ // OAI-compat fields
+ bool verbose = false;
+ oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
+ std::string oaicompat_model;
+ std::string oaicompat_cmpl_id;
+ common_chat_syntax oaicompat_chat_syntax;
+
+ // Embeddings
+ int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
+
+ json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) const;
+ json to_json(bool only_metrics = false) const;
+};
+
+struct server_task {
+ int id = -1; // to be filled by server_queue
+ int index = -1; // used when there are multiple prompts (batch request)
+
+ // used by SERVER_TASK_TYPE_CANCEL
+ int id_target = -1;
+ int id_slot = -1;
+
+ // used by SERVER_TASK_TYPE_INFERENCE
+ task_params params;
+ server_tokens tokens;
+
+ server_task_type type;
+
+ // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
+ struct slot_action {
+ int slot_id;
+ std::string filename;
+ std::string filepath;
+ };
+ slot_action slot_action;
+
+ // used by SERVER_TASK_TYPE_METRICS
+ bool metrics_reset_bucket = false;
+
+ // used by SERVER_TASK_TYPE_SET_LORA
+ std::vector<common_adapter_lora_info> set_lora;
+
+ server_task() = default;
+
+ server_task(server_task_type type) : type(type) {}
+
+ int32_t n_tokens() const {
+ return tokens.size();
+ }
+
+ static task_params params_from_json_cmpl(
+ const llama_context * ctx,
+ const common_params & params_base,
+ const json & data);
+
+ // utility function
+ static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
+ std::unordered_set<int> ids(tasks.size());
+ for (size_t i = 0; i < tasks.size(); i++) {
+ ids.insert(tasks[i].id);
+ }
+ return ids;
+ }
+};
+
+struct result_timings {
+ int32_t cache_n = -1;
+
+ int32_t prompt_n = -1;
+ double prompt_ms;
+ double prompt_per_token_ms;
+ double prompt_per_second;
+
+ int32_t predicted_n = -1;
+ double predicted_ms;
+ double predicted_per_token_ms;
+ double predicted_per_second;
+
+ // Optional speculative metrics - only included when > 0
+ int32_t draft_n = 0;
+ int32_t draft_n_accepted = 0;
+
+ json to_json() const;
+};
+
+struct result_prompt_progress {
+ int32_t total = 0;
+ int32_t cache = 0;
+ int32_t processed = 0;
+ int64_t time_ms = 0;
+
+ json to_json() const;
+};
+
+struct server_task_result {
+ int id = -1;
+ int id_slot = -1;
+ virtual bool is_error() {
+ // only used by server_task_result_error
+ return false;
+ }
+ virtual bool is_stop() {
+ // only used by server_task_result_cmpl_*
+ return true;
+ }
+ virtual int get_index() {
+ return -1;
+ }
+ virtual json to_json() = 0;
+ virtual ~server_task_result() = default;
+};
+
+// using shared_ptr for polymorphism of server_task_result
+using server_task_result_ptr = std::unique_ptr<server_task_result>;
+
+struct completion_token_output {
+ llama_token tok;
+ float prob;
+ std::string text_to_send;
+ struct prob_info {
+ llama_token tok;
+ std::string txt;
+ float prob;
+ };
+ std::vector<prob_info> probs;
+
+ json to_json(bool post_sampling_probs) const;
+
+ static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs);
+
+ static float logarithm(float x);
+
+ static std::vector<unsigned char> str_to_bytes(const std::string & str);
+
+};
+
+struct server_task_result_cmpl_final : server_task_result {
+ int index = 0;
+
+ std::string content;
+ llama_tokens tokens;
+
+ bool stream;
+ bool include_usage;
+ result_timings timings;
+ std::string prompt;
+
+ bool truncated;
+ int32_t n_decoded;
+ int32_t n_prompt_tokens;
+ int32_t n_tokens_cached;
+ bool has_new_line;
+ std::string stopping_word;
+ stop_type stop = STOP_TYPE_NONE;
+
+ bool post_sampling_probs;
+ std::vector<completion_token_output> probs_output;
+ std::vector<std::string> response_fields;
+
+ task_params generation_params;
+
+ // OAI-compat fields
+ bool verbose = false;
+ oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
+ std::string oaicompat_model;
+ std::string oaicompat_cmpl_id;
+ common_chat_msg oaicompat_msg;
+
+ std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
+
+ virtual int get_index() override {
+ return index;
+ }
+
+ virtual bool is_stop() override {
+ return true; // in stream mode, final responses are considered stop
+ }
+
+ virtual json to_json() override;
+
+ json to_json_non_oaicompat();
+
+ json to_json_oaicompat();
+
+ json to_json_oaicompat_chat();
+
+ json to_json_oaicompat_chat_stream();
+};
+
+struct server_task_result_cmpl_partial : server_task_result {
+ int index = 0;
+
+ std::string content;
+ llama_tokens tokens;
+
+ int32_t n_decoded;
+ int32_t n_prompt_tokens;
+
+ bool post_sampling_probs;
+ bool is_progress = false;
+ completion_token_output prob_output;
+ result_timings timings;
+ result_prompt_progress progress;
+
+ // OAI-compat fields
+ bool verbose = false;
+ oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
+ std::string oaicompat_model;
+ std::string oaicompat_cmpl_id;
+ std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
+
+ virtual int get_index() override {
+ return index;
+ }
+
+ virtual bool is_stop() override {
+ return false; // in stream mode, partial responses are not considered stop
+ }
+
+ virtual json to_json() override;
+
+ json to_json_non_oaicompat();
+
+ json to_json_oaicompat();
+
+ json to_json_oaicompat_chat();
+};
+
+struct server_task_result_embd : server_task_result {
+ int index = 0;
+ std::vector<std::vector<float>> embedding;
+
+ int32_t n_tokens;
+
+ // OAI-compat fields
+ oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
+
+ virtual int get_index() override {
+ return index;
+ }
+
+ virtual json to_json() override;
+
+ json to_json_non_oaicompat();
+
+ json to_json_oaicompat();
+};
+
+struct server_task_result_rerank : server_task_result {
+ int index = 0;
+ float score = -1e6;
+
+ int32_t n_tokens;
+
+ virtual int get_index() override {
+ return index;
+ }
+
+ virtual json to_json() override;
+};
+
+struct server_task_result_error : server_task_result {
+ int index = 0;
+ error_type err_type = ERROR_TYPE_SERVER;
+ std::string err_msg;
+
+ // for ERROR_TYPE_EXCEED_CONTEXT_SIZE
+ int32_t n_prompt_tokens = 0;
+ int32_t n_ctx = 0;
+
+ virtual bool is_error() override {
+ return true;
+ }
+
+ virtual json to_json() override;
+};
+
+struct server_task_result_metrics : server_task_result {
+ int n_idle_slots;
+ int n_processing_slots;
+ int n_tasks_deferred;
+ int64_t t_start;
+
+ // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields
+ uint64_t n_prompt_tokens_processed_total = 0;
+ uint64_t t_prompt_processing_total = 0;
+ uint64_t n_tokens_predicted_total = 0;
+ uint64_t t_tokens_generation_total = 0;
+
+ uint64_t n_tokens_max = 0;
+
+ uint64_t n_prompt_tokens_processed = 0;
+ uint64_t t_prompt_processing = 0;
+
+ uint64_t n_tokens_predicted = 0;
+ uint64_t t_tokens_generation = 0;
+
+ uint64_t n_decode_total = 0;
+ uint64_t n_busy_slots_total = 0;
+
+ // while we can also use std::vector<server_slot> this requires copying the slot object which can be quite messy
+ // therefore, we use json to temporarily store the slot.to_json() result
+ json slots_data = json::array();
+
+ virtual json to_json() override;
+};
+
+struct server_task_result_slot_save_load : server_task_result {
+ std::string filename;
+ bool is_save; // true = save, false = load
+
+ size_t n_tokens;
+ size_t n_bytes;
+ double t_ms;
+
+ virtual json to_json() override;
+};
+
+struct server_task_result_slot_erase : server_task_result {
+ size_t n_erased;
+
+ virtual json to_json() override;
+};
+
+struct server_task_result_apply_lora : server_task_result {
+ virtual json to_json() override;
+};
+
+struct server_prompt_checkpoint {
+ llama_pos pos_min;
+ llama_pos pos_max;
+
+ std::vector<uint8_t> data;
+
+ size_t size() const {
+ return data.size();
+ }
+};
+
+struct server_prompt {
+ server_tokens tokens;
+
+ std::vector<uint8_t> data;
+
+ std::list<server_prompt_checkpoint> checkpoints;
+
+ size_t size() const {
+ size_t res = data.size();
+
+ for (const auto & checkpoint : checkpoints) {
+ res += checkpoint.size();
+ }
+
+ return res;
+ }
+
+ int n_tokens() const {
+ return tokens.size();
+ }
+};
+
+struct server_prompt_cache {
+ server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) {
+ this->limit_size = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib);
+ this->limit_tokens = limit_tokens;
+ }
+
+ std::list<server_prompt> states;
+
+ // in bytes, 0 = no limit
+ size_t limit_size = 0;
+
+ // in tokens, 0 = no limit
+ size_t limit_tokens = 0;
+
+ size_t size() const;
+
+ size_t n_tokens() const;
+
+ server_prompt * alloc(const server_prompt & prompt, size_t state_size);
+
+ bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot);
+
+ void update();
+};
-#include "chat.h"
-#include "utils.hpp"
+#include "server-common.h"
#include "server-http.h"
+#include "server-task.h"
+#include "server-queue.h"
#include "arg.h"
#include "common.h"
-#include "json-schema-to-grammar.h"
#include "llama.h"
#include "log.h"
#include "sampling.h"
#include "speculative.h"
#include "mtmd.h"
+#include "mtmd-helper.h"
#include <atomic>
-#include <chrono>
-#include <condition_variable>
#include <cstddef>
#include <cinttypes>
-#include <deque>
#include <memory>
-#include <mutex>
-#include <list>
#include <signal.h>
#include <thread>
#include <unordered_set>
constexpr int HTTP_POLLING_SECONDS = 1;
-enum stop_type {
- STOP_TYPE_NONE,
- STOP_TYPE_EOS,
- STOP_TYPE_WORD,
- STOP_TYPE_LIMIT,
-};
-
-// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
-enum slot_state {
- SLOT_STATE_IDLE,
- SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
- SLOT_STATE_PROCESSING_PROMPT,
- SLOT_STATE_DONE_PROMPT,
- SLOT_STATE_GENERATING,
-};
-
-enum server_state {
- SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
- SERVER_STATE_READY, // Server is ready and model is loaded
-};
-
-enum server_task_type {
- SERVER_TASK_TYPE_COMPLETION,
- SERVER_TASK_TYPE_EMBEDDING,
- SERVER_TASK_TYPE_RERANK,
- SERVER_TASK_TYPE_INFILL,
- SERVER_TASK_TYPE_CANCEL,
- SERVER_TASK_TYPE_NEXT_RESPONSE,
- SERVER_TASK_TYPE_METRICS,
- SERVER_TASK_TYPE_SLOT_SAVE,
- SERVER_TASK_TYPE_SLOT_RESTORE,
- SERVER_TASK_TYPE_SLOT_ERASE,
- SERVER_TASK_TYPE_SET_LORA,
-};
-
-enum oaicompat_type {
- OAICOMPAT_TYPE_NONE,
- OAICOMPAT_TYPE_CHAT,
- OAICOMPAT_TYPE_COMPLETION,
- OAICOMPAT_TYPE_EMBEDDING,
-};
-
-// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
-enum error_type {
- ERROR_TYPE_INVALID_REQUEST,
- ERROR_TYPE_AUTHENTICATION,
- ERROR_TYPE_SERVER,
- ERROR_TYPE_NOT_FOUND,
- ERROR_TYPE_PERMISSION,
- ERROR_TYPE_UNAVAILABLE, // custom error
- ERROR_TYPE_NOT_SUPPORTED, // custom error
- ERROR_TYPE_EXCEED_CONTEXT_SIZE, // custom error
-};
-
-static bool server_task_type_need_embd(server_task_type task_type) {
- switch (task_type) {
- case SERVER_TASK_TYPE_EMBEDDING:
- case SERVER_TASK_TYPE_RERANK:
- return true;
- default:
- return false;
- }
-}
-
-static bool server_task_type_need_logits(server_task_type task_type) {
- switch (task_type) {
- case SERVER_TASK_TYPE_COMPLETION:
- case SERVER_TASK_TYPE_INFILL:
- return true;
- default:
- return false;
- }
-}
-
-struct slot_params {
- bool stream = true;
- bool include_usage = false;
- bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
- bool return_tokens = false;
- bool return_progress = false;
-
- int32_t n_keep = 0; // number of tokens to keep from initial prompt
- int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
- int32_t n_predict = -1; // new tokens to predict
- int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters
-
- int64_t t_max_prompt_ms = -1; // TODO: implement
- int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
-
- std::vector<common_adapter_lora_info> lora;
-
- std::vector<std::string> antiprompt;
- std::vector<std::string> response_fields;
- bool timings_per_token = false;
- bool post_sampling_probs = false;
-
- struct common_params_sampling sampling;
- struct common_params_speculative speculative;
-
- // OAI-compat fields
- bool verbose = false;
- oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
- std::string oaicompat_model;
- std::string oaicompat_cmpl_id;
- common_chat_syntax oaicompat_chat_syntax;
-
- // Embeddings
- int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
-
- json to_json(bool only_metrics = false) const {
- std::vector<std::string> samplers;
- samplers.reserve(sampling.samplers.size());
- for (const auto & sampler : sampling.samplers) {
- samplers.emplace_back(common_sampler_type_to_str(sampler));
- }
-
- json lora = json::array();
- for (size_t i = 0; i < this->lora.size(); ++i) {
- lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
- }
-
- if (only_metrics) {
- return json {
- {"seed", sampling.seed},
- {"temperature", sampling.temp},
- {"dynatemp_range", sampling.dynatemp_range},
- {"dynatemp_exponent", sampling.dynatemp_exponent},
- {"top_k", sampling.top_k},
- {"top_p", sampling.top_p},
- {"min_p", sampling.min_p},
- {"top_n_sigma", sampling.top_n_sigma},
- {"xtc_probability", sampling.xtc_probability},
- {"xtc_threshold", sampling.xtc_threshold},
- {"typical_p", sampling.typ_p},
- {"repeat_last_n", sampling.penalty_last_n},
- {"repeat_penalty", sampling.penalty_repeat},
- {"presence_penalty", sampling.penalty_present},
- {"frequency_penalty", sampling.penalty_freq},
- {"dry_multiplier", sampling.dry_multiplier},
- {"dry_base", sampling.dry_base},
- {"dry_allowed_length", sampling.dry_allowed_length},
- {"dry_penalty_last_n", sampling.dry_penalty_last_n},
- {"mirostat", sampling.mirostat},
- {"mirostat_tau", sampling.mirostat_tau},
- {"mirostat_eta", sampling.mirostat_eta},
- {"max_tokens", n_predict},
- {"n_predict", n_predict}, // TODO: deduplicate?
- {"n_keep", n_keep},
- {"n_discard", n_discard},
- {"ignore_eos", sampling.ignore_eos},
- {"stream", stream},
- {"n_probs", sampling.n_probs},
- {"min_keep", sampling.min_keep},
- {"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)},
- {"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)},
- {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content},
- {"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open},
- {"samplers", samplers},
- {"speculative.n_max", speculative.n_max},
- {"speculative.n_min", speculative.n_min},
- {"speculative.p_min", speculative.p_min},
- {"timings_per_token", timings_per_token},
- {"post_sampling_probs", post_sampling_probs},
- {"lora", lora},
- };
- }
-
- auto grammar_triggers = json::array();
- for (const auto & trigger : sampling.grammar_triggers) {
- server_grammar_trigger ct(trigger);
- grammar_triggers.push_back(ct.to_json());
- }
-
- return json {
- {"seed", sampling.seed},
- {"temperature", sampling.temp},
- {"dynatemp_range", sampling.dynatemp_range},
- {"dynatemp_exponent", sampling.dynatemp_exponent},
- {"top_k", sampling.top_k},
- {"top_p", sampling.top_p},
- {"min_p", sampling.min_p},
- {"top_n_sigma", sampling.top_n_sigma},
- {"xtc_probability", sampling.xtc_probability},
- {"xtc_threshold", sampling.xtc_threshold},
- {"typical_p", sampling.typ_p},
- {"repeat_last_n", sampling.penalty_last_n},
- {"repeat_penalty", sampling.penalty_repeat},
- {"presence_penalty", sampling.penalty_present},
- {"frequency_penalty", sampling.penalty_freq},
- {"dry_multiplier", sampling.dry_multiplier},
- {"dry_base", sampling.dry_base},
- {"dry_allowed_length", sampling.dry_allowed_length},
- {"dry_penalty_last_n", sampling.dry_penalty_last_n},
- {"dry_sequence_breakers", sampling.dry_sequence_breakers},
- {"mirostat", sampling.mirostat},
- {"mirostat_tau", sampling.mirostat_tau},
- {"mirostat_eta", sampling.mirostat_eta},
- {"stop", antiprompt},
- {"max_tokens", n_predict},
- {"n_predict", n_predict}, // TODO: deduplicate?
- {"n_keep", n_keep},
- {"n_discard", n_discard},
- {"ignore_eos", sampling.ignore_eos},
- {"stream", stream},
- {"logit_bias", format_logit_bias(sampling.logit_bias)},
- {"n_probs", sampling.n_probs},
- {"min_keep", sampling.min_keep},
- {"grammar", sampling.grammar},
- {"grammar_lazy", sampling.grammar_lazy},
- {"grammar_triggers", grammar_triggers},
- {"preserved_tokens", sampling.preserved_tokens},
- {"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)},
- {"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)},
- {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content},
- {"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open},
- {"samplers", samplers},
- {"speculative.n_max", speculative.n_max},
- {"speculative.n_min", speculative.n_min},
- {"speculative.p_min", speculative.p_min},
- {"timings_per_token", timings_per_token},
- {"post_sampling_probs", post_sampling_probs},
- {"lora", lora},
- };
- }
-};
-
-struct server_task {
- int id = -1; // to be filled by server_queue
- int index = -1; // used when there are multiple prompts (batch request)
-
- // used by SERVER_TASK_TYPE_CANCEL
- int id_target = -1;
- int id_slot = -1;
-
- // used by SERVER_TASK_TYPE_INFERENCE
- slot_params params;
- server_tokens tokens;
-
- server_task_type type;
-
- // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
- struct slot_action {
- int slot_id;
- std::string filename;
- std::string filepath;
- };
- slot_action slot_action;
-
- // used by SERVER_TASK_TYPE_METRICS
- bool metrics_reset_bucket = false;
-
- // used by SERVER_TASK_TYPE_SET_LORA
- std::vector<common_adapter_lora_info> set_lora;
-
- server_task() = default;
-
- server_task(server_task_type type) : type(type) {}
-
- int32_t n_tokens() const {
- return tokens.size();
- }
-
- static slot_params params_from_json_cmpl(
- const llama_context * ctx,
- const common_params & params_base,
- const json & data) {
- const llama_model * model = llama_get_model(ctx);
- const llama_vocab * vocab = llama_model_get_vocab(model);
-
- slot_params params;
-
- // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
- slot_params defaults;
- defaults.sampling = params_base.sampling;
- defaults.speculative = params_base.speculative;
- defaults.n_keep = params_base.n_keep;
- defaults.n_predict = params_base.n_predict;
- defaults.antiprompt = params_base.antiprompt;
-
- // enabling this will output extra debug information in the HTTP responses from the server
- params.verbose = params_base.verbosity > 9;
- params.timings_per_token = json_value(data, "timings_per_token", false);
-
- params.stream = json_value(data, "stream", false);
- auto stream_opt = json_value(data, "stream_options", json::object());
- params.include_usage = json_value(stream_opt, "include_usage", false);
- params.cache_prompt = json_value(data, "cache_prompt", true);
- params.return_tokens = json_value(data, "return_tokens", false);
- params.return_progress = json_value(data, "return_progress", false);
- params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
- params.n_indent = json_value(data, "n_indent", defaults.n_indent);
- params.n_keep = json_value(data, "n_keep", defaults.n_keep);
- params.n_discard = json_value(data, "n_discard", defaults.n_discard);
- //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
- params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
- params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
-
- params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
- params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
- params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
- params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma);
- params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
- params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
- params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
- params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
- params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
- params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
- params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
- params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
- params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
- params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
- params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
- params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
- params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
- params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
- params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
- params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
- params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
- params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
- params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
- params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
- params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
-
- params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
- params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
- params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
-
- params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min);
- params.speculative.n_min = std::max(params.speculative.n_min, 0);
- params.speculative.n_max = std::max(params.speculative.n_max, 0);
-
- // Use OpenAI API logprobs only if n_probs wasn't provided
- if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){
- params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs);
- }
-
- if (data.contains("lora")) {
- if (data.at("lora").is_array()) {
- params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora"));
- } else {
- throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
- }
- } else {
- params.lora = params_base.lora_adapters;
- }
-
- // TODO: add more sanity checks for the input parameters
-
- if (params.sampling.penalty_last_n < -1) {
- throw std::runtime_error("Error: repeat_last_n must be >= -1");
- }
-
- if (params.sampling.dry_penalty_last_n < -1) {
- throw std::runtime_error("Error: dry_penalty_last_n must be >= -1");
- }
-
- if (params.sampling.penalty_last_n == -1) {
- // note: should be the slot's context and not the full context, but it's ok
- params.sampling.penalty_last_n = llama_n_ctx(ctx);
- }
-
- if (params.sampling.dry_penalty_last_n == -1) {
- params.sampling.dry_penalty_last_n = llama_n_ctx(ctx);
- }
-
- if (params.sampling.dry_base < 1.0f) {
- params.sampling.dry_base = defaults.sampling.dry_base;
- }
-
- // sequence breakers for DRY
- {
- // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
- // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
-
- if (data.contains("dry_sequence_breakers")) {
- params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
- if (params.sampling.dry_sequence_breakers.empty()) {
- throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings");
- }
- }
- }
-
- // process "json_schema" and "grammar"
- if (data.contains("json_schema") && !data.contains("grammar")) {
- try {
- auto schema = json_value(data, "json_schema", json::object());
- SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str());
- params.sampling.grammar = json_schema_to_grammar(schema);
- SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str());
- } catch (const std::exception & e) {
- throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
- }
- } else {
- params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
- SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str());
- params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy);
- SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false");
- }
-
- {
- auto it = data.find("chat_format");
- if (it != data.end()) {
- params.oaicompat_chat_syntax.format = static_cast<common_chat_format>(it->get<int>());
- SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format));
- } else {
- params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format;
- }
- common_reasoning_format reasoning_format = params_base.reasoning_format;
- if (data.contains("reasoning_format")) {
- reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get<std::string>());
- }
- params.oaicompat_chat_syntax.reasoning_format = reasoning_format;
- params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY);
- params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
- params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false);
- }
-
- {
- const auto preserved_tokens = data.find("preserved_tokens");
- if (preserved_tokens != data.end()) {
- for (const auto & t : *preserved_tokens) {
- auto ids = common_tokenize(vocab, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
- if (ids.size() == 1) {
- SRV_DBG("Preserved token: %d\n", ids[0]);
- params.sampling.preserved_tokens.insert(ids[0]);
- } else {
- // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
- SRV_DBG("Not preserved because more than 1 token: %s\n", t.get<std::string>().c_str());
- }
- }
- }
- const auto grammar_triggers = data.find("grammar_triggers");
- if (grammar_triggers != data.end()) {
- for (const auto & t : *grammar_triggers) {
- server_grammar_trigger ct(t);
- if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
- const auto & word = ct.value.value;
- auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true);
- if (ids.size() == 1) {
- auto token = ids[0];
- if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) {
- throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word);
- }
- SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str());
- common_grammar_trigger trigger;
- trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
- trigger.value = word;
- trigger.token = token;
- params.sampling.grammar_triggers.push_back(std::move(trigger));
- } else {
- SRV_DBG("Grammar trigger word: `%s`\n", word.c_str());
- params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
- }
- } else {
- if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) {
- SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str());
- } else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) {
- SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str());
- } else {
- throw std::runtime_error("Unknown grammar trigger type");
- }
- params.sampling.grammar_triggers.emplace_back(std::move(ct.value));
- }
- }
- }
- if (params.sampling.grammar_lazy && params.sampling.grammar_triggers.empty()) {
- throw std::runtime_error("Error: no triggers set for lazy grammar!");
- }
- }
-
- {
- params.sampling.logit_bias.clear();
-
- const auto & logit_bias = data.find("logit_bias");
- if (logit_bias != data.end() && logit_bias->is_array()) {
- const int n_vocab = llama_vocab_n_tokens(vocab);
- for (const auto & el : *logit_bias) {
- // TODO: we may want to throw errors here, in case "el" is incorrect
- if (el.is_array() && el.size() == 2) {
- float bias;
- if (el[1].is_number()) {
- bias = el[1].get<float>();
- } else if (el[1].is_boolean() && !el[1].get<bool>()) {
- bias = -INFINITY;
- } else {
- continue;
- }
-
- if (el[0].is_number_integer()) {
- llama_token tok = el[0].get<llama_token>();
- if (tok >= 0 && tok < n_vocab) {
- params.sampling.logit_bias.push_back({tok, bias});
- }
- } else if (el[0].is_string()) {
- auto toks = common_tokenize(vocab, el[0].get<std::string>(), false);
- for (auto tok : toks) {
- params.sampling.logit_bias.push_back({tok, bias});
- }
- }
- }
- }
- } else if (logit_bias != data.end() && logit_bias->is_object()) {
- const int n_vocab = llama_vocab_n_tokens(vocab);
- for (const auto & el : logit_bias->items()) {
- float bias;
- const auto & key = el.key();
- const auto & value = el.value();
- if (value.is_number()) {
- bias = value.get<float>();
- } else if (value.is_boolean() && !value.get<bool>()) {
- bias = -INFINITY;
- } else {
- continue;
- }
-
- char *end;
- llama_token tok = strtol(key.c_str(), &end, 10);
- if (*end == 0) {
- if (tok >= 0 && tok < n_vocab) {
- params.sampling.logit_bias.push_back({tok, bias});
- }
- } else {
- auto toks = common_tokenize(vocab, key, false);
- for (auto tok : toks) {
- params.sampling.logit_bias.push_back({tok, bias});
- }
- }
- }
- }
-
- params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos);
- if (params.sampling.ignore_eos) {
- params.sampling.logit_bias.insert(
- params.sampling.logit_bias.end(),
- defaults.sampling.logit_bias_eog.begin(), defaults.sampling.logit_bias_eog.end());
- }
- }
-
- {
- params.antiprompt.clear();
-
- const auto & stop = data.find("stop");
- if (stop != data.end() && stop->is_array()) {
- for (const auto & word : *stop) {
- if (!word.empty()) {
- params.antiprompt.push_back(word);
- }
- }
- }
- // set reverse prompt from cli args if not set in the request
- if (params.antiprompt.empty()) {
- params.antiprompt = defaults.antiprompt;
- }
- }
-
- {
- const auto samplers = data.find("samplers");
- if (samplers != data.end()) {
- if (samplers->is_array()) {
- params.sampling.samplers = common_sampler_types_from_names(*samplers, false);
- } else if (samplers->is_string()){
- params.sampling.samplers = common_sampler_types_from_chars(samplers->get<std::string>());
- }
- } else {
- params.sampling.samplers = defaults.sampling.samplers;
- }
- }
-
- std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
- params.oaicompat_model = json_value(data, "model", model_name);
-
- return params;
- }
-
- // utility function
- static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
- std::unordered_set<int> ids(tasks.size());
- for (size_t i = 0; i < tasks.size(); i++) {
- ids.insert(tasks[i].id);
- }
- return ids;
- }
-};
-
-struct result_timings {
- int32_t cache_n = -1;
-
- int32_t prompt_n = -1;
- double prompt_ms;
- double prompt_per_token_ms;
- double prompt_per_second;
-
- int32_t predicted_n = -1;
- double predicted_ms;
- double predicted_per_token_ms;
- double predicted_per_second;
-
- // Optional speculative metrics - only included when > 0
- int32_t draft_n = 0;
- int32_t draft_n_accepted = 0;
-
- json to_json() const {
- json base = {
- {"cache_n", cache_n},
-
- {"prompt_n", prompt_n},
- {"prompt_ms", prompt_ms},
- {"prompt_per_token_ms", prompt_per_token_ms},
- {"prompt_per_second", prompt_per_second},
-
- {"predicted_n", predicted_n},
- {"predicted_ms", predicted_ms},
- {"predicted_per_token_ms", predicted_per_token_ms},
- {"predicted_per_second", predicted_per_second},
- };
-
- if (draft_n > 0) {
- base["draft_n"] = draft_n;
- base["draft_n_accepted"] = draft_n_accepted;
- }
-
- return base;
- }
-};
-
-struct result_prompt_progress {
- int32_t total = 0;
- int32_t cache = 0;
- int32_t processed = 0;
- int64_t time_ms = 0;
-
- json to_json() const {
- return json {
- {"total", total},
- {"cache", cache},
- {"processed", processed},
- {"time_ms", time_ms},
- };
- }
-};
-
-struct server_task_result {
- int id = -1;
- int id_slot = -1;
- virtual bool is_error() {
- // only used by server_task_result_error
- return false;
- }
- virtual bool is_stop() {
- // only used by server_task_result_cmpl_*
- return true;
- }
- virtual int get_index() {
- return -1;
- }
- virtual json to_json() = 0;
- virtual ~server_task_result() = default;
-};
-
-// using shared_ptr for polymorphism of server_task_result
-using server_task_result_ptr = std::unique_ptr<server_task_result>;
-
-static inline std::string stop_type_to_str(stop_type type) {
- switch (type) {
- case STOP_TYPE_EOS: return "eos";
- case STOP_TYPE_WORD: return "word";
- case STOP_TYPE_LIMIT: return "limit";
- default: return "none";
- }
-}
-
-struct completion_token_output {
- llama_token tok;
- float prob;
- std::string text_to_send;
- struct prob_info {
- llama_token tok;
- std::string txt;
- float prob;
- };
- std::vector<prob_info> probs;
-
- json to_json(bool post_sampling_probs) const {
- json probs_for_token = json::array();
- for (const auto & p : probs) {
- std::string txt(p.txt);
- txt.resize(validate_utf8(txt));
- probs_for_token.push_back(json {
- {"id", p.tok},
- {"token", txt},
- {"bytes", str_to_bytes(p.txt)},
- {
- post_sampling_probs ? "prob" : "logprob",
- post_sampling_probs ? p.prob : logarithm(p.prob)
- },
- });
- }
- return probs_for_token;
- }
-
- static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs) {
- json out = json::array();
- for (const auto & p : probs) {
- std::string txt(p.text_to_send);
- txt.resize(validate_utf8(txt));
- out.push_back(json {
- {"id", p.tok},
- {"token", txt},
- {"bytes", str_to_bytes(p.text_to_send)},
- {
- post_sampling_probs ? "prob" : "logprob",
- post_sampling_probs ? p.prob : logarithm(p.prob)
- },
- {
- post_sampling_probs ? "top_probs" : "top_logprobs",
- p.to_json(post_sampling_probs)
- },
- });
- }
- return out;
- }
-
- static float logarithm(float x) {
- // nlohmann::json converts -inf to null, so we need to prevent that
- return x == 0.0f ? std::numeric_limits<float>::lowest() : std::log(x);
- }
-
- static std::vector<unsigned char> str_to_bytes(const std::string & str) {
- std::vector<unsigned char> bytes;
- for (unsigned char c : str) {
- bytes.push_back(c);
- }
- return bytes;
- }
-};
-
-struct server_task_result_cmpl_final : server_task_result {
- int index = 0;
-
- std::string content;
- llama_tokens tokens;
-
- bool stream;
- bool include_usage;
- result_timings timings;
- std::string prompt;
-
- bool truncated;
- int32_t n_decoded;
- int32_t n_prompt_tokens;
- int32_t n_tokens_cached;
- bool has_new_line;
- std::string stopping_word;
- stop_type stop = STOP_TYPE_NONE;
-
- bool post_sampling_probs;
- std::vector<completion_token_output> probs_output;
- std::vector<std::string> response_fields;
-
- slot_params generation_params;
-
- // OAI-compat fields
- bool verbose = false;
- oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
- std::string oaicompat_model;
- std::string oaicompat_cmpl_id;
- common_chat_msg oaicompat_msg;
-
- std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
-
- virtual int get_index() override {
- return index;
- }
-
- virtual bool is_stop() override {
- return true; // in stream mode, final responses are considered stop
- }
-
- virtual json to_json() override {
- switch (oaicompat) {
- case OAICOMPAT_TYPE_NONE:
- return to_json_non_oaicompat();
- case OAICOMPAT_TYPE_COMPLETION:
- return to_json_oaicompat();
- case OAICOMPAT_TYPE_CHAT:
- return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat();
- default:
- GGML_ASSERT(false && "Invalid oaicompat_type");
- }
- }
-
- json to_json_non_oaicompat() {
- json res = json {
- {"index", index},
- {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
- {"tokens", stream ? llama_tokens {} : tokens},
- {"id_slot", id_slot},
- {"stop", true},
- {"model", oaicompat_model},
- {"tokens_predicted", n_decoded},
- {"tokens_evaluated", n_prompt_tokens},
- {"generation_settings", generation_params.to_json()},
- {"prompt", prompt},
- {"has_new_line", has_new_line},
- {"truncated", truncated},
- {"stop_type", stop_type_to_str(stop)},
- {"stopping_word", stopping_word},
- {"tokens_cached", n_tokens_cached},
- {"timings", timings.to_json()},
- };
- if (!stream && !probs_output.empty()) {
- res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
- }
- return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
- }
-
- json to_json_oaicompat() {
- std::time_t t = std::time(0);
- json logprobs = json(nullptr); // OAI default to null
- if (!stream && probs_output.size() > 0) {
- logprobs = json{
- {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
- };
- }
- json finish_reason = "length";
- if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
- finish_reason = "stop";
- }
- json res = json {
- {"choices", json::array({
- json{
- {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
- {"index", index},
- {"logprobs", logprobs},
- {"finish_reason", finish_reason},
- }
- })},
- {"created", t},
- {"model", oaicompat_model},
- {"system_fingerprint", build_info},
- {"object", "text_completion"},
- {"usage", json {
- {"completion_tokens", n_decoded},
- {"prompt_tokens", n_prompt_tokens},
- {"total_tokens", n_decoded + n_prompt_tokens}
- }},
- {"id", oaicompat_cmpl_id}
- };
-
- // extra fields for debugging purposes
- if (verbose) {
- res["__verbose"] = to_json_non_oaicompat();
- }
- if (timings.prompt_n >= 0) {
- res.push_back({"timings", timings.to_json()});
- }
-
- return res;
- }
-
- json to_json_oaicompat_chat() {
- std::string finish_reason = "length";
- common_chat_msg msg;
- if (!oaicompat_msg.empty()) {
- msg = oaicompat_msg;
- } else {
- msg.role = "assistant";
- msg.content = content;
- }
- if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
- finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
- }
-
- json choice {
- {"finish_reason", finish_reason},
- {"index", 0},
- {"message", msg.to_json_oaicompat<json>()},
- };
-
- if (!stream && probs_output.size() > 0) {
- choice["logprobs"] = json{
- {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
- };
- }
-
- std::time_t t = std::time(0);
-
- json res = json {
- {"choices", json::array({choice})},
- {"created", t},
- {"model", oaicompat_model},
- {"system_fingerprint", build_info},
- {"object", "chat.completion"},
- {"usage", json {
- {"completion_tokens", n_decoded},
- {"prompt_tokens", n_prompt_tokens},
- {"total_tokens", n_decoded + n_prompt_tokens}
- }},
- {"id", oaicompat_cmpl_id}
- };
-
- // extra fields for debugging purposes
- if (verbose) {
- res["__verbose"] = to_json_non_oaicompat();
- }
- if (timings.prompt_n >= 0) {
- res.push_back({"timings", timings.to_json()});
- }
-
- return res;
- }
-
- json to_json_oaicompat_chat_stream() {
- std::time_t t = std::time(0);
- std::string finish_reason = "length";
- if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
- finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls";
- }
-
- json deltas = json::array();
- for (const auto & diff : oaicompat_msg_diffs) {
- deltas.push_back({
- {"choices", json::array({
- json {
- {"finish_reason", nullptr},
- {"index", 0},
- {"delta", common_chat_msg_diff_to_json_oaicompat<json>(diff)},
- },
- })},
- {"created", t},
- {"id", oaicompat_cmpl_id},
- {"model", oaicompat_model},
- {"system_fingerprint", build_info},
- {"object", "chat.completion.chunk"},
- });
- }
-
- deltas.push_back({
- {"choices", json::array({
- json {
- {"finish_reason", finish_reason},
- {"index", 0},
- {"delta", json::object()},
- },
- })},
- {"created", t},
- {"id", oaicompat_cmpl_id},
- {"model", oaicompat_model},
- {"system_fingerprint", build_info},
- {"object", "chat.completion.chunk"},
- });
-
- if (include_usage) {
- // OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage
- // https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices
- deltas.push_back({
- {"choices", json::array()},
- {"created", t},
- {"id", oaicompat_cmpl_id},
- {"model", oaicompat_model},
- {"system_fingerprint", build_info},
- {"object", "chat.completion.chunk"},
- {"usage", json {
- {"completion_tokens", n_decoded},
- {"prompt_tokens", n_prompt_tokens},
- {"total_tokens", n_decoded + n_prompt_tokens},
- }},
- });
- }
-
- if (timings.prompt_n >= 0) {
- deltas.back().push_back({"timings", timings.to_json()});
- }
-
- // extra fields for debugging purposes
- if (verbose && !deltas.empty()) {
- deltas.front()["__verbose"] = to_json_non_oaicompat();
- }
-
- return deltas;
- }
-};
-
-struct server_task_result_cmpl_partial : server_task_result {
- int index = 0;
-
- std::string content;
- llama_tokens tokens;
-
- int32_t n_decoded;
- int32_t n_prompt_tokens;
-
- bool post_sampling_probs;
- bool is_progress = false;
- completion_token_output prob_output;
- result_timings timings;
- result_prompt_progress progress;
-
- // OAI-compat fields
- bool verbose = false;
- oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
- std::string oaicompat_model;
- std::string oaicompat_cmpl_id;
- std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
-
- virtual int get_index() override {
- return index;
- }
-
- virtual bool is_stop() override {
- return false; // in stream mode, partial responses are not considered stop
- }
-
- virtual json to_json() override {
- switch (oaicompat) {
- case OAICOMPAT_TYPE_NONE:
- return to_json_non_oaicompat();
- case OAICOMPAT_TYPE_COMPLETION:
- return to_json_oaicompat();
- case OAICOMPAT_TYPE_CHAT:
- return to_json_oaicompat_chat();
- default:
- GGML_ASSERT(false && "Invalid oaicompat_type");
- }
- }
-
- json to_json_non_oaicompat() {
- // non-OAI-compat JSON
- json res = json {
- {"index", index},
- {"content", content},
- {"tokens", tokens},
- {"stop", false},
- {"id_slot", id_slot},
- {"tokens_predicted", n_decoded},
- {"tokens_evaluated", n_prompt_tokens},
- };
- // populate the timings object when needed (usually for the last response or with timings_per_token enabled)
- if (timings.prompt_n > 0) {
- res.push_back({"timings", timings.to_json()});
- }
- if (is_progress) {
- res.push_back({"prompt_progress", progress.to_json()});
- }
- if (!prob_output.probs.empty()) {
- res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs);
- }
- return res;
- }
-
- json to_json_oaicompat() {
- std::time_t t = std::time(0);
- json logprobs = json(nullptr); // OAI default to null
- if (prob_output.probs.size() > 0) {
- logprobs = json{
- {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
- };
- }
- json res = json {
- {"choices", json::array({
- json{
- {"text", content},
- {"index", index},
- {"logprobs", logprobs},
- {"finish_reason", nullptr},
- }
- })},
- {"created", t},
- {"model", oaicompat_model},
- {"system_fingerprint", build_info},
- {"object", "text_completion"},
- {"id", oaicompat_cmpl_id}
- };
-
- // extra fields for debugging purposes
- if (verbose) {
- res["__verbose"] = to_json_non_oaicompat();
- }
- if (timings.prompt_n >= 0) {
- res.push_back({"timings", timings.to_json()});
- }
- if (is_progress) {
- res.push_back({"prompt_progress", progress.to_json()});
- }
-
- return res;
- }
-
- json to_json_oaicompat_chat() {
- bool first = n_decoded == 1;
- std::time_t t = std::time(0);
- json choices;
-
- std::vector<json> deltas;
- auto add_delta = [&](const json & delta) {
- deltas.push_back({
- {"choices", json::array({
- json {
- {"finish_reason", nullptr},
- {"index", 0},
- {"delta", delta},
- },
- })},
- {"created", t},
- {"id", oaicompat_cmpl_id},
- {"model", oaicompat_model},
- {"system_fingerprint", build_info},
- {"object", "chat.completion.chunk"},
- });
- };
- // We have to send an initial update to conform to openai behavior
- if (first || is_progress) {
- add_delta({
- {"role", "assistant"},
- {"content", nullptr},
- });
- }
-
- for (const auto & diff : oaicompat_msg_diffs) {
- add_delta(common_chat_msg_diff_to_json_oaicompat<json>(diff));
- }
-
- if (!deltas.empty()) {
- auto & last_json = deltas[deltas.size() - 1];
- GGML_ASSERT(last_json.at("choices").size() >= 1);
-
- if (prob_output.probs.size() > 0) {
- last_json.at("choices").at(0)["logprobs"] = json {
- {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
- };
- }
-
- if (timings.prompt_n >= 0) {
- last_json.push_back({"timings", timings.to_json()});
- }
- if (is_progress) {
- last_json.push_back({"prompt_progress", progress.to_json()});
- }
- }
-
- return deltas;
- }
-};
-
-struct server_task_result_embd : server_task_result {
- int index = 0;
- std::vector<std::vector<float>> embedding;
-
- int32_t n_tokens;
-
- // OAI-compat fields
- oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
-
- virtual int get_index() override {
- return index;
- }
-
- virtual json to_json() override {
- return oaicompat == OAICOMPAT_TYPE_EMBEDDING
- ? to_json_oaicompat()
- : to_json_non_oaicompat();
- }
-
- json to_json_non_oaicompat() {
- return json {
- {"index", index},
- {"embedding", embedding},
- };
- }
-
- json to_json_oaicompat() {
- return json {
- {"index", index},
- {"embedding", embedding[0]},
- {"tokens_evaluated", n_tokens},
- };
- }
-};
-
-struct server_task_result_rerank : server_task_result {
- int index = 0;
- float score = -1e6;
-
- int32_t n_tokens;
-
- virtual int get_index() override {
- return index;
- }
-
- virtual json to_json() override {
- return json {
- {"index", index},
- {"score", score},
- {"tokens_evaluated", n_tokens},
- };
- }
-};
-
-// this function maybe used outside of server_task_result_error
-static json format_error_response(const std::string & message, const enum error_type type) {
- std::string type_str;
- int code = 500;
- switch (type) {
- case ERROR_TYPE_INVALID_REQUEST:
- type_str = "invalid_request_error";
- code = 400;
- break;
- case ERROR_TYPE_AUTHENTICATION:
- type_str = "authentication_error";
- code = 401;
- break;
- case ERROR_TYPE_NOT_FOUND:
- type_str = "not_found_error";
- code = 404;
- break;
- case ERROR_TYPE_SERVER:
- type_str = "server_error";
- code = 500;
- break;
- case ERROR_TYPE_PERMISSION:
- type_str = "permission_error";
- code = 403;
- break;
- case ERROR_TYPE_NOT_SUPPORTED:
- type_str = "not_supported_error";
- code = 501;
- break;
- case ERROR_TYPE_UNAVAILABLE:
- type_str = "unavailable_error";
- code = 503;
- break;
- case ERROR_TYPE_EXCEED_CONTEXT_SIZE:
- type_str = "exceed_context_size_error";
- code = 400;
- break;
- }
- return json {
- {"code", code},
- {"message", message},
- {"type", type_str},
- };
-}
-
-struct server_task_result_error : server_task_result {
- int index = 0;
- error_type err_type = ERROR_TYPE_SERVER;
- std::string err_msg;
-
- // for ERROR_TYPE_EXCEED_CONTEXT_SIZE
- int32_t n_prompt_tokens = 0;
- int32_t n_ctx = 0;
-
- virtual bool is_error() override {
- return true;
- }
-
- virtual json to_json() override {
- json res = format_error_response(err_msg, err_type);
- if (err_type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) {
- res["n_prompt_tokens"] = n_prompt_tokens;
- res["n_ctx"] = n_ctx;
- }
- return res;
- }
-};
-
-struct server_task_result_metrics : server_task_result {
- int n_idle_slots;
- int n_processing_slots;
- int n_tasks_deferred;
- int64_t t_start;
-
- // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields
- uint64_t n_prompt_tokens_processed_total = 0;
- uint64_t t_prompt_processing_total = 0;
- uint64_t n_tokens_predicted_total = 0;
- uint64_t t_tokens_generation_total = 0;
-
- uint64_t n_tokens_max = 0;
-
- uint64_t n_prompt_tokens_processed = 0;
- uint64_t t_prompt_processing = 0;
-
- uint64_t n_tokens_predicted = 0;
- uint64_t t_tokens_generation = 0;
-
- uint64_t n_decode_total = 0;
- uint64_t n_busy_slots_total = 0;
-
- // while we can also use std::vector<server_slot> this requires copying the slot object which can be quite messy
- // therefore, we use json to temporarily store the slot.to_json() result
- json slots_data = json::array();
-
- virtual json to_json() override {
- return json {
- { "idle", n_idle_slots },
- { "processing", n_processing_slots },
- { "deferred", n_tasks_deferred },
- { "t_start", t_start },
-
- { "n_prompt_tokens_processed_total", n_prompt_tokens_processed_total },
- { "t_tokens_generation_total", t_tokens_generation_total },
- { "n_tokens_predicted_total", n_tokens_predicted_total },
- { "t_prompt_processing_total", t_prompt_processing_total },
-
- { "n_tokens_max", n_tokens_max },
-
- { "n_prompt_tokens_processed", n_prompt_tokens_processed },
- { "t_prompt_processing", t_prompt_processing },
- { "n_tokens_predicted", n_tokens_predicted },
- { "t_tokens_generation", t_tokens_generation },
-
- { "n_decode_total", n_decode_total },
- { "n_busy_slots_total", n_busy_slots_total },
-
- { "slots", slots_data },
- };
- }
-};
-
-struct server_task_result_slot_save_load : server_task_result {
- std::string filename;
- bool is_save; // true = save, false = load
-
- size_t n_tokens;
- size_t n_bytes;
- double t_ms;
-
- virtual json to_json() override {
- if (is_save) {
- return json {
- { "id_slot", id_slot },
- { "filename", filename },
- { "n_saved", n_tokens },
- { "n_written", n_bytes },
- { "timings", {
- { "save_ms", t_ms }
- }},
- };
- }
-
- return json {
- { "id_slot", id_slot },
- { "filename", filename },
- { "n_restored", n_tokens },
- { "n_read", n_bytes },
- { "timings", {
- { "restore_ms", t_ms }
- }},
- };
- }
-};
-
-struct server_task_result_slot_erase : server_task_result {
- size_t n_erased;
-
- virtual json to_json() override {
- return json {
- { "id_slot", id_slot },
- { "n_erased", n_erased },
- };
- }
-};
-
-struct server_task_result_apply_lora : server_task_result {
- virtual json to_json() override {
- return json {{ "success", true }};
- }
-};
-
-struct server_prompt_checkpoint {
- llama_pos pos_min;
- llama_pos pos_max;
-
- std::vector<uint8_t> data;
-
- size_t size() const {
- return data.size();
- }
+// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
+enum slot_state {
+ SLOT_STATE_IDLE,
+ SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
+ SLOT_STATE_PROCESSING_PROMPT,
+ SLOT_STATE_DONE_PROMPT,
+ SLOT_STATE_GENERATING,
};
-struct server_prompt {
- server_tokens tokens;
-
- std::vector<uint8_t> data;
-
- std::list<server_prompt_checkpoint> checkpoints;
-
- size_t size() const {
- size_t res = data.size();
-
- for (const auto & checkpoint : checkpoints) {
- res += checkpoint.size();
- }
-
- return res;
- }
-
- int n_tokens() const {
- return tokens.size();
- }
+enum server_state {
+ SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
+ SERVER_STATE_READY, // Server is ready and model is loaded
};
-struct server_prompt_cache {
- server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) {
- this->limit_size = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib);
- this->limit_tokens = limit_tokens;
- }
-
- std::list<server_prompt> states;
-
- // in bytes, 0 = no limit
- size_t limit_size = 0;
-
- // in tokens, 0 = no limit
- size_t limit_tokens = 0;
-
- size_t size() const {
- size_t res = 0;
-
- for (const auto & state : states) {
- res += state.size();
- }
-
- return res;
- }
-
- size_t n_tokens() const {
- size_t res = 0;
-
- for (const auto & state : states) {
- res += state.n_tokens();
- }
-
- return res;
- }
-
- server_prompt * alloc(const server_prompt & prompt, size_t state_size) {
- // first check if the current state is contained fully in the cache
- for (auto it = states.begin(); it != states.end(); ++it) {
- const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens);
-
- if (cur_lcp_len == (int) prompt.tokens.size()) {
- SRV_WRN("%s", " - prompt is already in the cache, skipping\n");
- return nullptr;
- }
- }
-
- // next, remove any cached prompts that are fully contained in the current prompt
- for (auto it = states.begin(); it != states.end();) {
- const int len = it->tokens.get_common_prefix(prompt.tokens);
-
- if (len == (int) it->tokens.size()) {
- SRV_WRN(" - removing obsolete cached prompt with length %d\n", len);
-
- it = states.erase(it);
- } else {
- ++it;
- }
- }
-
- std::vector<uint8_t> state_data;
-
- // check if we can allocate enough memory for the new state
- try {
- state_data.resize(state_size);
- } catch (const std::bad_alloc & e) {
- SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what());
-
- limit_size = std::max<size_t>(1, 0.4*size());
-
- SRV_WRN(" - cache size limit reduced to %.3f MiB\n", limit_size / (1024.0 * 1024.0));
-
- update();
-
- return nullptr;
- }
-
- // TODO: for some reason we can't copy server_tokens, so we have to do this workaround
- auto & cur = states.emplace_back();
- cur = {
- /*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false),
- /*.data =*/ std::move(state_data),
- /*.checkpoints =*/ prompt.checkpoints,
- };
-
- return &cur;
- }
-
- bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) {
- const int lcp_best = prompt.tokens.get_common_prefix(tokens_new);
-
- float f_keep_best = float(lcp_best) / prompt.tokens.size();
- float sim_best = float(lcp_best) / tokens_new.size();
-
- SRV_WRN(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
-
- auto it_best = states.end();
-
- // find the most similar cached prompt, that would also preserve the most context
- for (auto it = states.begin(); it != states.end(); ++it) {
- const int lcp_cur = it->tokens.get_common_prefix(tokens_new);
-
- const float f_keep_cur = float(lcp_cur) / it->tokens.size();
- const float sim_cur = float(lcp_cur) / tokens_new.size();
-
- // don't trash large prompts
- if (f_keep_cur < 0.25f) {
- continue;
- }
-
- if (f_keep_best < f_keep_cur && sim_best < sim_cur) {
- f_keep_best = f_keep_cur;
- sim_best = sim_cur;
-
- it_best = it;
- }
- }
-
- if (it_best != states.end()) {
- SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
-
- const size_t size = it_best->data.size();
- const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0);
- if (n != size) {
- SRV_WRN("failed to restore state with size %zu\n", size);
-
- return false;
- }
-
- it_best->data.clear();
- it_best->data.shrink_to_fit();
-
- prompt = std::move(*it_best);
-
- states.erase(it_best);
- }
-
- return true;
+static bool server_task_type_need_embd(server_task_type task_type) {
+ switch (task_type) {
+ case SERVER_TASK_TYPE_EMBEDDING:
+ case SERVER_TASK_TYPE_RERANK:
+ return true;
+ default:
+ return false;
}
+}
- void update() {
- if (limit_size > 0) {
- // always keep at least one state, regardless of the limits
- while (states.size() > 1 && size() > limit_size) {
- if (states.empty()) {
- break;
- }
-
- SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0));
-
- states.pop_front();
- }
- }
-
- // average size per token
- const float size_per_token = std::max<float>(1.0f, float(size()) / (std::max<size_t>(1, n_tokens())));
-
- // dynamically increase the token limit if it can fit in the memory limit
- const size_t limit_tokens_cur = limit_size > 0 ? std::max<size_t>(limit_tokens, limit_size/size_per_token) : limit_tokens;
-
- if (limit_tokens > 0) {
- while (states.size() > 1 && n_tokens() > limit_tokens_cur) {
- if (states.empty()) {
- break;
- }
-
- SRV_WRN(" - cache token limit (%zu, est: %zu) reached, removing oldest entry (size = %.3f MiB)\n",
- limit_tokens, limit_tokens_cur, states.front().size() / (1024.0 * 1024.0));
-
- states.pop_front();
- }
- }
-
- SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n",
- states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur);
-
- for (const auto & state : states) {
- SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n",
- (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
- }
+static bool server_task_type_need_logits(server_task_type task_type) {
+ switch (task_type) {
+ case SERVER_TASK_TYPE_COMPLETION:
+ case SERVER_TASK_TYPE_INFILL:
+ return true;
+ default:
+ return false;
}
-};
+}
struct server_slot {
int id;
}
};
-struct server_queue {
- int id = 0;
- bool running;
-
- // queues
- std::deque<server_task> queue_tasks;
- std::deque<server_task> queue_tasks_deferred;
-
- std::mutex mutex_tasks;
- std::condition_variable condition_tasks;
-
- // callback functions
- std::function<void(server_task &&)> callback_new_task;
- std::function<void(void)> callback_update_slots;
-
- // Add a new task to the end of the queue
- int post(server_task && task, bool front = false) {
- std::unique_lock<std::mutex> lock(mutex_tasks);
- GGML_ASSERT(task.id != -1);
- // if this is cancel task make sure to clean up pending tasks
- if (task.type == SERVER_TASK_TYPE_CANCEL) {
- cleanup_pending_task(task.id_target);
- }
- const int task_id = task.id;
- QUE_DBG("new task, id = %d, front = %d\n", task_id, front);
- if (front) {
- queue_tasks.push_front(std::move(task));
- } else {
- queue_tasks.push_back(std::move(task));
- }
- condition_tasks.notify_one();
- return task_id;
- }
-
- // multi-task version of post()
- int post(std::vector<server_task> && tasks, bool front = false) {
- std::unique_lock<std::mutex> lock(mutex_tasks);
- for (auto & task : tasks) {
- if (task.id == -1) {
- task.id = id++;
- }
- // if this is cancel task make sure to clean up pending tasks
- if (task.type == SERVER_TASK_TYPE_CANCEL) {
- cleanup_pending_task(task.id_target);
- }
- QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front);
- if (front) {
- queue_tasks.push_front(std::move(task));
- } else {
- queue_tasks.push_back(std::move(task));
- }
- }
- condition_tasks.notify_one();
- return 0;
- }
-
- // Add a new task, but defer until one slot is available
- void defer(server_task && task) {
- std::unique_lock<std::mutex> lock(mutex_tasks);
- QUE_DBG("defer task, id = %d\n", task.id);
- queue_tasks_deferred.push_back(std::move(task));
- condition_tasks.notify_one();
- }
-
- // Get the next id for creating a new task
- int get_new_id() {
- std::unique_lock<std::mutex> lock(mutex_tasks);
- int new_id = id++;
- return new_id;
- }
-
- // Register function to process a new task
- void on_new_task(std::function<void(server_task &&)> callback) {
- callback_new_task = std::move(callback);
- }
-
- // Register the function to be called when all slots data is ready to be processed
- void on_update_slots(std::function<void(void)> callback) {
- callback_update_slots = std::move(callback);
- }
-
- // Call when the state of one slot is changed, it will move one task from deferred to main queue
- void pop_deferred_task() {
- std::unique_lock<std::mutex> lock(mutex_tasks);
- if (!queue_tasks_deferred.empty()) {
- queue_tasks.emplace_front(std::move(queue_tasks_deferred.front()));
- queue_tasks_deferred.pop_front();
- }
- condition_tasks.notify_one();
- }
-
- // end the start_loop routine
- void terminate() {
- std::unique_lock<std::mutex> lock(mutex_tasks);
- running = false;
- condition_tasks.notify_all();
- }
-
- /**
- * Main loop consists of these steps:
- * - Wait until a new task arrives
- * - Process the task (i.e. maybe copy data into slot)
- * - Check if multitask is finished
- * - Update all slots
- */
- void start_loop() {
- running = true;
-
- while (true) {
- QUE_DBG("%s", "processing new tasks\n");
-
- while (true) {
- std::unique_lock<std::mutex> lock(mutex_tasks);
- if (!running) {
- QUE_DBG("%s", "terminate\n");
- return;
- }
- if (queue_tasks.empty()) {
- lock.unlock();
- break;
- }
- server_task task = std::move(queue_tasks.front());
- queue_tasks.pop_front();
- lock.unlock();
-
- QUE_DBG("processing task, id = %d\n", task.id);
- callback_new_task(std::move(task));
- }
-
- // all tasks in the current loop is processed, slots data is now ready
- QUE_DBG("%s", "update slots\n");
-
- callback_update_slots();
-
- QUE_DBG("%s", "waiting for new tasks\n");
- {
- std::unique_lock<std::mutex> lock(mutex_tasks);
- if (!running) {
- QUE_DBG("%s", "terminate\n");
- return;
- }
- if (queue_tasks.empty()) {
- condition_tasks.wait(lock, [&]{
- return (!queue_tasks.empty() || !running);
- });
- }
- }
- }
- }
-
-private:
- void cleanup_pending_task(int id_target) {
- // no need lock because this is called exclusively by post()
- auto rm_func = [id_target](const server_task & task) {
- return task.id == id_target;
- };
- queue_tasks.erase(
- std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func),
- queue_tasks.end());
- queue_tasks_deferred.erase(
- std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func),
- queue_tasks_deferred.end());
- }
-};
-
-struct server_response {
- bool running = true;
-
- // for keeping track of all tasks waiting for the result
- std::unordered_set<int> waiting_task_ids;
-
- // the main result queue (using ptr for polymorphism)
- std::vector<server_task_result_ptr> queue_results;
-
- std::mutex mutex_results;
- std::condition_variable condition_results;
-
- // add the id_task to the list of tasks waiting for response
- void add_waiting_task_id(int id_task) {
- SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
-
- std::unique_lock<std::mutex> lock(mutex_results);
- waiting_task_ids.insert(id_task);
- }
-
- void add_waiting_tasks(const std::vector<server_task> & tasks) {
- std::unique_lock<std::mutex> lock(mutex_results);
-
- for (const auto & task : tasks) {
- SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size());
- waiting_task_ids.insert(task.id);
- }
- }
-
- // when the request is finished, we can remove task associated with it
- void remove_waiting_task_id(int id_task) {
- SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
-
- std::unique_lock<std::mutex> lock(mutex_results);
- waiting_task_ids.erase(id_task);
- // make sure to clean up all pending results
- queue_results.erase(
- std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) {
- return res->id == id_task;
- }),
- queue_results.end());
- }
-
- void remove_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
- std::unique_lock<std::mutex> lock(mutex_results);
-
- for (const auto & id_task : id_tasks) {
- SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
- waiting_task_ids.erase(id_task);
- }
- }
-
- // This function blocks the thread until there is a response for one of the id_tasks
- server_task_result_ptr recv(const std::unordered_set<int> & id_tasks) {
- while (true) {
- std::unique_lock<std::mutex> lock(mutex_results);
- condition_results.wait(lock, [&]{
- if (!running) {
- SRV_DBG("%s : queue result stop\n", __func__);
- std::terminate(); // we cannot return here since the caller is HTTP code
- }
- return !queue_results.empty();
- });
-
- for (size_t i = 0; i < queue_results.size(); i++) {
- if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
- server_task_result_ptr res = std::move(queue_results[i]);
- queue_results.erase(queue_results.begin() + i);
- return res;
- }
- }
- }
-
- // should never reach here
- }
-
- // same as recv(), but have timeout in seconds
- // if timeout is reached, nullptr is returned
- server_task_result_ptr recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout) {
- while (true) {
- std::unique_lock<std::mutex> lock(mutex_results);
-
- for (int i = 0; i < (int) queue_results.size(); i++) {
- if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
- server_task_result_ptr res = std::move(queue_results[i]);
- queue_results.erase(queue_results.begin() + i);
- return res;
- }
- }
-
- std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout));
- if (!running) {
- SRV_DBG("%s : queue result stop\n", __func__);
- std::terminate(); // we cannot return here since the caller is HTTP code
- }
- if (cr_res == std::cv_status::timeout) {
- return nullptr;
- }
- }
-
- // should never reach here
- }
-
- // single-task version of recv()
- server_task_result_ptr recv(int id_task) {
- std::unordered_set<int> id_tasks = {id_task};
- return recv(id_tasks);
- }
-
- // Send a new result to a waiting id_task
- void send(server_task_result_ptr && result) {
- SRV_DBG("sending result for task id = %d\n", result->id);
-
- std::unique_lock<std::mutex> lock(mutex_results);
- for (const auto & id_task : waiting_task_ids) {
- if (result->id == id_task) {
- SRV_DBG("task id = %d pushed to result queue\n", result->id);
-
- queue_results.emplace_back(std::move(result));
- condition_results.notify_all();
- return;
- }
- }
- }
-
- // terminate the waiting loop
- void terminate() {
- running = false;
- condition_results.notify_all();
- }
-};
-
struct server_context {
common_params params_base;
res->slots_data = std::move(slots_data);
res->n_idle_slots = n_idle_slots;
res->n_processing_slots = n_processing_slots;
- res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size();
+ res->n_tasks_deferred = queue_tasks.queue_tasks_deferred_size();
res->t_start = metrics.t_start;
res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total;
json default_generation_settings_for_props;
{
- slot_params params;
+ task_params params;
params.sampling = ctx_server.params_base.sampling;
std::string prompt = json_value(data, "prompt", std::string());
std::vector<server_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, false, true);
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
- data["prompt"] = format_infill(
+ data["prompt"] = format_prompt_infill(
ctx_server.vocab,
data.at("input_prefix"),
data.at("input_suffix"),
}
}
- const json data = format_tokenizer_response(tokens_response);
- res->ok(data);
+ res->ok(json{{"tokens", std::move(tokens_response)}});
return res;
};
std::string content;
if (body.count("tokens") != 0) {
const llama_tokens tokens = body.at("tokens");
- content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend());
+ content = tokens_to_str(ctx_server.ctx, tokens);
}
- const json data = format_detokenized_response(content);
- res->ok(data);
+ res->ok(json{{"content", std::move(content)}});
return res;
};
std::vector<server_task> tasks;
tasks.reserve(documents.size());
for (size_t i = 0; i < documents.size(); i++) {
- auto tmp = format_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]);
+ auto tmp = format_prompt_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]);
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
}
};
-std::function<void(int)> shutdown_handler;
-std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
+static std::function<void(int)> shutdown_handler;
+static std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
-inline void signal_handler(int signal) {
+static inline void signal_handler(int signal) {
if (is_terminating.test_and_set()) {
// in case it hangs, we can force terminate the server by hitting Ctrl+C twice
// this is for better developer experience, we can remove when the server is stable enough
ctx_server.queue_tasks.terminate();
};
+ // TODO: refactor in common/console
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action;
sigint_action.sa_handler = signal_handler;
+++ /dev/null
-#pragma once
-
-#include "common.h"
-#include "log.h"
-#include "llama.h"
-#include "arg.h" // common_remote_get_content
-#include "base64.hpp"
-#include "mtmd.h"
-#include "mtmd-helper.h"
-#include "chat.h"
-
-#define JSON_ASSERT GGML_ASSERT
-#include <nlohmann/json.hpp>
-
-#include <random>
-#include <sstream>
-#include <string>
-#include <vector>
-#include <memory>
-#include <cinttypes>
-
-#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo"
-
-using json = nlohmann::ordered_json;
-
-#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
-#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
-#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
-#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
-
-#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
-#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
-#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
-#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
-
-#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
-#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
-#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
-#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
-
-using raw_buffer = std::vector<uint8_t>;
-
-template <typename T>
-static T json_value(const json & body, const std::string & key, const T & default_value) {
- // Fallback null to default value
- if (body.contains(key) && !body.at(key).is_null()) {
- try {
- return body.at(key);
- } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const & err) {
- LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value: %s\n", key.c_str(), json(default_value).type_name(), err.what());
- return default_value;
- }
- } else {
- return default_value;
- }
-}
-
-const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT);
-
-// thin wrapper around common_grammar_trigger with (de)serialization functions
-struct server_grammar_trigger {
- common_grammar_trigger value;
-
- server_grammar_trigger() = default;
- server_grammar_trigger(const common_grammar_trigger & value) : value(value) {}
- server_grammar_trigger(const json & in) {
- value.type = (common_grammar_trigger_type) in.at("type").get<int>();
- value.value = in.at("value").get<std::string>();
- if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
- value.token = (llama_token) in.at("token").get<int>();
- }
- }
-
- json to_json() const {
- json out {
- {"type", (int) value.type},
- {"value", value.value},
- };
- if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
- out["token"] = (int) value.token;
- }
- return out;
- }
-};
-
-//
-// tokenizer and input processing utils
-//
-
-static bool json_is_array_of_numbers(const json & data) {
- if (data.is_array()) {
- for (const auto & e : data) {
- if (!e.is_number_integer()) {
- return false;
- }
- }
- return true;
- }
- return false;
-}
-
-// is array having BOTH numbers & strings?
-static bool json_is_array_of_mixed_numbers_strings(const json & data) {
- bool seen_string = false;
- bool seen_number = false;
- if (data.is_array()) {
- for (const auto & e : data) {
- seen_string |= e.is_string();
- seen_number |= e.is_number_integer();
- if (seen_number && seen_string) {
- return true;
- }
- }
- }
- return false;
-}
-
-// does array have any individual integers/tokens?
-static bool json_is_array_and_contains_numbers(const json & data) {
- if (data.is_array()) {
- for (const auto & e : data) {
- if (e.is_number_integer()) {
- return true;
- }
- }
- return false;
- }
- return false;
-}
-
-// get value by path(key1 / key2)
-static json json_get_nested_values(const std::vector<std::string> & paths, const json & js) {
- json result = json::object();
-
- for (const std::string & path : paths) {
- json current = js;
- const auto keys = string_split<std::string>(path, /*separator*/ '/');
- bool valid_path = true;
- for (const std::string & k : keys) {
- if (valid_path && current.is_object() && current.contains(k)) {
- current = current[k];
- } else {
- valid_path = false;
- }
- }
- if (valid_path) {
- result[path] = current;
- }
- }
- return result;
-}
-
-/**
- * this handles 2 cases:
- * - only string, example: "string"
- * - mixed string and tokens, example: [12, 34, "string", 56, 78]
- */
-static llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) {
- // If `add_bos` is true, we only add BOS, when json_prompt is a string,
- // or the first element of the json_prompt array is a string.
- llama_tokens prompt_tokens;
-
- if (json_prompt.is_array()) {
- bool first = true;
- for (const auto & p : json_prompt) {
- if (p.is_string()) {
- auto s = p.template get<std::string>();
-
- llama_tokens p;
- if (first) {
- p = common_tokenize(vocab, s, add_special, parse_special);
- first = false;
- } else {
- p = common_tokenize(vocab, s, false, parse_special);
- }
-
- prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
- } else {
- if (first) {
- first = false;
- }
-
- prompt_tokens.push_back(p.template get<llama_token>());
- }
- }
- } else {
- auto s = json_prompt.template get<std::string>();
- prompt_tokens = common_tokenize(vocab, s, add_special, parse_special);
- }
-
- return prompt_tokens;
-}
-
-// return the last index of character that can form a valid string
-// if the last character is potentially cut in half, return the index before the cut
-// if validate_utf8(text) == text.size(), then the whole text is valid utf8
-static size_t validate_utf8(const std::string& text) {
- size_t len = text.size();
- if (len == 0) return 0;
-
- // Check the last few bytes to see if a multi-byte character is cut off
- for (size_t i = 1; i <= 4 && i <= len; ++i) {
- unsigned char c = text[len - i];
- // Check for start of a multi-byte sequence from the end
- if ((c & 0xE0) == 0xC0) {
- // 2-byte character start: 110xxxxx
- // Needs at least 2 bytes
- if (i < 2) return len - i;
- } else if ((c & 0xF0) == 0xE0) {
- // 3-byte character start: 1110xxxx
- // Needs at least 3 bytes
- if (i < 3) return len - i;
- } else if ((c & 0xF8) == 0xF0) {
- // 4-byte character start: 11110xxx
- // Needs at least 4 bytes
- if (i < 4) return len - i;
- }
- }
-
- // If no cut-off multi-byte character is found, return full length
- return len;
-}
-
-//
-// template utils
-//
-
-// format infill task
-static llama_tokens format_infill(
- const llama_vocab * vocab,
- const json & input_prefix,
- const json & input_suffix,
- const json & input_extra,
- const int n_batch,
- const int n_predict,
- const int n_ctx,
- const bool spm_infill,
- const llama_tokens & tokens_prompt
- ) {
- // TODO: optimize this block by reducing memory allocations and movement
-
- // use FIM repo-level pattern:
- // ref: https://arxiv.org/pdf/2409.12186
- //
- // [FIM_REP]myproject
- // [FIM_SEP]filename0
- // extra chunk 0
- // [FIM_SEP]filename1
- // extra chunk 1
- // ...
- // [FIM_SEP]filename
- // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
- //
- llama_tokens extra_tokens;
- extra_tokens.reserve(n_ctx);
-
- auto tokens_prefix = tokenize_mixed(vocab, input_prefix, false, false);
- auto tokens_suffix = tokenize_mixed(vocab, input_suffix, false, false);
-
- if (llama_vocab_fim_rep(vocab) != LLAMA_TOKEN_NULL) {
- // TODO: make project name an input
- static const auto k_fim_repo = common_tokenize(vocab, "myproject\n", false, false);
-
- extra_tokens.push_back(llama_vocab_fim_rep(vocab));
- extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
- }
- for (const auto & chunk : input_extra) {
- // { "text": string, "filename": string }
- const std::string text = json_value(chunk, "text", std::string());
- const std::string filename = json_value(chunk, "filename", std::string("tmp"));
-
- if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) {
- const auto k_fim_file = common_tokenize(vocab, filename + "\n", false, false);
-
- extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab));
- extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
- } else {
- // chunk separator in binary form to avoid confusing the AI
- static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
- static const auto k_chunk_prefix_tokens = common_tokenize(vocab, k_chunk_prefix_str, false, false);
-
- extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
- }
-
- const auto chunk_tokens = common_tokenize(vocab, text, false, false);
- extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
- }
-
- if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) {
- // TODO: current filename
- static const auto k_fim_file = common_tokenize(vocab, "filename\n", false, false);
-
- extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab));
- extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
- }
-
- // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
- const int n_prefix_take = std::min<int>(tokens_prefix.size(), 3*(n_batch/4));
- const int n_suffix_take = std::min<int>(tokens_suffix.size(), std::max<int>(0, (n_batch/4) - (2 + tokens_prompt.size())));
-
- SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n", n_prefix_take, n_suffix_take, (n_prefix_take + n_suffix_take));
-
- // fill the rest of the context with extra chunks
- const int n_extra_take = std::min<int>(std::max<int>(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size());
-
- tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
- tokens_suffix.resize(n_suffix_take);
-
- tokens_prefix.insert(tokens_prefix.begin(), llama_vocab_fim_pre(vocab));
- tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
- tokens_suffix.insert(tokens_suffix.begin(), llama_vocab_fim_suf(vocab));
-
- auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix;
- auto embd_end = spm_infill ? tokens_prefix : tokens_suffix;
-
- if (llama_vocab_get_add_bos(vocab)) {
- embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab));
- }
-
- SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size());
-
- // put the extra context before the FIM prefix
- embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end());
-
- embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
- embd_inp.push_back(llama_vocab_fim_mid(vocab));
-
- return embd_inp;
-}
-
-//
-// base64 utils (TODO: move to common in the future)
-//
-
-static const std::string base64_chars =
- "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
- "abcdefghijklmnopqrstuvwxyz"
- "0123456789+/";
-
-static inline bool is_base64(uint8_t c) {
- return (isalnum(c) || (c == '+') || (c == '/'));
-}
-
-static inline raw_buffer base64_decode(const std::string & encoded_string) {
- int i = 0;
- int j = 0;
- int in_ = 0;
-
- int in_len = encoded_string.size();
-
- uint8_t char_array_4[4];
- uint8_t char_array_3[3];
-
- raw_buffer ret;
-
- while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) {
- char_array_4[i++] = encoded_string[in_]; in_++;
- if (i == 4) {
- for (i = 0; i < 4; i++) {
- char_array_4[i] = base64_chars.find(char_array_4[i]);
- }
-
- char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4);
- char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
- char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
-
- for (i = 0; (i < 3); i++) {
- ret.push_back(char_array_3[i]);
- }
-
- i = 0;
- }
- }
-
- if (i) {
- for (j = i; j < 4; j++) {
- char_array_4[j] = 0;
- }
-
- for (j = 0; j < 4; j++) {
- char_array_4[j] = base64_chars.find(char_array_4[j]);
- }
-
- char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4);
- char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
- char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
-
- for (j = 0; j < i - 1; j++) {
- ret.push_back(char_array_3[j]);
- }
- }
-
- return ret;
-}
-
-//
-// random string / id
-//
-
-static std::string random_string() {
- static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
-
- std::random_device rd;
- std::mt19937 generator(rd());
-
- std::string result(32, ' ');
-
- for (int i = 0; i < 32; ++i) {
- result[i] = str[generator() % str.size()];
- }
-
- return result;
-}
-
-static std::string gen_chatcmplid() {
- return "chatcmpl-" + random_string();
-}
-
-static std::string gen_tool_call_id() {
- return random_string();
-}
-
-//
-// other common utils
-//
-
-static std::string safe_json_to_str(const json & data) {
- return data.dump(-1, ' ', false, json::error_handler_t::replace);
-}
-
-// TODO: reuse llama_detokenize
-template <class Iter>
-static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
- std::string ret;
- for (; begin != end; ++begin) {
- ret += common_token_to_piece(ctx, *begin);
- }
-
- return ret;
-}
-
-// format incomplete utf-8 multibyte character for output
-static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) {
- std::string out = token == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(ctx, token);
-
- // if the size is 1 and first bit is 1, meaning it's a partial character
- // (size > 1 meaning it's already a known token)
- if (out.size() == 1 && (out[0] & 0x80) == 0x80) {
- std::stringstream ss;
- ss << std::hex << (out[0] & 0xff);
- std::string res(ss.str());
- out = "byte: \\x" + res;
- }
-
- return out;
-}
-
-// format server-sent event (SSE), return the formatted string to send
-// note: if data is a json array, it will be sent as multiple events, one per item
-static std::string format_sse(const json & data) {
- std::ostringstream ss;
- auto send_single = [&ss](const json & data) {
- ss << "data: " <<
- safe_json_to_str(data) <<
- "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row).
- };
-
- if (data.is_array()) {
- for (const auto & item : data) {
- send_single(item);
- }
- } else {
- send_single(data);
- }
-
- return ss.str();
-}
-
-//
-// OAI utils
-//
-
-// used by /completions endpoint
-static json oaicompat_completion_params_parse(const json & body) {
- json llama_params;
-
- if (!body.contains("prompt")) {
- throw std::runtime_error("\"prompt\" is required");
- }
-
- // Handle "stop" field
- if (body.contains("stop") && body.at("stop").is_string()) {
- llama_params["stop"] = json::array({body.at("stop").get<std::string>()});
- } else {
- llama_params["stop"] = json_value(body, "stop", json::array());
- }
-
- // Handle "n" field
- int n_choices = json_value(body, "n", 1);
- if (n_choices != 1) {
- throw std::runtime_error("Only one completion choice is allowed");
- }
-
- // Handle "echo" field
- if (json_value(body, "echo", false)) {
- throw std::runtime_error("Only no echo is supported");
- }
-
- // Params supported by OAI but unsupported by llama.cpp
- static const std::vector<std::string> unsupported_params { "best_of", "suffix" };
- for (const auto & param : unsupported_params) {
- if (body.contains(param)) {
- throw std::runtime_error("Unsupported param: " + param);
- }
- }
-
- // Copy remaining properties to llama_params
- for (const auto & item : body.items()) {
- // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"
- if (!llama_params.contains(item.key()) || item.key() == "n_predict") {
- llama_params[item.key()] = item.value();
- }
- }
-
- return llama_params;
-}
-
-struct oaicompat_parser_options {
- bool use_jinja;
- bool prefill_assistant;
- common_reasoning_format reasoning_format;
- std::map<std::string,std::string> chat_template_kwargs;
- common_chat_templates * tmpls;
- bool allow_image;
- bool allow_audio;
- bool enable_thinking = true;
-};
-
-// used by /chat/completions endpoint
-static json oaicompat_chat_params_parse(
- json & body, /* openai api json semantics */
- const oaicompat_parser_options & opt,
- std::vector<raw_buffer> & out_files)
-{
- json llama_params;
-
- auto tools = json_value(body, "tools", json());
- auto has_tools = tools.is_array() && !tools.empty();
- auto stream = json_value(body, "stream", false);
- auto tool_choice = json_value(body, "tool_choice", std::string("auto"));
-
- if (!opt.use_jinja) {
- if (has_tools) {
- throw std::runtime_error("tools param requires --jinja flag");
- }
- if (tool_choice != "auto") {
- throw std::runtime_error("tool_choice param requires --jinja flag");
- }
- }
-
- // Handle "stop" field
- if (body.contains("stop") && body.at("stop").is_string()) {
- llama_params["stop"] = json::array({body.at("stop").get<std::string>()});
- } else {
- llama_params["stop"] = json_value(body, "stop", json::array());
- }
-
- auto json_schema = json_value(body, "json_schema", json());
- auto grammar = json_value(body, "grammar", std::string());
- if (!json_schema.is_null() && !grammar.empty()) {
- throw std::runtime_error("Cannot use both json_schema and grammar");
- }
-
- // Handle "response_format" field
- if (body.contains("response_format")) {
- json response_format = json_value(body, "response_format", json::object());
- std::string response_type = json_value(response_format, "type", std::string());
- if (response_type == "json_object") {
- json_schema = json_value(response_format, "schema", json::object());
- } else if (response_type == "json_schema") {
- auto schema_wrapper = json_value(response_format, "json_schema", json::object());
- json_schema = json_value(schema_wrapper, "schema", json::object());
- } else if (!response_type.empty() && response_type != "text") {
- throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
- }
- }
-
- // get input files
- if (!body.contains("messages")) {
- throw std::runtime_error("'messages' is required");
- }
- json & messages = body.at("messages");
- if (!messages.is_array()) {
- throw std::runtime_error("Expected 'messages' to be an array");
- }
- for (auto & msg : messages) {
- std::string role = json_value(msg, "role", std::string());
- if (role != "assistant" && !msg.contains("content")) {
- throw std::runtime_error("All non-assistant messages must contain 'content'");
- }
- if (role == "assistant") {
- if (!msg.contains("content") && !msg.contains("tool_calls")) {
- throw std::runtime_error("Assistant message must contain either 'content' or 'tool_calls'!");
- }
- if (!msg.contains("content")) {
- continue; // avoid errors with no content
- }
- }
- json & content = msg.at("content");
- if (content.is_string() || content.is_null()) {
- continue;
- }
-
- if (!content.is_array()) {
- throw std::runtime_error("Expected 'content' to be a string or an array");
- }
-
- for (auto & p : content) {
- std::string type = json_value(p, "type", std::string());
- if (type == "image_url") {
- if (!opt.allow_image) {
- throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
- }
-
- json image_url = json_value(p, "image_url", json::object());
- std::string url = json_value(image_url, "url", std::string());
- if (string_starts_with(url, "http")) {
- // download remote image
- // TODO @ngxson : maybe make these params configurable
- common_remote_params params;
- params.headers.push_back("User-Agent: llama.cpp/" + build_info);
- params.max_size = 1024 * 1024 * 10; // 10MB
- params.timeout = 10; // seconds
- SRV_INF("downloading image from '%s'\n", url.c_str());
- auto res = common_remote_get_content(url, params);
- if (200 <= res.first && res.first < 300) {
- SRV_INF("downloaded %ld bytes\n", res.second.size());
- raw_buffer data;
- data.insert(data.end(), res.second.begin(), res.second.end());
- out_files.push_back(data);
- } else {
- throw std::runtime_error("Failed to download image");
- }
-
- } else {
- // try to decode base64 image
- std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
- if (parts.size() != 2) {
- throw std::runtime_error("Invalid image_url.url value");
- } else if (!string_starts_with(parts[0], "data:image/")) {
- throw std::runtime_error("Invalid image_url.url format: " + parts[0]);
- } else if (!string_ends_with(parts[0], "base64")) {
- throw std::runtime_error("image_url.url must be base64 encoded");
- } else {
- auto base64_data = parts[1];
- auto decoded_data = base64_decode(base64_data);
- out_files.push_back(decoded_data);
- }
- }
-
- // replace this chunk with a marker
- p["type"] = "text";
- p["text"] = mtmd_default_marker();
- p.erase("image_url");
-
- } else if (type == "input_audio") {
- if (!opt.allow_audio) {
- throw std::runtime_error("audio input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
- }
-
- json input_audio = json_value(p, "input_audio", json::object());
- std::string data = json_value(input_audio, "data", std::string());
- std::string format = json_value(input_audio, "format", std::string());
- // while we also support flac, we don't allow it here so we matches the OAI spec
- if (format != "wav" && format != "mp3") {
- throw std::runtime_error("input_audio.format must be either 'wav' or 'mp3'");
- }
- auto decoded_data = base64_decode(data); // expected to be base64 encoded
- out_files.push_back(decoded_data);
-
- // replace this chunk with a marker
- p["type"] = "text";
- p["text"] = mtmd_default_marker();
- p.erase("input_audio");
-
- } else if (type != "text") {
- throw std::runtime_error("unsupported content[].type");
- }
- }
- }
-
- common_chat_templates_inputs inputs;
- inputs.messages = common_chat_msgs_parse_oaicompat(messages);
- inputs.tools = common_chat_tools_parse_oaicompat(tools);
- inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice);
- inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
- inputs.grammar = grammar;
- inputs.use_jinja = opt.use_jinja;
- inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
- inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
- inputs.reasoning_format = opt.reasoning_format;
- inputs.enable_thinking = opt.enable_thinking;
- if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
- if (body.contains("grammar")) {
- throw std::runtime_error("Cannot use custom grammar constraints with tools.");
- }
- llama_params["parse_tool_calls"] = true;
- }
-
- // merge the template args provided from command line with the args provided in the user request
- auto chat_template_kwargs_object = json_value(body, "chat_template_kwargs", json::object());
- inputs.chat_template_kwargs = opt.chat_template_kwargs;
- for (const auto & item : chat_template_kwargs_object.items()) {
- inputs.chat_template_kwargs[item.key()] = item.value().dump();
- }
-
- // parse the "enable_thinking" kwarg to override the default value
- auto enable_thinking_kwarg = json_value(inputs.chat_template_kwargs, "enable_thinking", std::string(""));
- if (enable_thinking_kwarg == "true") {
- inputs.enable_thinking = true;
- } else if (enable_thinking_kwarg == "false") {
- inputs.enable_thinking = false;
- } else if (!enable_thinking_kwarg.empty() && enable_thinking_kwarg[0] == '"') {
- throw std::runtime_error("invalid type for \"enable_thinking\" (expected boolean, got string)");
- }
-
- // if the assistant message appears at the end of list, we do not add end-of-turn token
- // for ex. this can be useful to modify the reasoning process in reasoning models
- bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" && opt.prefill_assistant;
- common_chat_msg last_message;
- if (prefill_assistant_message) {
- last_message = inputs.messages.back();
- inputs.messages.pop_back();
-
- /* sanity check, max one assistant message at the end of the list */
- if (!inputs.messages.empty() && inputs.messages.back().role == "assistant"){
- throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list.");
- }
-
- /* TODO: test this properly */
- inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE;
-
- if ( inputs.enable_thinking ) {
- throw std::runtime_error("Assistant response prefill is incompatible with enable_thinking.");
- }
-
- inputs.add_generation_prompt = true;
- }
-
- // Apply chat template to the list of messages
- auto chat_params = common_chat_templates_apply(opt.tmpls, inputs);
-
- /* Append assistant prefilled message */
- if (prefill_assistant_message) {
- if (!last_message.content_parts.empty()) {
- for (auto & p : last_message.content_parts) {
- chat_params.prompt += p.text;
- }
- } else {
- chat_params.prompt += last_message.content;
- }
- }
-
- llama_params["chat_format"] = static_cast<int>(chat_params.format);
- llama_params["prompt"] = chat_params.prompt;
- if (!chat_params.grammar.empty()) {
- llama_params["grammar"] = chat_params.grammar;
- }
- llama_params["grammar_lazy"] = chat_params.grammar_lazy;
- auto grammar_triggers = json::array();
- for (const auto & trigger : chat_params.grammar_triggers) {
- server_grammar_trigger ct(trigger);
- grammar_triggers.push_back(ct.to_json());
- }
- llama_params["grammar_triggers"] = grammar_triggers;
- llama_params["preserved_tokens"] = chat_params.preserved_tokens;
- llama_params["thinking_forced_open"] = chat_params.thinking_forced_open;
- for (const auto & stop : chat_params.additional_stops) {
- llama_params["stop"].push_back(stop);
- }
-
- // Handle "n" field
- int n_choices = json_value(body, "n", 1);
- if (n_choices != 1) {
- throw std::runtime_error("Only one completion choice is allowed");
- }
-
- // Handle "logprobs" field
- // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
- if (json_value(body, "logprobs", false)) {
- if (has_tools && stream) {
- throw std::runtime_error("logprobs is not supported with tools + stream");
- }
- llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
- } else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) {
- throw std::runtime_error("top_logprobs requires logprobs to be set to true");
- }
-
- // Copy remaining properties to llama_params
- // This allows user to use llama.cpp-specific params like "mirostat", ... via OAI endpoint.
- // See "launch_slot_with_task()" for a complete list of params supported by llama.cpp
- for (const auto & item : body.items()) {
- // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"
- if (!llama_params.contains(item.key()) || item.key() == "n_predict") {
- llama_params[item.key()] = item.value();
- }
- }
-
- return llama_params;
-}
-
-static json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false) {
- json data = json::array();
- int32_t n_tokens = 0;
- int i = 0;
- for (const auto & elem : embeddings) {
- json embedding_obj;
-
- if (use_base64) {
- const auto& vec = json_value(elem, "embedding", json::array()).get<std::vector<float>>();
- const char* data_ptr = reinterpret_cast<const char*>(vec.data());
- size_t data_size = vec.size() * sizeof(float);
- embedding_obj = {
- {"embedding", base64::encode(data_ptr, data_size)},
- {"index", i++},
- {"object", "embedding"},
- {"encoding_format", "base64"}
- };
- } else {
- embedding_obj = {
- {"embedding", json_value(elem, "embedding", json::array())},
- {"index", i++},
- {"object", "embedding"}
- };
- }
- data.push_back(embedding_obj);
-
- n_tokens += json_value(elem, "tokens_evaluated", 0);
- }
-
- json res = json {
- {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
- {"object", "list"},
- {"usage", json {
- {"prompt_tokens", n_tokens},
- {"total_tokens", n_tokens}
- }},
- {"data", data}
- };
-
- return res;
-}
-
-static json format_response_rerank(
- const json & request,
- const json & ranks,
- bool is_tei_format,
- std::vector<std::string> & texts,
- int top_n) {
- int32_t n_tokens = 0;
- bool return_text = is_tei_format && json_value(request, "return_text", false);
- std::vector<json> elements; // Temporary vector to hold unsorted elements
- std::string score_label = is_tei_format ? "score" : "relevance_score";
- for (const auto & rank : ranks) {
- int index = json_value(rank, "index", 0);
- json elem = json{
- {"index", index},
- {score_label, json_value(rank, "score", 0.0)},
- };
- n_tokens += json_value(rank, "tokens_evaluated", 0);
- if (return_text) {
- elem["text"] = std::move(texts[index]);
- }
- elements.push_back(elem);
- }
-
- std::sort(elements.begin(), elements.end(), [score_label](const json& a, const json& b) {
- return json_value(a, score_label, 0.0) > json_value(b, score_label, 0.0);
- });
-
- elements.resize(std::min(top_n, (int)elements.size()));
- json results = elements;
-
- if (is_tei_format) return results;
-
- json res = json{
- {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
- {"object", "list"},
- {"usage", json{
- {"prompt_tokens", n_tokens},
- {"total_tokens", n_tokens}
- }},
- {"results", results}
- };
-
- return res;
-}
-
-static bool is_valid_utf8(const std::string & str) {
- const unsigned char* bytes = reinterpret_cast<const unsigned char*>(str.data());
- const unsigned char* end = bytes + str.length();
-
- while (bytes < end) {
- if (*bytes <= 0x7F) {
- // 1-byte sequence (0xxxxxxx)
- bytes++;
- } else if ((*bytes & 0xE0) == 0xC0) {
- // 2-byte sequence (110xxxxx 10xxxxxx)
- if (end - bytes < 2 || (bytes[1] & 0xC0) != 0x80)
- return false;
- bytes += 2;
- } else if ((*bytes & 0xF0) == 0xE0) {
- // 3-byte sequence (1110xxxx 10xxxxxx 10xxxxxx)
- if (end - bytes < 3 || (bytes[1] & 0xC0) != 0x80 || (bytes[2] & 0xC0) != 0x80)
- return false;
- bytes += 3;
- } else if ((*bytes & 0xF8) == 0xF0) {
- // 4-byte sequence (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx)
- if (end - bytes < 4 || (bytes[1] & 0xC0) != 0x80 ||
- (bytes[2] & 0xC0) != 0x80 || (bytes[3] & 0xC0) != 0x80)
- return false;
- bytes += 4;
- } else {
- // Invalid UTF-8 lead byte
- return false;
- }
- }
-
- return true;
-}
-
-static json format_tokenizer_response(const json & tokens) {
- return json {
- {"tokens", tokens}
- };
-}
-
-static json format_detokenized_response(const std::string & content) {
- return json {
- {"content", content}
- };
-}
-
-static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) {
- json data = json::array();
- for (const auto & lb : logit_bias) {
- data.push_back(json{
- {"bias", lb.bias},
- {"token", lb.token},
- });
- }
- return data;
-}
-
-static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) {
- std::vector<llama_token_data> cur;
- const auto * logits = llama_get_logits_ith(ctx, idx);
-
- const llama_model * model = llama_get_model(ctx);
- const llama_vocab * vocab = llama_model_get_vocab(model);
-
- const int n_vocab = llama_vocab_n_tokens(vocab);
-
- cur.resize(n_vocab);
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
- cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
- }
-
- // sort tokens by logits
- std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) {
- return a.logit > b.logit;
- });
-
- // apply softmax
- float max_l = cur[0].logit;
- float cum_sum = 0.0f;
- for (size_t i = 0; i < cur.size(); ++i) {
- float p = expf(cur[i].logit - max_l);
- cur[i].p = p;
- cum_sum += p;
- }
- for (size_t i = 0; i < cur.size(); ++i) {
- cur[i].p /= cum_sum;
- }
-
- return cur;
-}
-
-static bool are_lora_equal(
- const std::vector<common_adapter_lora_info> & l1,
- const std::vector<common_adapter_lora_info> & l2) {
- if (l1.size() != l2.size()) {
- return false;
- }
- for (size_t i = 0; i < l1.size(); ++i) {
- // we don't check lora.path to reduce the time complexity
- if (l1[i].scale != l2[i].scale || l1[i].ptr != l2[i].ptr) {
- return false;
- }
- }
- return true;
-}
-
-// get the ids of all enabled loras
-static std::vector<size_t> lora_get_enabled_ids(const std::vector<common_adapter_lora_info> & loras) {
- std::vector<size_t> enabled_ids;
- for (size_t i = 0; i < loras.size(); ++i) {
- if (loras[i].scale > 0) {
- enabled_ids.push_back(i);
- }
- }
- return enabled_ids;
-}
-
-// check whether the given lora set has only aloras activated (empty => false)
-static bool lora_all_alora(const std::vector<common_adapter_lora_info> & loras) {
- bool found_alora = false;
- for (const auto & lora : loras) {
- if (lora.scale != 0) {
- if (llama_adapter_get_alora_n_invocation_tokens(lora.ptr) == 0) {
- return false;
- }
- found_alora = true;
- }
- }
- return found_alora;
-}
-
-// if the two sets of loras are different, they require a cache clear unless the
-// change is only from aloras to aloras.
-static bool lora_should_clear_cache(
- const std::vector<common_adapter_lora_info> & current,
- const std::vector<common_adapter_lora_info> & next) {
-
- // This should always be called after determining that the two sets are
- // _not_ equal. This assert is therefore some slightly wasted work and
- // should be safe to remove as long as this method is called correctly.
- GGML_ASSERT(!are_lora_equal(current, next));
-
- return (
- !(lora_get_enabled_ids(current).empty() || lora_all_alora(current)) ||
- !lora_all_alora(next));
-}
-
-// parse lora config from JSON request, returned a copy of lora_base with updated scale
-static std::vector<common_adapter_lora_info> parse_lora_request(
- const std::vector<common_adapter_lora_info> & lora_base,
- const json & data) {
- std::vector<common_adapter_lora_info> lora(lora_base);
- int max_idx = lora.size();
-
- // clear existing value
- for (auto & entry : lora) {
- entry.scale = 0.0f;
- }
-
- // set value
- for (const auto & entry : data) {
- int id = json_value(entry, "id", -1);
- float scale = json_value(entry, "scale", 0.0f);
- if (0 <= id && id < max_idx) {
- lora[id].scale = scale;
- } else {
- throw std::runtime_error("invalid adapter id");
- }
- }
-
- return lora;
-}
-
-//
-// utils for interacting with libmtmd
-// (may need to refactor in near future)
-//
-
-/**
- * server_tokens is a helper to manage the input tokens and image for the server.
- * it is made this way to simplify the logic of KV cache management.
- */
-struct server_tokens {
- bool has_mtmd = false;
-
-private: // disallow accessing these members directly, risking out-of-sync
-
- // map a **start** index in tokens to the image chunk
- // note: the order need to be in-sync with tokens
- std::map<size_t, mtmd::input_chunk_ptr> map_idx_to_media;
-
- // list of tokens
- // if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk
- // otherwise, it is a normal text token
- // note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list
- // note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos
- llama_tokens tokens;
-
- // for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos):
- // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1]
- // idx 0 1 2 3 4 5 6 7 8 9 10
- // pos 0 1 2 3 4 5 5 5 7 7 7
- // map_idx_to_media will contain: {5, img0}, {8, img1}
-
-public:
- server_tokens() = default;
- ~server_tokens() = default;
-
- // Prevent copying
- // TODO: server_tokens should be copyable - remove this:
- server_tokens(const server_tokens&) = delete;
- server_tokens& operator=(const server_tokens&) = delete;
-
- // Allow moving (usually implicitly generated if members are movable)
- server_tokens(server_tokens&&) = default;
- server_tokens& operator=(server_tokens&&) = default;
-
- // Allow accessing elements using [] operator
- llama_token operator[](size_t index) { return tokens[index]; }
- const llama_token& operator[](size_t index) const { return tokens[index]; }
-
- server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) {
- for (size_t i = 0; i < mtmd_chunks.size(); ++i) {
- push_back(mtmd_chunks[i]);
- }
- }
-
- server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {
- }
-
- llama_pos pos_next() const {
- if (!has_mtmd) {
- return tokens.size();
- }
-
- llama_pos res = tokens.size();
-
- for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
- const auto & chunk = it->second;
- res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get());
- }
-
- return res;
- }
-
- // for debugging
- std::string str() const {
- std::ostringstream oss;
- oss << "tokens: ";
- for (size_t idx = 0; idx < tokens.size(); ++idx) {
- llama_token t = tokens[idx];
- oss << "idx:" << idx << " ";
- if (t == LLAMA_TOKEN_NULL) {
- oss << "<embd> ";
- } else {
- oss << t << " ";
- }
- }
- oss << "\n";
- oss << "image idx: ";
- for (const auto & it : map_idx_to_media) {
- oss << it.first << ", ";
- }
- return oss.str();
- }
-
- const mtmd::input_chunk_ptr & find_chunk(size_t idx) const {
- auto it = map_idx_to_media.find(idx);
- if (it != map_idx_to_media.end()) {
- return it->second;
- }
- throw std::runtime_error("Chunk not found");
- }
-
- void push_back(llama_token tok) {
- if (tok == LLAMA_TOKEN_NULL) {
- throw std::runtime_error("Invalid token");
- }
- tokens.emplace_back(tok);
- }
-
- // will create a copy of the chunk if it contains non-text data
- void push_back(const mtmd_input_chunk * chunk) {
- auto type = mtmd_input_chunk_get_type(chunk);
- if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
- GGML_ASSERT(has_mtmd);
- const size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk);
- size_t start_idx = tokens.size();
- for (size_t i = 0; i < n_tokens; ++i) {
- tokens.emplace_back(LLAMA_TOKEN_NULL);
- }
- mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
- map_idx_to_media[start_idx] = std::move(new_chunk);
- } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
- size_t n_tokens;
- const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
- for (size_t i = 0; i < n_tokens; ++i) {
- push_back(text_tokens[i]);
- }
- } else {
- GGML_ABORT("Invalid chunk type");
- }
- }
-
- // appends server tokens, updates the media map. copies media chunks.
- void push_back(server_tokens & tokens) {
- size_t start_idx = size();
- for (size_t i = 0; i < tokens.size(); i++) {
- push_back(tokens[i]);
- }
- if (tokens.has_mtmd) {
- // Assert if we are copying MTMD chunks to a server_tokens that does not have mtmd.
- // We could also just check, but this will prevent silently dropping MTMD data.
- GGML_ASSERT(has_mtmd);
- for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) {
- auto * chunk = tokens.map_idx_to_media[it->first].get();
- mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
- map_idx_to_media[start_idx + it->first] = std::move(new_chunk);
- }
- }
- }
-
- // for compatibility with context shift and prompt truncation
- void insert(const llama_tokens & inp_tokens) {
- GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
- tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end());
- }
-
- // for compatibility with speculative decoding, ctx shift, slot save/load
- const llama_tokens & get_text_tokens() const {
- GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
- return tokens;
- }
-
- // for compatibility with speculative decoding
- void set_token(llama_pos pos, llama_token id) {
- GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
- tokens[pos] = id;
- }
-
- size_t size() const {
- return tokens.size();
- }
-
- bool empty() const {
- return tokens.empty();
- }
-
- void clear() {
- map_idx_to_media.clear();
- tokens.clear();
- }
-
- void keep_first(size_t n) {
- GGML_ASSERT(n <= tokens.size());
- if (has_mtmd) {
- if (n == tokens.size()) {
- return; // nothing to do
- }
- // we throw an error if we try to remove a token in the middle of an image
- // for ex. with input of 5 text tokens and 2 images:
- // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1]
- // n 1 2 3 4 5 6 7 8 9 10
- // allowed to resize ^ ^
- // disallowed to resize ^ ^ ^
- if (n > 0) {
- // make sure we never remove tokens in the middle of an image
- // note that the case where we keep a full image at the end is allowed:
- // tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] != LLAMA_TOKEN_NULL
- if (tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] == LLAMA_TOKEN_NULL) {
- find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk
- }
- }
- // remove all image chunks that are not used anymore
- for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ) {
- size_t idx = it->first;
- if (idx >= n) {
- it = map_idx_to_media.erase(it);
- } else {
- ++it;
- }
- }
- }
- tokens.resize(n);
- }
-
- std::string detokenize(const llama_context * ctx, bool special) const {
- llama_tokens text_tokens;
- text_tokens.reserve(tokens.size());
- for (const auto & t : tokens) {
- if (t != LLAMA_TOKEN_NULL) {
- text_tokens.push_back(t);
- }
- }
- return common_detokenize(ctx, text_tokens, special);
- }
-
- size_t get_common_prefix(const server_tokens & b) const {
- const size_t max_idx = std::min(tokens.size(), b.tokens.size());
-
- if (!has_mtmd) {
- for (size_t i = 0; i < max_idx; ++i) {
- if (tokens[i] == b.tokens[i]) {
- continue;
- }
-
- return i;
- }
-
- return max_idx;
- }
-
- for (size_t i = 0; i < max_idx; ++i) {
- const llama_token ai = tokens[i];
- const llama_token bi = b.tokens[i];
-
- if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) {
- const auto & a_chunk = find_chunk(i);
- const auto & b_chunk = b.find_chunk(i);
-
- GGML_ASSERT(a_chunk && b_chunk);
-
- const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get());
- const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get());
-
- const size_t n_tok_a = mtmd_input_chunk_get_n_tokens(a_chunk.get());
- const size_t n_tok_b = mtmd_input_chunk_get_n_tokens(b_chunk.get());
-
- if (id_ai == id_bi && n_tok_a == n_tok_b) {
- GGML_ASSERT(n_tok_a > 0 && "Invalid media chunk"); // should never happen
- i += n_tok_a - 1; // will be +1 by the for loop
- continue;
- }
-
- return i;
- }
-
- if (ai == bi) {
- continue;
- }
-
- return i;
- }
-
- return max_idx; // all tokens are equal
- }
-
- // make sure all text tokens are within the vocab range
- bool validate(const struct llama_context * ctx) const {
- const llama_model * model = llama_get_model(ctx);
- const llama_vocab * vocab = llama_model_get_vocab(model);
- const int32_t n_vocab = llama_vocab_n_tokens(vocab);
-
- for (size_t i = 0; i < tokens.size(); ++i) {
- const auto & t = tokens[i];
- if (t == LLAMA_TOKEN_NULL) {
- try {
- const auto & chunk = find_chunk(i);
- size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk.get());
- i += n_tokens - 1; // will be +1 by the for loop
- } catch (const std::exception & e) {
- return false;
- }
- } else if (t < 0 || t >= n_vocab) {
- return false;
- }
- }
- return true;
- }
-
- // encode and decode the image chunk
- int32_t process_chunk(
- llama_context * ctx,
- mtmd_context * mctx,
- size_t idx,
- llama_pos pos,
- int32_t seq_id,
- size_t & n_tokens_out) const {
- const auto & chunk = find_chunk(idx);
- const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
- ? "image" : "audio";
- SRV_INF("processing %s...\n", name);
- int32_t n_batch = llama_n_batch(ctx);
- int64_t t0 = ggml_time_ms();
- llama_pos new_n_past; // unused for now
- int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx,
- chunk.get(),
- pos,
- seq_id,
- n_batch,
- true, // logits last
- &new_n_past);
- SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0);
- if (result != 0) {
- LOG_ERR("mtmd_helper_eval failed with status %d", result);
- n_tokens_out = 0;
- return result;
- }
- n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get());
- return 0;
- }
-};
-
-// Computes FNV-1a hash of the data
-static std::string fnv_hash(const uint8_t * data, size_t len) {
- const uint64_t fnv_prime = 0x100000001b3ULL;
- uint64_t hash = 0xcbf29ce484222325ULL;
-
- for (size_t i = 0; i < len; ++i) {
- hash ^= data[i];
- hash *= fnv_prime;
- }
- return std::to_string(hash);
-}
-
-static server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector<raw_buffer> files) {
- mtmd::bitmaps bitmaps;
- for (auto & file : files) {
- mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(mctx, file.data(), file.size()));
- if (!bmp.ptr) {
- throw std::runtime_error("Failed to load image or audio file");
- }
- // calculate bitmap hash (for KV caching)
- std::string hash = fnv_hash(bmp.data(), bmp.n_bytes());
- bmp.set_id(hash.c_str());
- bitmaps.entries.push_back(std::move(bmp));
- }
- // process prompt
- std::vector<server_tokens> inputs;
- // multimodal
- mtmd_input_text inp_txt = {
- prompt.c_str(),
- /* add_special */ true,
- /* parse_special */ true,
- };
- mtmd::input_chunks chunks(mtmd_input_chunks_init());
- auto bitmaps_c_ptr = bitmaps.c_ptr();
- int32_t tokenized = mtmd_tokenize(mctx,
- chunks.ptr.get(),
- &inp_txt,
- bitmaps_c_ptr.data(),
- bitmaps_c_ptr.size());
- if (tokenized != 0) {
- throw std::runtime_error("Failed to tokenize prompt");
- }
- auto result = server_tokens(chunks, true);
- return result;
-}
-
-/**
- * break the input "prompt" object into multiple prompt if needed, then tokenize them
- * use tokenize_input_prompts() if the input could be an array.
- * this supports these cases:
- * - "prompt": "string"
- * - "prompt": [12, 34, 56]
- * - "prompt": [12, 34, "string", 56, 78]
- * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] }
- */
-static server_tokens tokenize_input_subprompt(const llama_vocab * vocab, mtmd_context * mctx, const json & json_prompt, bool add_special, bool parse_special) {
- constexpr char JSON_STRING_PROMPT_KEY[] = "prompt_string";
- constexpr char JSON_MTMD_DATA_KEY[] = "multimodal_data";
- const bool has_mtmd = mctx != nullptr;
- if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) {
- // string or mixed
- llama_tokens tmp = tokenize_mixed(vocab, json_prompt, add_special, parse_special);
- return server_tokens(tmp, false);
- } else if (json_is_array_of_numbers(json_prompt)) {
- // array of tokens
- llama_tokens tmp = json_prompt.get<llama_tokens>();
- return server_tokens(tmp, false);
- } else if (json_prompt.contains(JSON_STRING_PROMPT_KEY)) {
- // JSON object with prompt key.
- if (json_prompt.contains(JSON_MTMD_DATA_KEY)) {
- if (!has_mtmd)
- throw std::runtime_error("Multimodal data provided, but model does not support multimodal requests.");
-
- // JSON object with prompt and multimodal key.
- std::vector<raw_buffer> files;
- for (const auto & entry : json_prompt.at(JSON_MTMD_DATA_KEY)) {
- files.push_back(base64_decode(entry));
- }
- return process_mtmd_prompt(mctx, json_prompt.at(JSON_STRING_PROMPT_KEY), files);
- } else {
- // Not multimodal, but contains a subobject.
- llama_tokens tmp = tokenize_mixed(vocab, json_prompt.at(JSON_STRING_PROMPT_KEY), add_special, parse_special);
- return server_tokens(tmp, false);
- }
- } else {
- throw std::runtime_error("\"prompt\" elements must be a string, a list of tokens, a JSON object containing a prompt string, or a list of mixed strings & tokens.");
- }
-}
-
-/**
- * break the input "prompt" object into multiple prompt if needed, then tokenize them
- * this supports these cases:
- * - "prompt": "string"
- * - "prompt": [12, 34, 56]
- * - "prompt": [12, 34, "string", 56, 78]
- * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] }
- * and multiple prompts (multi-tasks):
- * - "prompt": ["string1", "string2"]
- * - "prompt": ["string1", [12, 34, 56]]
- * - "prompt": [[12, 34, 56], [78, 90, 12]]
- * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56], { "prompt_string": "string", "multimodal_data": [ "base64" ]}]
- */
-static std::vector<server_tokens> tokenize_input_prompts(const llama_vocab * vocab, mtmd_context * mctx, const json & json_prompt, bool add_special, bool parse_special) {
- std::vector<server_tokens> result;
- if (json_prompt.is_array() && !json_is_array_and_contains_numbers(json_prompt)) {
- result.reserve(json_prompt.size());
- for (const auto & p : json_prompt) {
- result.push_back(tokenize_input_subprompt(vocab, mctx, p,add_special, parse_special));
- }
- } else {
- result.push_back(tokenize_input_subprompt(vocab, mctx, json_prompt, add_special, parse_special));
- }
- if (result.empty()) {
- throw std::runtime_error("\"prompt\" must not be empty");
- }
- return result;
-}
-
-// format rerank task: [BOS]query[EOS][SEP]doc[EOS].
-static server_tokens format_rerank(const struct llama_model * model, const struct llama_vocab * vocab, mtmd_context * mctx, const std::string & query, const std::string & doc) {
- server_tokens result = {};
-
- const char * rerank_prompt = llama_model_chat_template(model, "rerank");
-
- if (rerank_prompt != nullptr) {
- std::string prompt = rerank_prompt;
- string_replace_all(prompt, "{query}" , query);
- string_replace_all(prompt, "{document}", doc );
- server_tokens tokens = tokenize_input_subprompt(vocab, mctx, prompt, false, true);
- result.push_back(tokens);
- } else {
- // Get EOS token - use SEP token as fallback if EOS is not available
- server_tokens query_tokens = tokenize_input_subprompt(vocab, mctx, query, false, false);
- server_tokens doc_tokens = tokenize_input_subprompt(vocab, mctx, doc, false, false);
- llama_token eos_token = llama_vocab_eos(vocab);
- if (eos_token == LLAMA_TOKEN_NULL) {
- eos_token = llama_vocab_sep(vocab);
- }
-
- if (llama_vocab_get_add_bos(vocab)) {
- result.push_back(llama_vocab_bos(vocab));
- }
- result.push_back(query_tokens);
- if (llama_vocab_get_add_eos(vocab)) {
- result.push_back(eos_token);
- }
- if (llama_vocab_get_add_sep(vocab)) {
- result.push_back(llama_vocab_sep(vocab));
- }
- result.push_back(doc_tokens);
- if (llama_vocab_get_add_eos(vocab)) {
- result.push_back(eos_token);
- }
- }
-
- return result;
-}