"path to save slot kv cache (default: disabled)",
[](common_params & params, const std::string & value) {
params.slot_save_path = value;
+ if (!fs_is_directory(params.slot_save_path)) {
+ throw std::invalid_argument("not a directory: " + value);
+ }
// if doesn't end with DIRECTORY_SEPARATOR, add it
if (!params.slot_save_path.empty() && params.slot_save_path[params.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) {
params.slot_save_path += DIRECTORY_SEPARATOR;
}
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
+ add_opt(common_arg(
+ {"--media-path"}, "PATH",
+ "directory for loading local media files; files can be accessed via file:// URLs using relative paths (default: disabled)",
+ [](common_params & params, const std::string & value) {
+ params.media_path = value;
+ if (!fs_is_directory(params.media_path)) {
+ throw std::invalid_argument("not a directory: " + value);
+ }
+ // if doesn't end with DIRECTORY_SEPARATOR, add it
+ if (!params.media_path.empty() && params.media_path[params.media_path.size() - 1] != DIRECTORY_SEPARATOR) {
+ params.media_path += DIRECTORY_SEPARATOR;
+ }
+ }
+ ).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--models-dir"}, "PATH",
"directory containing models for the router server (default: disabled)",
// Validate if a filename is safe to use
// To validate a full path, split the path by the OS-specific path separator, and validate each part with this function
-bool fs_validate_filename(const std::string & filename) {
+bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
if (!filename.length()) {
// Empty filename invalid
return false;
|| (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
|| c == 0xFFFD // Replacement Character (UTF-8)
|| c == 0xFEFF // Byte Order Mark (BOM)
- || c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters
+ || c == ':' || c == '*' // Illegal characters
|| c == '?' || c == '"' || c == '<' || c == '>' || c == '|') {
return false;
}
+ if (!allow_subdirs && (c == '/' || c == '\\')) {
+ // Subdirectories not allowed, reject path separators
+ return false;
+ }
}
// Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
#endif // _WIN32
}
+bool fs_is_directory(const std::string & path) {
+ std::filesystem::path dir(path);
+ return std::filesystem::exists(dir) && std::filesystem::is_directory(dir);
+}
+
std::string fs_get_cache_directory() {
std::string cache_directory = "";
auto ensure_trailing_slash = [](std::string p) {
bool log_json = false;
std::string slot_save_path;
+ std::string media_path; // path to directory for loading media files
float slot_prompt_similarity = 0.1f;
// Filesystem utils
//
-bool fs_validate_filename(const std::string & filename);
+bool fs_validate_filename(const std::string & filename, bool allow_subdirs = false);
bool fs_create_directory_with_parents(const std::string & path);
+bool fs_is_directory(const std::string & path);
std::string fs_get_cache_directory();
std::string fs_get_cache_file(const std::string & filename);
#include <random>
#include <sstream>
+#include <fstream>
json format_error_response(const std::string & message, const enum error_type type) {
std::string type_str;
return llama_params;
}
+// media_path always end with '/', see arg.cpp
+static void handle_media(
+ std::vector<raw_buffer> & out_files,
+ json & media_obj,
+ const std::string & media_path) {
+ std::string url = json_value(media_obj, "url", std::string());
+ if (string_starts_with(url, "http")) {
+ // download remote image
+ // TODO @ngxson : maybe make these params configurable
+ common_remote_params params;
+ params.headers.push_back("User-Agent: llama.cpp/" + build_info);
+ params.max_size = 1024 * 1024 * 10; // 10MB
+ params.timeout = 10; // seconds
+ SRV_INF("downloading image from '%s'\n", url.c_str());
+ auto res = common_remote_get_content(url, params);
+ if (200 <= res.first && res.first < 300) {
+ SRV_INF("downloaded %ld bytes\n", res.second.size());
+ raw_buffer data;
+ data.insert(data.end(), res.second.begin(), res.second.end());
+ out_files.push_back(data);
+ } else {
+ throw std::runtime_error("Failed to download image");
+ }
+
+ } else if (string_starts_with(url, "file://")) {
+ if (media_path.empty()) {
+ throw std::invalid_argument("file:// URLs are not allowed unless --media-path is specified");
+ }
+ // load local image file
+ std::string file_path = url.substr(7); // remove "file://"
+ raw_buffer data;
+ if (!fs_validate_filename(file_path, true)) {
+ throw std::invalid_argument("file path is not allowed: " + file_path);
+ }
+ SRV_INF("loading image from local file '%s'\n", (media_path + file_path).c_str());
+ std::ifstream file(media_path + file_path, std::ios::binary);
+ if (!file) {
+ throw std::invalid_argument("file does not exist or cannot be opened: " + file_path);
+ }
+ data.assign((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
+ out_files.push_back(data);
+
+ } else {
+ // try to decode base64 image
+ std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
+ if (parts.size() != 2) {
+ throw std::runtime_error("Invalid url value");
+ } else if (!string_starts_with(parts[0], "data:image/")) {
+ throw std::runtime_error("Invalid url format: " + parts[0]);
+ } else if (!string_ends_with(parts[0], "base64")) {
+ throw std::runtime_error("url must be base64 encoded");
+ } else {
+ auto base64_data = parts[1];
+ auto decoded_data = base64_decode(base64_data);
+ out_files.push_back(decoded_data);
+ }
+ }
+}
+
// used by /chat/completions endpoint
json oaicompat_chat_params_parse(
json & body, /* openai api json semantics */
throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
}
- json image_url = json_value(p, "image_url", json::object());
- std::string url = json_value(image_url, "url", std::string());
- if (string_starts_with(url, "http")) {
- // download remote image
- // TODO @ngxson : maybe make these params configurable
- common_remote_params params;
- params.headers.push_back("User-Agent: llama.cpp/" + build_info);
- params.max_size = 1024 * 1024 * 10; // 10MB
- params.timeout = 10; // seconds
- SRV_INF("downloading image from '%s'\n", url.c_str());
- auto res = common_remote_get_content(url, params);
- if (200 <= res.first && res.first < 300) {
- SRV_INF("downloaded %ld bytes\n", res.second.size());
- raw_buffer data;
- data.insert(data.end(), res.second.begin(), res.second.end());
- out_files.push_back(data);
- } else {
- throw std::runtime_error("Failed to download image");
- }
-
- } else {
- // try to decode base64 image
- std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
- if (parts.size() != 2) {
- throw std::invalid_argument("Invalid image_url.url value");
- } else if (!string_starts_with(parts[0], "data:image/")) {
- throw std::invalid_argument("Invalid image_url.url format: " + parts[0]);
- } else if (!string_ends_with(parts[0], "base64")) {
- throw std::invalid_argument("image_url.url must be base64 encoded");
- } else {
- auto base64_data = parts[1];
- auto decoded_data = base64_decode(base64_data);
- out_files.push_back(decoded_data);
- }
- }
+ json image_url = json_value(p, "image_url", json::object());
+ handle_media(out_files, image_url, opt.media_path);
// replace this chunk with a marker
p["type"] = "text";
auto decoded_data = base64_decode(data); // expected to be base64 encoded
out_files.push_back(decoded_data);
+ // TODO: add audio_url support by reusing handle_media()
+
// replace this chunk with a marker
p["type"] = "text";
p["text"] = mtmd_default_marker();
bool allow_image;
bool allow_audio;
bool enable_thinking = true;
+ std::string media_path;
};
// used by /chat/completions endpoint
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
/* enable_thinking */ enable_thinking,
+ /* media_path */ params_base.media_path,
};
// print sample chat example to make it clear which template is used
try {
return func(req);
} catch (const std::invalid_argument & e) {
+ // treat invalid_argument as invalid request (400)
error = ERROR_TYPE_INVALID_REQUEST;
message = e.what();
} catch (const std::exception & e) {
+ // treat other exceptions as server error (500)
error = ERROR_TYPE_SERVER;
message = e.what();
} catch (...) {
assert res.status_code == 200
assert cors_header in res.headers
assert res.headers[cors_header] == cors_header_value
+
+
+@pytest.mark.parametrize(
+ "media_path, image_url, success",
+ [
+ (None, "file://mtmd/test-1.jpeg", False), # disabled media path, should fail
+ ("../../../tools", "file://mtmd/test-1.jpeg", True),
+ ("../../../tools", "file:////mtmd//test-1.jpeg", True), # should be the same file as above
+ ("../../../tools", "file://mtmd/notfound.jpeg", False), # non-existent file
+ ("../../../tools", "file://../mtmd/test-1.jpeg", False), # no directory traversal
+ ]
+)
+def test_local_media_file(media_path, image_url, success,):
+ server = ServerPreset.tinygemma3()
+ server.media_path = media_path
+ server.start()
+ res = server.make_request("POST", "/chat/completions", data={
+ "max_tokens": 1,
+ "messages": [
+ {"role": "user", "content": [
+ {"type": "text", "text": "test"},
+ {"type": "image_url", "image_url": {
+ "url": image_url,
+ }},
+ ]},
+ ],
+ })
+ if success:
+ assert res.status_code == 200
+ else:
+ assert res.status_code == 400
chat_template_file: str | None = None
server_path: str | None = None
mmproj_url: str | None = None
+ media_path: str | None = None
# session variables
process: subprocess.Popen | None = None
server_args.extend(["--chat-template-file", self.chat_template_file])
if self.mmproj_url:
server_args.extend(["--mmproj-url", self.mmproj_url])
+ if self.media_path:
+ server_args.extend(["--media-path", self.media_path])
args = [str(arg) for arg in [server_path, *server_args]]
print(f"tests: starting server with: {' '.join(args)}")