]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
common : add common_remote_get_content (#13123)
authorXuan-Son Nguyen <redacted>
Sat, 26 Apr 2025 20:58:12 +0000 (22:58 +0200)
committerGitHub <redacted>
Sat, 26 Apr 2025 20:58:12 +0000 (22:58 +0200)
* common : add common_remote_get_content

* support max size and timeout

* add tests

common/arg.cpp
common/arg.h
tests/test-arg-parser.cpp

index 0657553e4e9cf645ef9d3a4932d374613208e08f..de173159f4a76b9a30676658687794e083dc5abf 100644 (file)
@@ -162,6 +162,10 @@ struct common_hf_file_res {
 
 #ifdef LLAMA_USE_CURL
 
+bool common_has_curl() {
+    return true;
+}
+
 #ifdef __linux__
 #include <linux/limits.h>
 #elif defined(_WIN32)
@@ -527,64 +531,89 @@ static bool common_download_model(
     return true;
 }
 
-/**
- * Allow getting the HF file from the HF repo with tag (like ollama), for example:
- * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
- * - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M
- * - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s
- * Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo)
- *
- * Return pair of <repo, file> (with "repo" already having tag removed)
- *
- * Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
- */
-static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token) {
-    auto parts = string_split<std::string>(hf_repo_with_tag, ':');
-    std::string tag = parts.size() > 1 ? parts.back() : "latest";
-    std::string hf_repo = parts[0];
-    if (string_split<std::string>(hf_repo, '/').size() != 2) {
-        throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
-    }
-
-    // fetch model info from Hugging Face Hub API
+std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params) {
     curl_ptr       curl(curl_easy_init(), &curl_easy_cleanup);
     curl_slist_ptr http_headers;
-    std::string res_str;
+    std::vector<char> res_buffer;
 
-    std::string model_endpoint = get_model_endpoint();
-
-    std::string url = model_endpoint + "v2/" + hf_repo + "/manifests/" + tag;
     curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
     curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
+    curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
     typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
     auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
-        static_cast<std::string *>(data)->append((char * ) ptr, size * nmemb);
+        auto data_vec = static_cast<std::vector<char> *>(data);
+        data_vec->insert(data_vec->end(), (char *)ptr, (char *)ptr + size * nmemb);
         return size * nmemb;
     };
     curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
-    curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str);
+    curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_buffer);
 #if defined(_WIN32)
     curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
 #endif
-    if (!bearer_token.empty()) {
-        std::string auth_header = "Authorization: Bearer " + bearer_token;
-        http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
+    if (params.timeout > 0) {
+        curl_easy_setopt(curl.get(), CURLOPT_TIMEOUT, params.timeout);
+    }
+    if (params.max_size > 0) {
+        curl_easy_setopt(curl.get(), CURLOPT_MAXFILESIZE, params.max_size);
     }
-    // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
     http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
-    http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json");
+    for (const auto & header : params.headers) {
+        http_headers.ptr = curl_slist_append(http_headers.ptr, header.c_str());
+    }
     curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
 
     CURLcode res = curl_easy_perform(curl.get());
 
     if (res != CURLE_OK) {
-        throw std::runtime_error("error: cannot make GET request to HF API");
+        std::string error_msg = curl_easy_strerror(res);
+        throw std::runtime_error("error: cannot make GET request: " + error_msg);
     }
 
     long res_code;
-    std::string ggufFile   = "";
-    std::string mmprojFile = "";
     curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code);
