]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
common: llama_load_model_from_url using --model-url (#6098)
authorPierrick Hymbert <redacted>
Sun, 17 Mar 2024 18:12:37 +0000 (19:12 +0100)
committerGitHub <redacted>
Sun, 17 Mar 2024 18:12:37 +0000 (19:12 +0100)
* common: llama_load_model_from_url with libcurl dependency

Co-authored-by: Georgi Gerganov <redacted>
16 files changed:
.github/workflows/build.yml
.github/workflows/server.yml
CMakeLists.txt
Makefile
common/CMakeLists.txt
common/common.cpp
common/common.h
examples/main/README.md
examples/server/README.md
examples/server/server.cpp
examples/server/tests/README.md
examples/server/tests/features/embeddings.feature
examples/server/tests/features/environment.py
examples/server/tests/features/server.feature
examples/server/tests/features/steps/steps.py
examples/server/tests/requirements.txt

index 0da01d5ba6eadd1359ab9a0ec10c292ed1f89fe5..945df42f886a60946f3983ade7c9a8c57832bbaa 100644 (file)
@@ -48,6 +48,28 @@ jobs:
           CC=gcc-8 make tests -j $(nproc)
           make test -j $(nproc)
 
+  ubuntu-focal-make-curl:
+    runs-on: ubuntu-20.04
+
+    steps:
+      - name: Clone
+        id: checkout
+        uses: actions/checkout@v3
+
+      - name: Dependencies
+        id: depends
+        run: |
+          sudo apt-get update
+          sudo apt-get install build-essential gcc-8 libcurl4-openssl-dev
+
+      - name: Build
+        id: make_build
+        env:
+          LLAMA_FATAL_WARNINGS: 1
+          LLAMA_CURL: 1
+        run: |
+          CC=gcc-8 make -j $(nproc)
+
   ubuntu-latest-cmake:
     runs-on: ubuntu-latest
 
index 5e38b3547c659197c977ff7e59b4f677f7dee625..4ea09115a3c441ccd8dae044b70c7a549552af5a 100644 (file)
@@ -57,7 +57,8 @@ jobs:
             cmake \
             python3-pip \
             wget \
-            language-pack-en
+            language-pack-en \
+            libcurl4-openssl-dev
 
       - name: Build
         id: cmake_build
@@ -67,6 +68,7 @@ jobs:
           cmake .. \
               -DLLAMA_NATIVE=OFF \
               -DLLAMA_BUILD_SERVER=ON \
+              -DLLAMA_CURL=ON \
               -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
               -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON ;
           cmake --build . --config ${{ matrix.build_type }} -j $(nproc) --target server
@@ -101,12 +103,21 @@ jobs:
         with:
           fetch-depth: 0
 
+      - name: libCURL
+        id: get_libcurl
+        env:
+          CURL_VERSION: 8.6.0_6
+        run: |
+          curl.exe -o $env:RUNNER_TEMP/curl.zip -L "https://curl.se/windows/dl-${env:CURL_VERSION}/curl-${env:CURL_VERSION}-win64-mingw.zip"
+          mkdir $env:RUNNER_TEMP/libcurl
+          tar.exe -xvf $env:RUNNER_TEMP/curl.zip --strip-components=1 -C $env:RUNNER_TEMP/libcurl
+
       - name: Build
         id: cmake_build
         run: |
           mkdir build
           cd build
-          cmake ..  -DLLAMA_BUILD_SERVER=ON -DCMAKE_BUILD_TYPE=Release ;
+          cmake .. -DLLAMA_CURL=ON -DCURL_LIBRARY="$env:RUNNER_TEMP/libcurl/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:RUNNER_TEMP/libcurl/include"
           cmake --build . --config Release -j ${env:NUMBER_OF_PROCESSORS} --target server
 
       - name: Python setup
@@ -120,6 +131,11 @@ jobs:
         run: |
           pip install -r examples/server/tests/requirements.txt
 
+      - name: Copy Libcurl
+        id: prepare_libcurl
+        run: |
+          cp $env:RUNNER_TEMP/libcurl/bin/libcurl-x64.dll ./build/bin/Release/libcurl-x64.dll
+
       - name: Tests
         id: server_integration_tests
         if: ${{ !matrix.disabled_on_pr || !github.event.pull_request }}
index 3ac2804a6881adac476770f9bfdc632c70bd77c7..fc4cff28f44ac047b29a9b088ee741e6d75875c9 100644 (file)
@@ -99,6 +99,7 @@ option(LLAMA_CUDA_F16                        "llama: use 16 bit floats for some
 set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K")
 set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
                                              "llama: max. batch size for using peer access")
+option(LLAMA_CURL                            "llama: use libcurl to download model from an URL" OFF)
 option(LLAMA_HIPBLAS                         "llama: use hipBLAS"                               OFF)
 option(LLAMA_HIP_UMA                         "llama: use HIP unified memory architecture"       OFF)
 option(LLAMA_CLBLAST                         "llama: use CLBlast"                               OFF)
index c0f1250366a64e09b52a3074a84f70a995b96d66..838daf5c02acd89d8878eb6320e9de9d826b6ec5 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -595,6 +595,11 @@ include scripts/get-flags.mk
 CUDA_CXXFLAGS := $(BASE_CXXFLAGS) $(GF_CXXFLAGS) -Wno-pedantic
 endif
 
+ifdef LLAMA_CURL
+override CXXFLAGS := $(CXXFLAGS) -DLLAMA_USE_CURL
+override LDFLAGS  := $(LDFLAGS) -lcurl
+endif
+
 #
 # Print build information
 #
index 350bbdf7f7b1bd7164ed907cbe48652b1c1b21f4..af2629a460b938cfaba5e6cc3cd492f9c31b0627 100644 (file)
@@ -68,6 +68,17 @@ if (BUILD_SHARED_LIBS)
     set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
 endif()
 
+set(LLAMA_COMMON_EXTRA_LIBS build_info)
+
+# Use curl to download model url
+if (LLAMA_CURL)
+    find_package(CURL REQUIRED)
+    add_definitions(-DLLAMA_USE_CURL)
+    include_directories(${CURL_INCLUDE_DIRS})
+    find_library(CURL_LIBRARY curl REQUIRED)
+    set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARY})
+endif ()
+
 target_include_directories(${TARGET} PUBLIC .)
 target_compile_features(${TARGET} PUBLIC cxx_std_11)
