#endif // LLAMA_USE_CURL
+//
+// Docker registry functions
+//
+
+static std::string common_docker_get_token(const std::string & repo) {
+ std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull";
+
+ common_remote_params params;
+ auto res = common_remote_get_content(url, params);
+
+ if (res.first != 200) {
+ throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first));
+ }
+
+ std::string response_str(res.second.begin(), res.second.end());
+ nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str);
+
+ if (!response.contains("token")) {
+ throw std::runtime_error("Docker registry token response missing 'token' field");
+ }
+
+ return response["token"].get<std::string>();
+}
+
+static std::string common_docker_resolve_model(const std::string & docker) {
+ // Parse ai/smollm2:135M-Q4_K_M
+ size_t colon_pos = docker.find(':');
+ std::string repo, tag;
+ if (colon_pos != std::string::npos) {
+ repo = docker.substr(0, colon_pos);
+ tag = docker.substr(colon_pos + 1);
+ } else {
+ repo = docker;
+ tag = "latest";
+ }
+
+ // ai/ is the default
+ size_t slash_pos = docker.find('/');
+ if (slash_pos == std::string::npos) {
+ repo.insert(0, "ai/");
+ }
+
+ LOG_INF("%s: Downloading Docker Model: %s:%s\n", __func__, repo.c_str(), tag.c_str());
+ try {
+ // --- helper: digest validation ---
+ auto validate_oci_digest = [](const std::string & digest) -> std::string {
+ // Expected: algo:hex ; start with sha256 (64 hex chars)
+ // You can extend this map if supporting other algorithms in future.
+ static const std::regex re("^sha256:([a-fA-F0-9]{64})$");
+ std::smatch m;
+ if (!std::regex_match(digest, m, re)) {
+ throw std::runtime_error("Invalid OCI digest format received in manifest: " + digest);
+ }
+ // normalize hex to lowercase
+ std::string normalized = digest;
+ std::transform(normalized.begin()+7, normalized.end(), normalized.begin()+7, [](unsigned char c){
+ return std::tolower(c);
+ });
+ return normalized;
+ };
+
+ std::string token = common_docker_get_token(repo); // Get authentication token
+
+ // Get manifest
+ const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo;
+ std::string manifest_url = url_prefix + "/manifests/" + tag;
+ common_remote_params manifest_params;
+ manifest_params.headers.push_back("Authorization: Bearer " + token);
+ manifest_params.headers.push_back(
+ "Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json");
+ auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
+ if (manifest_res.first != 200) {
+ throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
+ }
+
+ std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end());
+ nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str);
+ std::string gguf_digest; // Find the GGUF layer
+ if (manifest.contains("layers")) {
+ for (const auto & layer : manifest["layers"]) {
+ if (layer.contains("mediaType")) {
+ std::string media_type = layer["mediaType"].get<std::string>();
+ if (media_type == "application/vnd.docker.ai.gguf.v3" ||
+ media_type.find("gguf") != std::string::npos) {
+ gguf_digest = layer["digest"].get<std::string>();
+ break;
+ }
+ }
+ }
+ }
+
+ if (gguf_digest.empty()) {
+ throw std::runtime_error("No GGUF layer found in Docker manifest");
+ }
+
+ // Validate & normalize digest
+ gguf_digest = validate_oci_digest(gguf_digest);
+ LOG_DBG("%s: Using validated digest: %s\n", __func__, gguf_digest.c_str());
+
+ // Prepare local filename
+ std::string model_filename = repo;
+ std::replace(model_filename.begin(), model_filename.end(), '/', '_');
+ model_filename += "_" + tag + ".gguf";
+ std::string local_path = fs_get_cache_file(model_filename);
+
+ const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
+ if (!common_download_file_single(blob_url, local_path, token, false)) {
+ throw std::runtime_error("Failed to download Docker Model");
+ }
+
+ LOG_INF("%s: Downloaded Docker Model to: %s\n", __func__, local_path.c_str());
+ return local_path;
+ } catch (const std::exception & e) {
+ LOG_ERR("%s: Docker Model download failed: %s\n", __func__, e.what());
+ throw;
+ }
+}
+
//
// utils
//
handle_model_result result;
// handle pre-fill default model path and url based on hf_repo and hf_file
{
- if (!model.hf_repo.empty()) {
+ if (!model.docker_repo.empty()) { // Handle Docker URLs by resolving them to local paths
+ model.path = common_docker_resolve_model(model.docker_repo);
+ } else if (!model.hf_repo.empty()) {
// short-hand to avoid specifying --hf-file -> default it to --model
if (model.hf_file.empty()) {
if (model.path.empty()) {
params.model.url = value;
}
).set_env("LLAMA_ARG_MODEL_URL"));
+ add_opt(common_arg(
+ { "-dr", "--docker-repo" }, "[<repo>/]<model>[:quant]",
+ "Docker Hub model repository. repo is optional, default to ai/. quant is optional, default to :latest.\n"
+ "example: gemma3\n"
+ "(default: unused)",
+ [](common_params & params, const std::string & value) {
+ params.model.docker_repo = value;
+ }
+ ).set_env("LLAMA_ARG_DOCKER_REPO"));
add_opt(common_arg(
{"-hf", "-hfr", "--hf-repo"}, "<user>/<model>[:quant]",
"Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n"