+
+    return { res_code, std::move(res_buffer) };
+}
+
+/**
+ * Allow getting the HF file from the HF repo with tag (like ollama), for example:
+ * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
+ * - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M
+ * - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s
+ * Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo)
+ *
+ * Return pair of <repo, file> (with "repo" already having tag removed)
+ *
+ * Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
+ */
+static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token) {
+    auto parts = string_split<std::string>(hf_repo_with_tag, ':');
+    std::string tag = parts.size() > 1 ? parts.back() : "latest";
+    std::string hf_repo = parts[0];
+    if (string_split<std::string>(hf_repo, '/').size() != 2) {
+        throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
+    }
+
+    std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag;
+
+    // headers
+    std::vector<std::string> headers;
+    headers.push_back("Accept: application/json");
+    if (!bearer_token.empty()) {
+        headers.push_back("Authorization: Bearer " + bearer_token);
+    }
+    // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
+    // User-Agent header is already set in common_remote_get_content, no need to set it here
+
+    // make the request
+    common_remote_params params;
+    params.headers = headers;
+    auto res = common_remote_get_content(url, params);
+    long res_code = res.first;
+    std::string res_str(res.second.data(), res.second.size());
+    std::string ggufFile;
+    std::string mmprojFile;
+
     if (res_code == 200) {
         // extract ggufFile.rfilename in json, using regex
         {
@@ -618,6 +647,10 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_
 
 #else
 
+bool common_has_curl() {
+    return false;
+}
+
 static bool common_download_file_single(const std::string &, const std::string &, const std::string &) {
     LOG_ERR("error: built without CURL, cannot download model from internet\n");
     return false;
@@ -640,6 +673,10 @@ static struct common_hf_file_res common_get_hf_file(const std::string &, const s
     return {};
 }
 
+std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params) {
+    throw std::runtime_error("error: built without CURL, cannot download model from the internet");
+}
+
 #endif // LLAMA_USE_CURL
 
 //
index 49ab8667b10527c05d1c6aed3593ad78170518c0..70bea100fd4f268dbb31e84e3c1dae664102f8c2 100644 (file)
@@ -78,3 +78,12 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
 
 // function to be used by test-arg-parser
 common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
+bool common_has_curl();
+
+struct common_remote_params {
+    std::vector<std::string> headers;
+    long timeout = 0; // CURLOPT_TIMEOUT, in seconds ; 0 means no timeout
+    long max_size = 0; // max size of the response ; unlimited if 0 ; max is 2GB
+};
+// get remote file content, returns <http_code, raw_response_body>
+std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params);
index 537fc63a4c97521c14c3cb76b9af5571f7c41779..21dbd5404222f6fe373dadfed464aea3664a9c48 100644 (file)
@@ -126,6 +126,53 @@ int main(void) {
     assert(params.cpuparams.n_threads == 1010);
 #endif // _WIN32
 
+    if (common_has_curl()) {
+        printf("test-arg-parser: test curl-related functions\n\n");
+        const char * GOOD_URL = "https://raw.githubusercontent.com/ggml-org/llama.cpp/refs/heads/master/README.md";
+        const char * BAD_URL  = "https://www.google.com/404";
+        const char * BIG_FILE = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v1.bin";
+
+        {
+            printf("test-arg-parser: test good URL\n\n");
+            auto res = common_remote_get_content(GOOD_URL, {});
+            assert(res.first == 200);
+            assert(res.second.size() > 0);
+            std::string str(res.second.data(), res.second.size());
+            assert(str.find("llama.cpp") != std::string::npos);
+        }
+
+        {
+            printf("test-arg-parser: test bad URL\n\n");
+            auto res = common_remote_get_content(BAD_URL, {});
+            assert(res.first == 404);
+        }
+
+        {
+            printf("test-arg-parser: test max size error\n");
+            common_remote_params params;
+            params.max_size = 1;
+            try {
+                common_remote_get_content(GOOD_URL, params);
+                assert(false && "it should throw an error");
+            } catch (std::exception & e) {
+                printf("  expected error: %s\n\n", e.what());
+            }
+        }
+
+        {
+            printf("test-arg-parser: test timeout error\n");
+            common_remote_params params;
+            params.timeout = 1;
+            try {
+                common_remote_get_content(BIG_FILE, params);
+                assert(false && "it should throw an error");
+            } catch (std::exception & e) {
+                printf("  expected error: %s\n\n", e.what());
+            }
+        }
+    } else {
+        printf("test-arg-parser: no curl, skipping curl-related functions\n");
+    }
 
     printf("test-arg-parser: all tests OK\n\n");
 }