-target_link_libraries(${TARGET} PRIVATE build_info PUBLIC llama)
+target_link_libraries(${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama)
index 1b0ba849398debb3e37117b84eaf572dd50865e3..2f5d965d6511c26cd12092239585a676c3c67574 100644 (file)
@@ -37,6 +37,9 @@
 #include <sys/stat.h>
 #include <unistd.h>
 #endif
+#if defined(LLAMA_USE_CURL)
+#include <curl/curl.h>
+#endif
 
 #if defined(_MSC_VER)
 #pragma warning(disable: 4244 4267) // possible loss of data
 #define GGML_USE_CUBLAS_SYCL_VULKAN
 #endif
 
+#if defined(LLAMA_USE_CURL)
+#ifdef __linux__
+#include <linux/limits.h>
+#elif defined(_WIN32)
+#define PATH_MAX MAX_PATH
+#else
+#include <sys/syslimits.h>
+#endif
+#define LLAMA_CURL_MAX_PATH_LENGTH PATH_MAX
+#define LLAMA_CURL_MAX_HEADER_LENGTH 256
+#endif // LLAMA_USE_CURL
+
 int32_t get_num_physical_cores() {
 #ifdef __linux__
     // enumerate the set of thread siblings, num entries is num cores
@@ -644,6 +659,13 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
             }
             params.model = argv[i];
         }
