]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Implement s3:// protocol (#11511)
authorEric Curtin <redacted>
Sat, 1 Feb 2025 10:30:54 +0000 (11:30 +0100)
committerGitHub <redacted>
Sat, 1 Feb 2025 10:30:54 +0000 (10:30 +0000)
For those that want to pull from s3

Signed-off-by: Eric Curtin <redacted>
examples/run/run.cpp

index 9cecae48c2d5b2ec2e36b5fb4a9ff576ec0c2235..cf61f4add3528a3467a63b1af449997bdbf6583f 100644 (file)
@@ -65,6 +65,13 @@ static int printe(const char * fmt, ...) {
     return ret;
 }
 
+static std::string strftime_fmt(const char * fmt, const std::tm & tm) {
+    std::ostringstream oss;
+    oss << std::put_time(&tm, fmt);
+
+    return oss.str();
+}
+
 class Opt {
   public:
     int init(int argc, const char ** argv) {
@@ -698,6 +705,39 @@ class LlamaData {
         return download(url, bn, true);
     }
 
+    int s3_dl(const std::string & model, const std::string & bn) {
+        const size_t slash_pos = model.find('/');
+        if (slash_pos == std::string::npos) {
+            return 1;
+        }
+
+        const std::string bucket     = model.substr(0, slash_pos);
+        const std::string key        = model.substr(slash_pos + 1);
+        const char * access_key = std::getenv("AWS_ACCESS_KEY_ID");
+        const char * secret_key = std::getenv("AWS_SECRET_ACCESS_KEY");
+        if (!access_key || !secret_key) {
+            printe("AWS credentials not found in environment\n");
+            return 1;
+        }
+
+        // Generate AWS Signature Version 4 headers
+        // (Implementation requires HMAC-SHA256 and date handling)
+        // Get current timestamp
+        const time_t                   now     = time(nullptr);
+        const tm                       tm      = *gmtime(&now);
+        const std::string              date     = strftime_fmt("%Y%m%d", tm);
+        const std::string              datetime = strftime_fmt("%Y%m%dT%H%M%SZ", tm);
+        const std::vector<std::string> headers  = {
+            "Authorization: AWS4-HMAC-SHA256 Credential=" + std::string(access_key) + "/" + date +
+                "/us-east-1/s3/aws4_request",
+            "x-amz-content-sha256: UNSIGNED-PAYLOAD", "x-amz-date: " + datetime
+        };
+
+        const std::string url = "https://" + bucket + ".s3.amazonaws.com/" + key;
+
+        return download(url, bn, true, headers);
+    }
+
     std::string basename(const std::string & path) {
         const size_t pos = path.find_last_of("/\\");
         if (pos == std::string::npos) {
@@ -738,6 +778,9 @@ class LlamaData {
             rm_until_substring(model_, "github:");
             rm_until_substring(model_, "://");
             ret = github_dl(model_, bn);
+        } else if (string_starts_with(model_, "s3://")) {
+            rm_until_substring(model_, "://");
+            ret = s3_dl(model_, bn);
         } else {  // ollama:// or nothing
             rm_until_substring(model_, "ollama.com/library/");
             rm_until_substring(model_, "://");