]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
feat: Implements retrying logic for downloading models using --model-url flag (#9255)
authorFarbod Bijary <redacted>
Wed, 11 Sep 2024 09:22:37 +0000 (12:52 +0330)
committerGitHub <redacted>
Wed, 11 Sep 2024 09:22:37 +0000 (11:22 +0200)
* feat: Implements retrying logic for downloading models using --model-url flag

* Update common/common.cpp

Co-authored-by: Xuan Son Nguyen <redacted>
* Update common/common.cpp

Co-authored-by: Xuan Son Nguyen <redacted>
* apply comments

* implements a retry function to avoid duplication

* fix editorconfig

* change function name

---------

Co-authored-by: farbod <redacted>
Co-authored-by: Xuan Son Nguyen <redacted>
Co-authored-by: slaren <redacted>
Co-authored-by: Xuan Son Nguyen <redacted>
common/common.cpp
lora-tests [new submodule]

index d572d2408270324ba84a602e9c57418b1d6973e0..30c6e84c795f71d7b2e8bcd11c14c2773bffd018 100644 (file)
@@ -941,11 +941,37 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p
 
 #ifdef LLAMA_USE_CURL
 
+#define CURL_MAX_RETRY 3
+#define CURL_RETRY_DELAY_SECONDS 2
+
+
 static bool starts_with(const std::string & str, const std::string & prefix) {
     // While we wait for C++20's std::string::starts_with...
     return str.rfind(prefix, 0) == 0;
 }
 
+static bool curl_perform_with_retry(const std::string& url, CURL* curl, int max_attempts, int retry_delay_seconds) {
+    int remaining_attempts = max_attempts;
+
+    while (remaining_attempts > 0) {
+        fprintf(stderr, "%s: Trying to download from %s (attempt %d of %d)...\n", __func__ , url.c_str(), max_attempts - remaining_attempts + 1, max_attempts);
+
+        CURLcode res = curl_easy_perform(curl);
+        if (res == CURLE_OK) {
+            return true;
+        }
+
+        int exponential_backoff_delay = std::pow(retry_delay_seconds, max_attempts - remaining_attempts) * 1000;
+        fprintf(stderr, "%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay);
+
+        remaining_attempts--;
+        std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
+    }
+
+    fprintf(stderr, "%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts);
+    return false;
+}
+
 static bool llama_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
 
     // Initialize libcurl
@@ -1049,9 +1075,8 @@ static bool llama_download_file(const std::string & url, const std::string & pat
         curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
         curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
 
-        CURLcode res = curl_easy_perform(curl.get());
-        if (res != CURLE_OK) {
-            fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res));
+        bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
+        if (!was_perform_successful) {
             return false;
         }
 
@@ -1126,11 +1151,10 @@ static bool llama_download_file(const std::string & url, const std::string & pat
         };
 
         // start the download
-        fprintf(stderr, "%s: downloading from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
-                llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str());
-        auto res = curl_easy_perform(curl.get());
-        if (res != CURLE_OK) {
-            fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res));
+        fprintf(stderr, "%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
+            llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str());
+        bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
+        if (!was_perform_successful) {
             return false;
         }
 
diff --git a/lora-tests b/lora-tests
new file mode 160000 (submodule)
index 0000000..c26d5fb
--- /dev/null
@@ -0,0 +1 @@
+Subproject commit c26d5fb85b4070a9e9c4e65d132c783b98086890