+        if (arg == "-mu" || arg == "--model-url") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.model_url = argv[i];
+        }
         if (arg == "-md" || arg == "--model-draft") {
             arg_found = true;
             if (++i >= argc) {
@@ -1368,6 +1390,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("                        layer range to apply the control vector(s) to, start and end inclusive\n");
     printf("  -m FNAME, --model FNAME\n");
     printf("                        model path (default: %s)\n", params.model.c_str());
+    printf("  -mu MODEL_URL, --model-url MODEL_URL\n");
+    printf("                        model download url (default: %s)\n", params.model_url.c_str());
     printf("  -md FNAME, --model-draft FNAME\n");
     printf("                        draft model for speculative decoding\n");
     printf("  -ld LOGDIR, --logdir LOGDIR\n");
@@ -1613,10 +1637,222 @@ void llama_batch_add(
     batch.n_tokens++;
 }
 
+#ifdef LLAMA_USE_CURL
+
+struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model,
+                                              struct llama_model_params params) {
+    // Basic validation of the model_url
+    if (!model_url || strlen(model_url) == 0) {
+        fprintf(stderr, "%s: invalid model_url\n", __func__);
+        return NULL;
+    }
+
+    // Initialize libcurl globally
+    auto curl = curl_easy_init();
+
+    if (!curl) {
+        fprintf(stderr, "%s: error initializing libcurl\n", __func__);
+        return NULL;
+    }
+
+    // Set the URL, allow to follow http redirection
+    curl_easy_setopt(curl, CURLOPT_URL, model_url);
+    curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
+#if defined(_WIN32)
+    // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
+    //   operating system. Currently implemented under MS-Windows.
+    curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
+#endif
+
+    // Check if the file already exists locally
+    struct stat model_file_info;
+    auto file_exists = (stat(path_model, &model_file_info) == 0);
+
+    // If the file exists, check for ${path_model}.etag or ${path_model}.lastModified files
+    char etag[LLAMA_CURL_MAX_HEADER_LENGTH] = {0};
+    char etag_path[LLAMA_CURL_MAX_PATH_LENGTH] = {0};
+    snprintf(etag_path, sizeof(etag_path), "%s.etag", path_model);
+
+    char last_modified[LLAMA_CURL_MAX_HEADER_LENGTH] = {0};
+    char last_modified_path[LLAMA_CURL_MAX_PATH_LENGTH] = {0};
+    snprintf(last_modified_path, sizeof(last_modified_path), "%s.lastModified", path_model);
+
+    if (file_exists) {
+        auto * f_etag = fopen(etag_path, "r");
+        if (f_etag) {
+            if (!fgets(etag, sizeof(etag), f_etag)) {
+                fprintf(stderr, "%s: unable to read file %s\n", __func__, etag_path);
+            } else {
+                fprintf(stderr, "%s: previous model file found %s: %s\n", __func__, etag_path, etag);
+            }
+            fclose(f_etag);
+        }
+
+        auto * f_last_modified = fopen(last_modified_path, "r");
+        if (f_last_modified) {
+            if (!fgets(last_modified, sizeof(last_modified), f_last_modified)) {
+                fprintf(stderr, "%s: unable to read file %s\n", __func__, last_modified_path);
+            } else {
+                fprintf(stderr, "%s: previous model file found %s: %s\n", __func__, last_modified_path,
+                        last_modified);
+            }
+            fclose(f_last_modified);
+        }
+    }
+
+    // Send a HEAD request to retrieve the etag and last-modified headers
+    struct llama_load_model_from_url_headers {
+        char etag[LLAMA_CURL_MAX_HEADER_LENGTH] = {0};
+        char last_modified[LLAMA_CURL_MAX_HEADER_LENGTH] = {0};
+    };
+    llama_load_model_from_url_headers headers;
+    {
+        typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
+        auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
+            llama_load_model_from_url_headers *headers = (llama_load_model_from_url_headers *) userdata;
+
+            const char * etag_prefix = "etag: ";
+            if (strncmp(buffer, etag_prefix, strlen(etag_prefix)) == 0) {
+                strncpy(headers->etag, buffer + strlen(etag_prefix), n_items - strlen(etag_prefix) - 2); // Remove CRLF
+            }
+
+            const char * last_modified_prefix = "last-modified: ";
+            if (strncmp(buffer, last_modified_prefix, strlen(last_modified_prefix)) == 0) {
+                strncpy(headers->last_modified, buffer + strlen(last_modified_prefix),
+                        n_items - strlen(last_modified_prefix) - 2); // Remove CRLF
+            }
+            return n_items;
+        };
+
+        curl_easy_setopt(curl, CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
+        curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 1L); // hide head request progress
+        curl_easy_setopt(curl, CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
+        curl_easy_setopt(curl, CURLOPT_HEADERDATA, &headers);
+
+        CURLcode res = curl_easy_perform(curl);
+        if (res != CURLE_OK) {
+            curl_easy_cleanup(curl);
+            fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res));
+            return NULL;
+        }
+
+        long http_code = 0;
+        curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code);
+        if (http_code != 200) {
+            // HEAD not supported, we don't know if the file has changed
+            // force trigger downloading
+            file_exists = false;
+            fprintf(stderr, "%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
+        }
+    }
+
+    // If the ETag or the Last-Modified headers are different: trigger a new download
+    if (!file_exists || strcmp(etag, headers.etag) != 0 || strcmp(last_modified, headers.last_modified) != 0) {
+        char path_model_temporary[LLAMA_CURL_MAX_PATH_LENGTH] = {0};
+        snprintf(path_model_temporary, sizeof(path_model_temporary), "%s.downloadInProgress", path_model);
+        if (file_exists) {
+            fprintf(stderr, "%s: deleting previous downloaded model file: %s\n", __func__, path_model);
+            if (remove(path_model) != 0) {
+                curl_easy_cleanup(curl);
+                fprintf(stderr, "%s: unable to delete file: %s\n", __func__, path_model);
+                return NULL;
+            }
+        }
+
+        // Set the output file
+        auto * outfile = fopen(path_model_temporary, "wb");
+        if (!outfile) {
+            curl_easy_cleanup(curl);
+            fprintf(stderr, "%s: error opening local file for writing: %s\n", __func__, path_model);
+            return NULL;
+        }
+
+        typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd);
+        auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
+            return fwrite(data, size, nmemb, (FILE *)fd);
+        };
+        curl_easy_setopt(curl, CURLOPT_NOBODY, 0L);
+        curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
+        curl_easy_setopt(curl, CURLOPT_WRITEDATA, outfile);
+
+        //  display download progress
+        curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L);
+
+        // start the download
+        fprintf(stderr, "%s: downloading model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
+                model_url, path_model, headers.etag, headers.last_modified);
+        auto res = curl_easy_perform(curl);
+        if (res != CURLE_OK) {
+            fclose(outfile);
+            curl_easy_cleanup(curl);
+            fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res));
+            return NULL;
+        }
+
+        long http_code = 0;
+        curl_easy_getinfo (curl, CURLINFO_RESPONSE_CODE, &http_code);
+        if (http_code < 200 || http_code >= 400) {
+            fclose(outfile);
+            curl_easy_cleanup(curl);
+            fprintf(stderr, "%s: invalid http status code received: %ld\n", __func__, http_code);
+            return NULL;
+        }
+
+        // Clean up
+        fclose(outfile);
+
+        // Write the new ETag to the .etag file
+        if (strlen(headers.etag) > 0) {
+            auto * etag_file = fopen(etag_path, "w");
+            if (etag_file) {
+                fputs(headers.etag, etag_file);
+                fclose(etag_file);
+                fprintf(stderr, "%s: model etag saved %s: %s\n", __func__, etag_path, headers.etag);
+            }
+        }
+
+        // Write the new lastModified to the .etag file
+        if (strlen(headers.last_modified) > 0) {
+            auto * last_modified_file = fopen(last_modified_path, "w");
+            if (last_modified_file) {
+                fputs(headers.last_modified, last_modified_file);
+                fclose(last_modified_file);
+                fprintf(stderr, "%s: model last modified saved %s: %s\n", __func__, last_modified_path,
+                        headers.last_modified);
+            }
+        }
+
+        if (rename(path_model_temporary, path_model) != 0) {
+            curl_easy_cleanup(curl);
+            fprintf(stderr, "%s: unable to rename file: %s to %s\n", __func__, path_model_temporary, path_model);
+            return NULL;
+        }
+    }
+
+    curl_easy_cleanup(curl);
+
+    return llama_load_model_from_file(path_model, params);
+}
+
+#else
+
+struct llama_model * llama_load_model_from_url(const char * /*model_url*/, const char * /*path_model*/,
+                                              struct llama_model_params /*params*/) {
+    fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
+    return nullptr;
+}
+
+#endif // LLAMA_USE_CURL
+
 std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
     auto mparams = llama_model_params_from_gpt_params(params);
 
