]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add docker protocol support for llama-server model loading (#15790)
authorEric Curtin <redacted>
Fri, 12 Sep 2025 15:31:50 +0000 (16:31 +0100)
committerGitHub <redacted>
Fri, 12 Sep 2025 15:31:50 +0000 (16:31 +0100)
To pull and run models via: llama-server -dr gemma3
Add some validators and sanitizers for Docker Model urls and metadata

Signed-off-by: Eric Curtin <redacted>
common/arg.cpp
common/common.h

index 406fbc2f06fe4d53b5357e9673360eb5db1a173a..6c293699a2760e47d2e34f021610fdaf7fb7c774 100644 (file)
@@ -745,6 +745,124 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
 
 #endif // LLAMA_USE_CURL
 
+//
+// Docker registry functions
+//
+
+static std::string common_docker_get_token(const std::string & repo) {
+    std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull";
+
+    common_remote_params params;
+    auto                 res = common_remote_get_content(url, params);
+
+    if (res.first != 200) {
+        throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first));
+    }
+
+    std::string            response_str(res.second.begin(), res.second.end());
+    nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str);
+
+    if (!response.contains("token")) {
+        throw std::runtime_error("Docker registry token response missing 'token' field");
+    }
+
+    return response["token"].get<std::string>();
+}
+
+static std::string common_docker_resolve_model(const std::string & docker) {
+    // Parse ai/smollm2:135M-Q4_K_M
+    size_t      colon_pos = docker.find(':');
+    std::string repo, tag;
+    if (colon_pos != std::string::npos) {
+        repo = docker.substr(0, colon_pos);
+        tag  = docker.substr(colon_pos + 1);
+    } else {
+        repo = docker;
+        tag  = "latest";
+    }
+
+    // ai/ is the default
+    size_t      slash_pos = docker.find('/');
+    if (slash_pos == std::string::npos) {
+        repo.insert(0, "ai/");
+    }
+
+    LOG_INF("%s: Downloading Docker Model: %s:%s\n", __func__, repo.c_str(), tag.c_str());
+    try {
+        // --- helper: digest validation ---
+        auto validate_oci_digest = [](const std::string & digest) -> std::string {
+            // Expected: algo:hex ; start with sha256 (64 hex chars)
+            // You can extend this map if supporting other algorithms in future.
+            static const std::regex re("^sha256:([a-fA-F0-9]{64})$");
+            std::smatch m;
+            if (!std::regex_match(digest, m, re)) {
+                throw std::runtime_error("Invalid OCI digest format received in manifest: " + digest);
+            }
+            // normalize hex to lowercase
+            std::string normalized = digest;
+            std::transform(normalized.begin()+7, normalized.end(), normalized.begin()+7, [](unsigned char c){
+                return std::tolower(c);
+            });
+            return normalized;
+        };
+
+        std::string token = common_docker_get_token(repo);  // Get authentication token
+
+        // Get manifest
+        const std::string    url_prefix = "https://registry-1.docker.io/v2/" + repo;
+        std::string          manifest_url = url_prefix + "/manifests/" + tag;
+        common_remote_params manifest_params;
+        manifest_params.headers.push_back("Authorization: Bearer " + token);
+        manifest_params.headers.push_back(
+            "Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json");
+        auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
+        if (manifest_res.first != 200) {
+            throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
+        }
+
+        std::string            manifest_str(manifest_res.second.begin(), manifest_res.second.end());
+        nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str);
+        std::string            gguf_digest;  // Find the GGUF layer
+        if (manifest.contains("layers")) {
+            for (const auto & layer : manifest["layers"]) {
+                if (layer.contains("mediaType")) {
+                    std::string media_type = layer["mediaType"].get<std::string>();
+                    if (media_type == "application/vnd.docker.ai.gguf.v3" ||
+                        media_type.find("gguf") != std::string::npos) {
+                        gguf_digest = layer["digest"].get<std::string>();
+                        break;
+                    }
+                }
+            }
+        }
+
+        if (gguf_digest.empty()) {
+            throw std::runtime_error("No GGUF layer found in Docker manifest");
+        }
+
+        // Validate & normalize digest
+        gguf_digest = validate_oci_digest(gguf_digest);
+        LOG_DBG("%s: Using validated digest: %s\n", __func__, gguf_digest.c_str());
+
+        // Prepare local filename
+        std::string model_filename = repo;
+        std::replace(model_filename.begin(), model_filename.end(), '/', '_');
+        model_filename += "_" + tag + ".gguf";
+        std::string local_path = fs_get_cache_file(model_filename);
+
+        const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
+        if (!common_download_file_single(blob_url, local_path, token, false)) {
+            throw std::runtime_error("Failed to download Docker Model");
+        }
+
+        LOG_INF("%s: Downloaded Docker Model to: %s\n", __func__, local_path.c_str());
+        return local_path;
+    } catch (const std::exception & e) {
+        LOG_ERR("%s: Docker Model download failed: %s\n", __func__, e.what());
+        throw;
+    }
+}
+
 //
 // utils
 //
@@ -795,7 +913,9 @@ static handle_model_result common_params_handle_model(
     handle_model_result result;
     // handle pre-fill default model path and url based on hf_repo and hf_file
     {
-        if (!model.hf_repo.empty()) {
+        if (!model.docker_repo.empty()) {  // Handle Docker URLs by resolving them to local paths
+            model.path = common_docker_resolve_model(model.docker_repo);
+        } else if (!model.hf_repo.empty()) {
             // short-hand to avoid specifying --hf-file -> default it to --model
             if (model.hf_file.empty()) {
                 if (model.path.empty()) {
@@ -2636,6 +2756,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.model.url = value;
         }
     ).set_env("LLAMA_ARG_MODEL_URL"));
+    add_opt(common_arg(
+        { "-dr", "--docker-repo" }, "[<repo>/]<model>[:quant]",
+        "Docker Hub model repository. repo is optional, default to ai/. quant is optional, default to :latest.\n"
+        "example: gemma3\n"
+        "(default: unused)",
+        [](common_params & params, const std::string & value) {
+            params.model.docker_repo = value;
+        }
+    ).set_env("LLAMA_ARG_DOCKER_REPO"));
     add_opt(common_arg(
         {"-hf", "-hfr", "--hf-repo"}, "<user>/<model>[:quant]",
         "Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n"
index 740168743aec94027741a44759a53895714fcb0f..cf57d48415bd1ea9d51b8dfa0b218ec0224438d0 100644 (file)
@@ -193,10 +193,11 @@ struct common_params_sampling {
 };
 
 struct common_params_model {
-    std::string path    = ""; // model local path                                           // NOLINT
-    std::string url     = ""; // model url to download                                      // NOLINT
-    std::string hf_repo = ""; // HF repo                                                    // NOLINT
-    std::string hf_file = ""; // HF file                                                    // NOLINT
+    std::string path        = ""; // model local path                                       // NOLINT
+    std::string url         = ""; // model url to download                                  // NOLINT
+    std::string hf_repo     = ""; // HF repo                                                // NOLINT
+    std::string hf_file     = ""; // HF file                                                // NOLINT
+    std::string docker_repo = ""; // Docker repo                                            // NOLINT
 };
 
 struct common_params_speculative {