-    llama_model * model  = llama_load_model_from_file(params.model.c_str(), mparams);
+    llama_model * model = nullptr;
+    if (!params.model_url.empty()) {
+        model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams);
+    } else {
+        model = llama_load_model_from_file(params.model.c_str(), mparams);
+    }
     if (model == NULL) {
         fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
         return std::make_tuple(nullptr, nullptr);
index 687f3425e8544c804895512e83ee7701fcd266db..8dd8a3edc9c9401fcea5d28563297640ff5039e1 100644 (file)
@@ -89,6 +89,7 @@ struct gpt_params {
     struct llama_sampling_params sparams;
 
     std::string model             = "models/7B/ggml-model-f16.gguf"; // model path
+    std::string model_url         = ""; // model url to download
     std::string model_draft       = "";                              // draft model for speculative decoding
     std::string model_alias       = "unknown"; // model alias
     std::string prompt            = "";
@@ -191,6 +192,9 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
 struct llama_model_params   llama_model_params_from_gpt_params  (const gpt_params & params);
 struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
 
+struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model,
+                                                         struct llama_model_params     params);
+
 // Batch utils
 
 void llama_batch_clear(struct llama_batch & batch);
index 7f84e4262327414709ef1bc57adbe95c311edaa7..6a8d1e1c50cbb52f69d099c0026c046a1ca98ee8 100644 (file)
@@ -67,6 +67,7 @@ main.exe -m models\7B\ggml-model.bin --ignore-eos -n -1 --random-prompt
 In this section, we cover the most commonly used options for running the `main` program with the LLaMA models:
 
 -   `-m FNAME, --model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.bin`).
+-   `-mu MODEL_URL --model-url MODEL_URL`: Specify a remote http url to download the file (e.g https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf).
 -   `-i, --interactive`: Run the program in interactive mode, allowing you to provide input directly and receive real-time responses.
 -   `-ins, --instruct`: Run the program in instruction mode, which is particularly useful when working with Alpaca models.
 -   `-n N, --n-predict N`: Set the number of tokens to predict when generating text. Adjusting this value can influence the length of the generated text.
index 8f8454affaecd766cf5946b12f3946b97bd78adb..755e1d5384f55eb1e0e87726a1b10dbef117b83a 100644 (file)
@@ -20,6 +20,7 @@ The project is under active development, and we are [looking for feedback and co
 - `-tb N, --threads-batch N`: Set the number of threads to use during batch and prompt processing. If not specified, the number of threads will be set to the number of threads used for generation.
 - `--threads-http N`: number of threads in the http server pool to process requests (default: `max(std::thread::hardware_concurrency() - 1, --parallel N + 2)`)
 - `-m FNAME`, `--model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.gguf`).
+- `-mu MODEL_URL --model-url MODEL_URL`: Specify a remote http url to download the file (e.g https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf).
 - `-a ALIAS`, `--alias ALIAS`: Set an alias for the model. The alias will be returned in API responses.
 - `-c N`, `--ctx-size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference. The size may differ in other models, for example, baichuan models were build with a context of 4096.
 - `-ngl N`, `--n-gpu-layers N`: When compiled with appropriate support (currently CLBlast or cuBLAS), this option allows offloading some layers to the GPU for computation. Generally results in increased performance.
index 895d608fdcc06b4cc7fad28bc8f0c41da5520aa3..d2a8e541d3305c9cd53f4314db53c90c06650ab5 100644 (file)
@@ -2195,6 +2195,8 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
     }
     printf("  -m FNAME, --model FNAME\n");
     printf("                            model path (default: %s)\n", params.model.c_str());
+    printf("  -mu MODEL_URL, --model-url MODEL_URL\n");
+    printf("                            model download url (default: %s)\n", params.model_url.c_str());
     printf("  -a ALIAS, --alias ALIAS\n");
     printf("                            set an alias for the model, will be added as `model` field in completion response\n");
     printf("  --lora FNAME              apply LoRA adapter (implies --no-mmap)\n");
@@ -2317,6 +2319,12 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
                 break;
             }
             params.model = argv[i];
+        } else if (arg == "-mu" || arg == "--model-url") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.model_url = argv[i];
         } else if (arg == "-a" || arg == "--alias") {
             if (++i >= argc) {
                 invalid_param = true;
index 95a0353b6a9c5ba1a188f82373a13521b717d39b..feb2b1d6cf5de8618e7b1e26f5a439b5f97da617 100644 (file)
@@ -57,7 +57,7 @@ Feature or Scenario must be annotated with `@llama.cpp` to be included in the de
 To run a scenario annotated with `@bug`, start:
 
 ```shell
-DEBUG=ON ./tests.sh --no-skipped --tags bug
+DEBUG=ON ./tests.sh --no-skipped --tags bug --stop
 ```
 
 After changing logic in `steps.py`, ensure that `@bug` and `@wrong_usage` scenario are updated.
index 57359b267a668b3c6089dd3d62c2a8dc4c61e847..dcf1434f97124121ba8751f6bda97be600ae28ba 100644 (file)
@@ -4,7 +4,8 @@ Feature: llama.cpp server
 
   Background: Server startup
     Given a server listening on localhost:8080
-    And   a model file bert-bge-small/ggml-model-f16.gguf from HF repo ggml-org/models
+    And   a model url https://huggingface.co/ggml-org/models/resolve/main/bert-bge-small/ggml-model-f16.gguf
+    And   a model file ggml-model-f16.gguf
     And   a model alias bert-bge-small
     And   42 as server seed
     And   2 slots
index 8ad987e1bb6183a5f913289862a989a054dc4258..82104e9202e5e05e227e6f111025b0bf49077979 100644 (file)
@@ -1,10 +1,12 @@
-import errno
 import os
+import signal
 import socket
-import subprocess
+import sys
 import time
+import traceback
 from contextlib import closing
-import signal
+
+import psutil
 
 
 def before_scenario(context, scenario):
@@ -20,33 +22,40 @@ def before_scenario(context, scenario):
 
 
 def after_scenario(context, scenario):
-    if context.server_process is None:
-        return
-    if scenario.status == "failed":
-        if 'GITHUB_ACTIONS' in os.environ:
-            print(f"\x1b[33;101mSCENARIO FAILED: {scenario.name} server logs:\x1b[0m\n\n")
-            if os.path.isfile('llama.log'):
-                with closing(open('llama.log', 'r')) as f:
-                    for line in f:
-                        print(line)
-        if not is_server_listening(context.server_fqdn, context.server_port):
-            print("\x1b[33;101mERROR: Server stopped listening\x1b[0m\n")
-
-    if not pid_exists(context.server_process.pid):
-        assert False, f"Server not running pid={context.server_process.pid} ..."
-
-    server_graceful_shutdown(context)
-
-    # Wait few for socket to free up
-    time.sleep(0.05)
-
-    attempts = 0
-    while pid_exists(context.server_process.pid) or is_server_listening(context.server_fqdn, context.server_port):
-        server_kill(context)
-        time.sleep(0.1)
-        attempts += 1
-        if attempts > 5:
-            server_kill_hard(context)
+    try:
+        if 'server_process' not in context or context.server_process is None:
+            return
+        if scenario.status == "failed":
+            if 'GITHUB_ACTIONS' in os.environ:
+                print(f"\x1b[33;101mSCENARIO FAILED: {scenario.name} server logs:\x1b[0m\n\n")
+                if os.path.isfile('llama.log'):
+                    with closing(open('llama.log', 'r')) as f:
+                        for line in f:
+                            print(line)
+            if not is_server_listening(context.server_fqdn, context.server_port):
+                print("\x1b[33;101mERROR: Server stopped listening\x1b[0m\n")
+
+        if not pid_exists(context.server_process.pid):
+            assert False, f"Server not running pid={context.server_process.pid} ..."
+
+        server_graceful_shutdown(context)
+
+        # Wait few for socket to free up
+        time.sleep(0.05)
+
+        attempts = 0
+        while pid_exists(context.server_process.pid) or is_server_listening(context.server_fqdn, context.server_port):
+            server_kill(context)
+            time.sleep(0.1)
+            attempts += 1
+            if attempts > 5:
+                server_kill_hard(context)
+    except:
+        exc = sys.exception()
+        print("error in after scenario: \n")
+        print(exc)
+        print("*** print_tb: \n")
+        traceback.print_tb(exc.__traceback__, file=sys.stdout)
 
 
 def server_graceful_shutdown(context):
@@ -67,11 +76,11 @@ def server_kill_hard(context):
     path = context.server_path
 
     print(f"Server dangling exits, hard killing force {pid}={path}...\n")
-    if os.name == 'nt':
-        process = subprocess.check_output(['taskkill', '/F', '/pid', str(pid)]).decode()
-        print(process)
-    else:
-        os.kill(-pid, signal.SIGKILL)
+    try:
+        psutil.Process(pid).kill()
+    except psutil.NoSuchProcess:
+        return False
+    return True
 
 
 def is_server_listening(server_fqdn, server_port):
@@ -84,17 +93,9 @@ def is_server_listening(server_fqdn, server_port):
 
 
 def pid_exists(pid):
-    """Check whether pid exists in the current process table."""
-    if pid < 0:
+    try:
+        psutil.Process(pid)
+    except psutil.NoSuchProcess:
         return False
-    if os.name == 'nt':
-        output = subprocess.check_output(['TASKLIST', '/FI', f'pid eq {pid}']).decode()
-        print(output)
-        return "No tasks are running" not in output
-    else:
-        try:
-            os.kill(pid, 0)
-        except OSError as e:
-            return e.errno == errno.EPERM
-        else:
-            return True
+    return True
+
index 5014f326dc050c1aa2987ebb3411f215b2e10e7b..7448986e75a496110147e6fce045210d4cdd15b1 100644 (file)
@@ -4,7 +4,8 @@ Feature: llama.cpp server
 
   Background: Server startup
     Given a server listening on localhost:8080
-    And   a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
+    And   a model url https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories260K.gguf
+    And   a model file stories260K.gguf
     And   a model alias tinyllama-2
     And   42 as server seed
       # KV Cache corresponds to the total amount of tokens
index a59a52d21748a3b03dc7c7641afbbcae774599e2..9e348d5fc4c37f294ea1d50d0ca776dd688a78cb 100644 (file)
@@ -5,6 +5,8 @@ import os
 import re
 import socket
 import subprocess
+import sys
+import threading
 import time
 from contextlib import closing
 from re import RegexFlag
@@ -32,6 +34,8 @@ def step_server_config(context, server_fqdn, server_port):
     context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
 
     context.model_alias = None
+    context.model_file = None
+    context.model_url = None
     context.n_batch = None
     context.n_ubatch = None
     context.n_ctx = None
@@ -65,6 +69,16 @@ def step_download_hf_model(context, hf_file, hf_repo):
         print(f"model file: {context.model_file}\n")
 
 
+@step('a model file {model_file}')
+def step_model_file(context, model_file):
+    context.model_file = model_file
+
+
+@step('a model url {model_url}')
+def step_model_url(context, model_url):
+    context.model_url = model_url
+
+
 @step('a model alias {model_alias}')
 def step_model_alias(context, model_alias):
     context.model_alias = model_alias
@@ -141,7 +155,8 @@ def step_start_server(context):
 async def step_wait_for_the_server_to_be_started(context, expecting_status):
     match expecting_status:
         case 'healthy':
-            await wait_for_health_status(context, context.base_url, 200, 'ok')
+            await wait_for_health_status(context, context.base_url, 200, 'ok',
+                                         timeout=30)
 
         case 'ready' | 'idle':
             await wait_for_health_status(context, context.base_url, 200, 'ok',
@@ -1038,8 +1053,11 @@ def start_server_background(context):
     server_args = [
         '--host', server_listen_addr,
         '--port', context.server_port,
-        '--model', context.model_file
     ]
+    if context.model_file:
+        server_args.extend(['--model', context.model_file])
+    if context.model_url:
+        server_args.extend(['--model-url', context.model_url])
     if context.n_batch:
         server_args.extend(['--batch-size', context.n_batch])
     if context.n_ubatch:
@@ -1079,8 +1097,23 @@ def start_server_background(context):
 
     pkwargs = {
         'creationflags': flags,
+        'stdout': subprocess.PIPE,
+        'stderr': subprocess.PIPE
     }
     context.server_process = subprocess.Popen(
         [str(arg) for arg in [context.server_path, *server_args]],
         **pkwargs)
+
+    def log_stdout(process):
+        for line in iter(process.stdout.readline, b''):
+            print(line.decode('utf-8'), end='')
+    thread_stdout = threading.Thread(target=log_stdout, args=(context.server_process,))
+    thread_stdout.start()
+
+    def log_stderr(process):
+        for line in iter(process.stderr.readline, b''):
+            print(line.decode('utf-8'), end='', file=sys.stderr)
+    thread_stderr = threading.Thread(target=log_stderr, args=(context.server_process,))
+    thread_stderr.start()
+
     print(f"server pid={context.server_process.pid}, behave pid={os.getpid()}")
index 2e4f42ad28c233fd2e3cec00f2a33aecb3925f5a..c2c960102b52346fe704a382ae2c8428b3966c7f 100644 (file)
@@ -3,4 +3,5 @@ behave~=1.2.6
 huggingface_hub~=0.20.3
 numpy~=1.24.4
 openai~=0.25.0
+psutil~=5.9.8
 prometheus-client~=0.20.0