]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : add lora hotswap endpoint (WIP) (#8857)
authorXuan Son Nguyen <redacted>
Tue, 6 Aug 2024 15:33:39 +0000 (17:33 +0200)
committerGitHub <redacted>
Tue, 6 Aug 2024 15:33:39 +0000 (17:33 +0200)
* server : add lora hotswap endpoint

* handle lora_no_apply

* fix build

* updae docs

* clean up struct def

* fix build

* add LoRA test

* fix style

common/common.cpp
common/common.h
examples/export-lora/export-lora.cpp
examples/server/README.md
examples/server/server.cpp
examples/server/tests/features/lora.feature [new file with mode: 0644]
examples/server/tests/features/steps/steps.py
examples/server/tests/requirements.txt
gguf-py/gguf/metadata.py

index ee7fbcba3c7973996583a2bad899adcda06a17d4..2e8374d50cafad99cb9f59eccb4ac5c22ed6b32b 100644 (file)
@@ -684,14 +684,24 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
     }
     if (arg == "--lora") {
         CHECK_ARG
-        params.lora_adapter.emplace_back(argv[i], 1.0f);
+        params.lora_adapters.push_back({
+            std::string(argv[i]),
+            1.0,
+        });
         return true;
     }
     if (arg == "--lora-scaled") {
         CHECK_ARG
-        const char* lora_adapter = argv[i];
+        std::string lora_adapter = argv[i];
         CHECK_ARG
-        params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i]));
+        params.lora_adapters.push_back({
+            lora_adapter,
+            std::stof(argv[i]),
+        });
+        return true;
+    }
+    if (arg == "--lora-init-without-apply") {
+        params.lora_init_without_apply = true;
         return true;
     }
     if (arg == "--control-vector") {
@@ -1654,6 +1664,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
                                                                         "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
     options.push_back({ "server",      "-sps,  --slot-prompt-similarity SIMILARITY",
                                                                         "how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity });
+    options.push_back({ "server",      "       --lora-init-without-apply",     "load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: %s)", params.lora_init_without_apply ? "enabled" : "disabled"});
 
 #ifndef LOG_DISABLE_LOGS
     options.push_back({ "logging" });
@@ -2091,17 +2102,22 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
         }
     }
 
-    for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
-        const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]);
-        float lora_scale = std::get<1>(params.lora_adapter[i]);
-        auto adapter = llama_lora_adapter_init(model, lora_adapter.c_str());
-        if (adapter == nullptr) {
-            fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
+    // load and optionally apply lora adapters
+    for (auto & la : params.lora_adapters) {
+        llama_lora_adapter_container loaded_la;
+        loaded_la.path = la.path;
+        loaded_la.scale = la.scale;
+        loaded_la.adapter = llama_lora_adapter_init(model, la.path.c_str());
+        if (loaded_la.adapter == nullptr) {
+            fprintf(stderr, "%s: error: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
             llama_free(lctx);
             llama_free_model(model);
             return iparams;
         }
-        llama_lora_adapter_set(lctx, adapter, lora_scale);
+        iparams.lora_adapters.push_back(loaded_la); // copy to list of loaded adapters
+    }
+    if (!params.lora_init_without_apply) {
+        llama_lora_adapters_apply(lctx, iparams.lora_adapters);
     }
 
     if (params.ignore_eos) {
@@ -2140,6 +2156,15 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
     return iparams;
 }
 
+void llama_lora_adapters_apply(struct llama_context * ctx, std::vector<llama_lora_adapter_container> & lora_adapters) {
+    llama_lora_adapter_clear(ctx);
+    for (auto & la : lora_adapters) {
+        if (la.scale != 0.0f) {
+            llama_lora_adapter_set(ctx, la.adapter, la.scale);
+        }
+    }
+}
+
 struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) {
     auto mparams = llama_model_default_params();
 
@@ -3162,19 +3187,18 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
     }
 
     fprintf(stream, "lora:\n");
-    for (std::tuple<std::string, float> la : params.lora_adapter) {
-        if (std::get<1>(la) != 1.0f) {
-            continue;
+    for (auto & la : params.lora_adapters) {
+        if (la.scale == 1.0f) {
+            fprintf(stream, "  - %s\n", la.path.c_str());
         }
-        fprintf(stream, "  - %s\n", std::get<0>(la).c_str());
     }
     fprintf(stream, "lora_scaled:\n");
-    for (std::tuple<std::string, float> la : params.lora_adapter) {
-        if (std::get<1>(la) == 1.0f) {
-            continue;
+    for (auto & la : params.lora_adapters) {
+        if (la.scale != 1.0f) {
+            fprintf(stream, "  - %s: %f\n", la.path.c_str(), la.scale);
         }
-        fprintf(stream, "  - %s: %f\n", std::get<0>(la).c_str(), std::get<1>(la));
     }
+    fprintf(stream, "lora_init_without_apply: %s # default: false\n", params.lora_init_without_apply ? "true" : "false");
     fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
     fprintf(stream, "min_keep: %d # default: 0 (disabled)\n", sparams.min_keep);
     fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat);
index 51dcc0d3993f70ed8e96a2568ff7c5dec7c6d7ad..d88966ece20aa73da262d047e6c0137e7a0e972f 100644 (file)
 
 #define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
 
+struct llama_lora_adapter_info {
+    std::string path;
+    float scale;
+};
+
+struct llama_lora_adapter_container : llama_lora_adapter_info {
+    struct llama_lora_adapter * adapter;
+};
+
 // build info
 extern int LLAMA_BUILD_NUMBER;
 extern char const * LLAMA_COMMIT;
@@ -126,8 +135,8 @@ struct gpt_params {
     std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
     std::vector<llama_model_kv_override> kv_overrides;
 
-    // TODO: avoid tuple, use struct
-    std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale
+    bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply)
+    std::vector<llama_lora_adapter_info> lora_adapters; // lora adapter path with user defined scale
 
     std::vector<llama_control_vector_load_info> control_vectors; // control vector with user defined scale
 
@@ -309,8 +318,9 @@ std::string fs_get_cache_file(const std::string & filename);
 //
 
 struct llama_init_result {
-    struct llama_model * model = nullptr;
+    struct llama_model   * model   = nullptr;
     struct llama_context * context = nullptr;
+    std::vector<llama_lora_adapter_container> lora_adapters;
 };
 
 struct llama_init_result    llama_init_from_gpt_params(gpt_params & params);
@@ -321,6 +331,9 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
 struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const char * hf_token, const struct llama_model_params & params);
 struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const char * hf_token, const struct llama_model_params & params);
 
+// clear LoRA adapters from context, then apply new list of adapters
+void llama_lora_adapters_apply(struct llama_context * ctx, std::vector<llama_lora_adapter_container> & lora_adapters);
+
 // Batch utils
 
 void llama_batch_clear(struct llama_batch & batch);
index 150f7e8d58f2068ff89b4267d911755e54dec3cb..d228ae66eeeecdb109cce6f4dfec024008f143d7 100644 (file)
@@ -135,7 +135,7 @@ struct lora_merge_ctx {
 
     lora_merge_ctx(
             std::string & base_fname,
-            std::vector<std::tuple<std::string, float>> & lora_files,
+            std::vector<llama_lora_adapter_info> & lora_files,
             std::string & outfile,
             int n_threads) : base_model(base_fname, 0), n_threads(n_threads), fout(outfile, std::ios::binary) {
         fout.exceptions(std::ofstream::failbit); // fail fast on write errors
@@ -144,9 +144,9 @@ struct lora_merge_ctx {
             throw std::runtime_error("split model is not yet supported");
         }
 
-        for (auto lora_inp : lora_files) {
-            auto fname = std::get<0>(lora_inp);
-            auto scale = std::get<1>(lora_inp);
+        for (auto lora_inp : lora_files) {
+            auto fname = lora_inp.path;
+            auto scale = lora_inp.scale;
             std::unique_ptr<file_input> adapter(new file_input(fname, scale));
             check_metadata_lora(adapter.get());
             adapters.push_back(std::move(adapter));
@@ -407,7 +407,7 @@ int main(int argc, char ** argv) {
 
     g_verbose = (params.verbosity == 1);
     try {
-        lora_merge_ctx ctx(params.model, params.lora_adapter, params.lora_outfile, params.n_threads);
+        lora_merge_ctx ctx(params.model, params.lora_adapters, params.lora_outfile, params.n_threads);
         ctx.run_merge();
     } catch (const std::exception & err) {
         fprintf(stderr, "%s\n", err.what());
index de83ee7d05e7b6c0b7aeda93f9ad9ba72902740b..e17595fe87f2549eede6114ee0a9623b25fdc627 100644 (file)
@@ -207,41 +207,6 @@ model:
   -hff,  --hf-file FILE           Hugging Face model file (default: unused)
   -hft,  --hf-token TOKEN         Hugging Face access token (default: value from HF_TOKEN environment variable)
 
-retrieval:
-
-         --context-file FNAME     file to load context from (repeat to specify multiple files)
-         --chunk-size N           minimum length of embedded text chunks (default: 64)
-         --chunk-separator STRING
-                                  separator between chunks (default: '
-                                  ')
-
-passkey:
-
-         --junk N                 number of times to repeat the junk text (default: 250)
-         --pos N                  position of the passkey in the junk text (default: -1)
-
-imatrix:
-
-  -o,    --output FNAME           output file (default: 'imatrix.dat')
-         --output-frequency N     output the imatrix every N iterations (default: 10)
-         --save-frequency N       save an imatrix copy every N iterations (default: 0)
-         --process-output         collect data for the output tensor (default: false)
-         --no-ppl                 do not compute perplexity (default: true)
-         --chunk N                start processing the input from chunk N (default: 0)
-
-bench:
-
-  -pps                            is the prompt shared across parallel sequences (default: false)
-  -npp n0,n1,...                  number of prompt tokens
-  -ntg n0,n1,...                  number of text generation tokens
-  -npl n0,n1,...                  number of parallel prompts
-
-embedding:
-
-         --embd-normalize         normalisation for embendings (default: 2) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
-         --embd-output-format     empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
-         --embd-separator         separator of embendings (default \n) for example "<#sep#>"
-
 server:
 
          --host HOST              ip address to listen (default: 127.0.0.1)
@@ -267,7 +232,8 @@ server:
                                   https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
   -sps,  --slot-prompt-similarity SIMILARITY
                                   how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)
-
+         --lora-init-without-apply
+                                  load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled)
 
 logging:
 
@@ -279,15 +245,6 @@ logging:
          --log-file FNAME         Specify a log filename (without extension)
          --log-new                Create a separate new log file on start. Each log file will have unique name: "<name>.<ID>.log"
          --log-append             Don't truncate the old log file.
-
-cvector:
-
-  -o,    --output FNAME           output file (default: 'control_vector.gguf')
-         --positive-file FNAME    positive prompts file, one prompt per line (default: 'examples/cvector-generator/positive.txt')
-         --negative-file FNAME    negative prompts file, one prompt per line (default: 'examples/cvector-generator/negative.txt')
-         --pca-batch N            batch size used for PCA. Larger batch runs faster, but uses more memory (default: 100)
-         --pca-iter N             number of iterations used for PCA (default: 1000)
-         --method {pca,mean}      dimensionality reduction method to be used (default: pca)
 ```
 
 
@@ -411,7 +368,8 @@ node index.js
 
 ## API Endpoints
 
-- **GET** `/health`: Returns the current state of the server:
+### GET `/health`: Returns the current state of the server
+
   - 503 -> `{"status": "loading model"}` if the model is still being loaded.
   - 500 -> `{"status": "error"}` if the model failed to load.
   - 200 -> `{"status": "ok", "slots_idle": 1, "slots_processing": 2 }` if the model is successfully loaded and the server is ready for further requests mentioned below.
@@ -420,7 +378,7 @@ node index.js
 
   If the query parameter `include_slots` is passed, `slots` field will contain internal slots data except if `--slots-endpoint-disable` is set.
 
-- **POST** `/completion`: Given a `prompt`, it returns the predicted completion.
+### POST `/completion`: Given a `prompt`, it returns the predicted completion.
 
     *Options:*
 
@@ -498,7 +456,7 @@ node index.js
 
     `samplers`: The order the samplers should be applied in. An array of strings representing sampler type names. If a sampler is not set, it will not be used. If a sampler is specified more than once, it will be applied multiple times. Default: `["top_k", "tfs_z", "typical_p", "top_p", "min_p", "temperature"]` - these are all the available values.
 
-### Result JSON
+**Response format**
 
 - Note: When using streaming mode (`stream`), only `content` and `stop` will be returned until end of completion.
 
@@ -537,7 +495,7 @@ Notice that each `probs` is an array of length `n_probs`.
 - `tokens_evaluated`: Number of tokens evaluated in total from the prompt
 - `truncated`: Boolean indicating if the context size was exceeded during generation, i.e. the number of tokens provided in the prompt (`tokens_evaluated`) plus tokens generated (`tokens predicted`) exceeded the context size (`n_ctx`)
 
-- **POST** `/tokenize`: Tokenize a given text.
+### POST `/tokenize`: Tokenize a given text
 
     *Options:*
 
@@ -545,13 +503,15 @@ Notice that each `probs` is an array of length `n_probs`.
 
     `add_special`: Boolean indicating if special tokens, i.e. `BOS`, should be inserted.  Default: `false`
 
-- **POST** `/detokenize`: Convert tokens to text.
+### POST `/detokenize`: Convert tokens to text
 
     *Options:*
 
     `tokens`: Set the tokens to detokenize.
 
-- **POST** `/embedding`: Generate embedding of a given text just as [the embedding example](../embedding) does.
+### POST `/embedding`: Generate embedding of a given text
+
+The same as [the embedding example](../embedding) does.
 
     *Options:*
 
@@ -559,7 +519,9 @@ Notice that each `probs` is an array of length `n_probs`.
 
     `image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `content`. You can determine the place of the image in the content as in the following: `Image: [img-21].\nCaption: This is a picture of a house`. In this case, `[img-21]` will be replaced by the embeddings of the image with id `21` in the following `image_data` array: `{..., "image_data": [{"data": "<BASE64_STRING>", "id": 21}]}`. Use `image_data` only with multimodal models, e.g., LLaVA.
 
-- **POST** `/infill`: For code infilling. Takes a prefix and a suffix and returns the predicted completion as stream.
+### POST `/infill`: For code infilling.
+
+Takes a prefix and a suffix and returns the predicted completion as stream.
 
     *Options:*
 
@@ -571,7 +533,7 @@ Notice that each `probs` is an array of length `n_probs`.
 
 - **GET** `/props`: Return current server settings.
 
-### Result JSON
+**Response format**
 
 ```json
 {
@@ -589,7 +551,9 @@ Notice that each `probs` is an array of length `n_probs`.
 - `total_slots` - the total number of slots for process requests (defined by `--parallel` option)
 - `chat_template` - the model's original Jinja2 prompt template
 
-- **POST** `/v1/chat/completions`: OpenAI-compatible Chat Completions API. Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only models with a [supported chat template](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) can be used optimally with this endpoint. By default, the ChatML template will be used.
+### POST `/v1/chat/completions`: OpenAI-compatible Chat Completions API
+
+Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only models with a [supported chat template](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) can be used optimally with this endpoint. By default, the ChatML template will be used.
 
     *Options:*
 
@@ -641,7 +605,7 @@ Notice that each `probs` is an array of length `n_probs`.
     }'
     ```
 
-- **POST** `/v1/embeddings`: OpenAI-compatible embeddings API.
+### POST `/v1/embeddings`: OpenAI-compatible embeddings API
 
     *Options:*
 
@@ -675,9 +639,9 @@ Notice that each `probs` is an array of length `n_probs`.
     }'
     ```
 
-- **GET** `/slots`: Returns the current slots processing state. Can be disabled with `--slots-endpoint-disable`.
+### GET `/slots`: Returns the current slots processing state. Can be disabled with `--slots-endpoint-disable`.
 
-### Result JSON
+**Response format**
 
 ```json
 [
@@ -738,7 +702,7 @@ Notice that each `probs` is an array of length `n_probs`.
 ]
 ```
 
-- **GET** `/metrics`: [Prometheus](https://prometheus.io/) compatible metrics exporter endpoint if `--metrics` is enabled:
+### GET `/metrics`: Prometheus compatible metrics exporter endpoint if `--metrics` is enabled:
 
 Available metrics:
 - `llamacpp:prompt_tokens_total`: Number of prompt tokens processed.
@@ -750,13 +714,13 @@ Available metrics:
 - `llamacpp:requests_processing`: Number of requests processing.
 - `llamacpp:requests_deferred`: Number of requests deferred.
 
-- **POST** `/slots/{id_slot}?action=save`: Save the prompt cache of the specified slot to a file.
+### POST `/slots/{id_slot}?action=save`: Save the prompt cache of the specified slot to a file.
 
     *Options:*
 
     `filename`: Name of the file to save the slot's prompt cache. The file will be saved in the directory specified by the `--slot-save-path` server parameter.
 
-### Result JSON
+**Response format**
 
 ```json
 {
@@ -770,13 +734,13 @@ Available metrics:
 }
 ```
 
-- **POST** `/slots/{id_slot}?action=restore`: Restore the prompt cache of the specified slot from a file.
+### POST `/slots/{id_slot}?action=restore`: Restore the prompt cache of the specified slot from a file.
 
     *Options:*
 
     `filename`: Name of the file to restore the slot's prompt cache from. The file should be located in the directory specified by the `--slot-save-path` server parameter.
 
-### Result JSON
+**Response format**
 
 ```json
 {
@@ -790,9 +754,9 @@ Available metrics:
 }
 ```
 
-- **POST** `/slots/{id_slot}?action=erase`: Erase the prompt cache of the specified slot.
+### POST `/slots/{id_slot}?action=erase`: Erase the prompt cache of the specified slot.
 
-### Result JSON
+**Response format**
 
 ```json
 {
@@ -801,6 +765,42 @@ Available metrics:
 }
 ```
 
+### GET `/lora-adapters`: Get list of all LoRA adapters
+
+If an adapter is disabled, the scale will be set to 0.
+
+**Response format**
+
+```json
+[
+    {
+        "id": 0,
+        "path": "my_adapter_1.gguf",
+        "scale": 0.0
+    },
+    {
+        "id": 1,
+        "path": "my_adapter_2.gguf",
+        "scale": 0.0
+    }
+]
+```
+
+### POST `/lora-adapters`: Set list of LoRA adapters
+
+To disable an adapter, either remove it from the list below, or set scale to 0.
+
+**Request format**
+
+To know the `id` of the adapter, use GET `/lora-adapters`
+
+```json
+[
+  {"id": 0, "scale": 0.2},
+  {"id": 1, "scale": 0.8}
+]
+```
+
 ## More examples
 
 ### Change system prompt on runtime
index d178ca0f79b831141668bb144d5e08ac76228422..898c83ea3522bb6d368535fb449fb727f5caeac3 100644 (file)
@@ -78,6 +78,7 @@ enum server_task_type {
     SERVER_TASK_TYPE_SLOT_SAVE,
     SERVER_TASK_TYPE_SLOT_RESTORE,
     SERVER_TASK_TYPE_SLOT_ERASE,
+    SERVER_TASK_TYPE_SET_LORA,
 };
 
 struct server_task {
@@ -622,6 +623,7 @@ struct server_response {
 struct server_context {
     llama_model * model = nullptr;
     llama_context * ctx = nullptr;
+    std::vector<llama_lora_adapter_container> lora_adapters;
 
     gpt_params params;
 
@@ -681,6 +683,7 @@ struct server_context {
 
         model = llama_init.model;
         ctx = llama_init.context;
+        lora_adapters = llama_init.lora_adapters;
         params.n_parallel -= 1; // but be sneaky about it
         if (model == nullptr) {
             LOG_ERROR("unable to load model", {{"model", params.model}});
@@ -1850,6 +1853,14 @@ struct server_context {
                     };
                     queue_results.send(result);
                 } break;
+            case SERVER_TASK_TYPE_SET_LORA:
+                {
+                    llama_lora_adapters_apply(ctx, lora_adapters);
+                    server_task_result result;
+                    result.id = task.id;
+                    result.data = json{{ "success", true }};
+                    queue_results.send(result);
+                } break;
         }
     }
 
@@ -3328,6 +3339,55 @@ int main(int argc, char ** argv) {
         return res.set_content(root.dump(), "application/json; charset=utf-8");
     };
 
+    const auto handle_lora_adapters_list = [&](const httplib::Request & req, httplib::Response & res) {
+        res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
+        json result = json::array();
+        for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) {
+            auto & la = ctx_server.lora_adapters[i];
+            result.push_back({
+                {"id", i},
+                {"path", la.path},
+                {"scale", la.scale},
+            });
+        }
+        res.set_content(result.dump(), "application/json");
+        res.status = 200; // HTTP OK
+    };
+
+    const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
+        res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
+
+        const std::vector<json> body = json::parse(req.body);
+        int max_idx = ctx_server.lora_adapters.size();
+
+        // clear existing value
+        for (auto & la : ctx_server.lora_adapters) {
+            la.scale = 0.0f;
+        }
+
+        // set value
+        for (auto entry : body) {
+            int id      = entry.at("id");
+            float scale = entry.at("scale");
+            if (0 <= id && id < max_idx) {
+                ctx_server.lora_adapters[id].scale = scale;
+            } else {
+                throw std::runtime_error("invalid adapter id");
+            }
+        }
+
+        server_task task;
+        task.type = SERVER_TASK_TYPE_SET_LORA;
+        const int id_task = ctx_server.queue_tasks.post(task);
+        ctx_server.queue_results.add_waiting_task_id(id_task);
+
+        server_task_result result = ctx_server.queue_results.recv(id_task);
+        ctx_server.queue_results.remove_waiting_task_id(id_task);
+
+        res.set_content(result.data.dump(), "application/json");
+        res.status = 200; // HTTP OK
+    };
+
     auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
         return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
             res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
@@ -3366,7 +3426,6 @@ int main(int argc, char ** argv) {
 
     // register API routes
     svr->Get ("/health",              handle_health);
-    svr->Get ("/slots",               handle_slots);
     svr->Get ("/metrics",             handle_metrics);
     svr->Get ("/props",               handle_props);
     svr->Get ("/v1/models",           handle_models);
@@ -3381,6 +3440,11 @@ int main(int argc, char ** argv) {
     svr->Post("/v1/embeddings",       handle_embeddings);
     svr->Post("/tokenize",            handle_tokenize);
     svr->Post("/detokenize",          handle_detokenize);
+    // LoRA adapters hotswap
+    svr->Get ("/lora-adapters",       handle_lora_adapters_list);
+    svr->Post("/lora-adapters",       handle_lora_adapters_apply);
+    // Save & load slots
+    svr->Get ("/slots",               handle_slots);
     if (!params.slot_save_path.empty()) {
         // only enable slot endpoints if slot_save_path is set
         svr->Post("/slots/:id_slot",  handle_slots_action);
diff --git a/examples/server/tests/features/lora.feature b/examples/server/tests/features/lora.feature
new file mode 100644 (file)
index 0000000..7b85988
--- /dev/null
@@ -0,0 +1,36 @@
+@llama.cpp
+@lora
+Feature: llama.cpp server
+
+  Background: Server startup
+    Given a server listening on localhost:8080
+    And   a model url https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/stories15M_MOE-F16.gguf
+    And   a model file stories15M_MOE-F16.gguf
+    And   a model alias stories15M_MOE
+    And   a lora adapter file from https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf
+    And   42 as server seed
+    And   1024 as batch size
+    And   1024 as ubatch size
+    And   2048 KV cache size
+    And   64 max tokens to predict
+    And   0.0 temperature
+    Then  the server is starting
+    Then  the server is healthy
+
+  Scenario: Completion LoRA disabled
+    Given switch off lora adapter 0
+    Given a prompt:
+    """
+    Look in thy glass
+    """
+    And   a completion request with no api error
+    Then  64 tokens are predicted matching little|girl|three|years|old
+
+  Scenario: Completion LoRA enabled
+    Given switch on lora adapter 0
+    Given a prompt:
+    """
+    Look in thy glass
+    """
+    And   a completion request with no api error
+    Then  64 tokens are predicted matching eye|love|glass|sun
index df0814cc99bd136d943a13a6a5b4f5ea81ccd7f6..6705a34fc469650b64f3a4b911af9adc355db644 100644 (file)
@@ -7,6 +7,7 @@ import subprocess
 import sys
 import threading
 import time
+import requests
 from collections.abc import Sequence
 from contextlib import closing
 from re import RegexFlag
@@ -70,6 +71,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
     context.user_api_key = None
     context.response_format = None
     context.temperature = None
+    context.lora_file = None
 
     context.tasks_result = []
     context.concurrent_tasks = []
@@ -82,6 +84,12 @@ def step_download_hf_model(context, hf_file: str, hf_repo: str):
     context.model_hf_file = hf_file
     context.model_file = os.path.basename(hf_file)
 
+@step('a lora adapter file from {lora_file_url}')
+def step_download_lora_file(context, lora_file_url: str):
+    file_name = lora_file_url.split('/').pop()
+    context.lora_file = f'../../../{file_name}'
+    with open(context.lora_file, 'wb') as f:
+        f.write(requests.get(lora_file_url).content)
 
 @step('a model file {model_file}')
 def step_model_file(context, model_file: str):
@@ -849,6 +857,17 @@ async def step_erase_slot(context, slot_id):
             context.response = response
 
 
+@step('switch {on_or_off} lora adapter {lora_id:d}')
+@async_run_until_complete
+async def toggle_lora_adapter(context, on_or_off: str, lora_id: int):
+    async with aiohttp.ClientSession() as session:
+        async with session.post(f'{context.base_url}/lora-adapters',
+                                json=[{'id': lora_id, 'scale': 1 if on_or_off == 'on' else 0}],
+                                headers={"Content-Type": "application/json"}) as response:
+            context.response = response
+            print([{'id': lora_id, 'scale': 1 if on_or_off == 'on' else 0}])
+
+
 @step('the server responds with status code {status_code:d}')
 def step_server_responds_with_status_code(context, status_code):
     assert context.response.status == status_code
@@ -1326,6 +1345,8 @@ def start_server_background(context):
         server_args.extend(['--grp-attn-w', context.n_ga_w])
     if context.debug:
         server_args.append('--verbose')
+    if context.lora_file:
+        server_args.extend(['--lora', context.lora_file])
     if 'SERVER_LOG_FORMAT_JSON' not in os.environ:
         server_args.extend(['--log-format', "text"])
 
index 2c741ea1081c489eead201ecdb569742bf6bab65..f2d7e5c5731be189acf3b43af329d06b81aba171 100644 (file)
@@ -4,3 +4,4 @@ huggingface_hub~=0.20.3
 numpy~=1.26.4
 openai~=1.30.3
 prometheus-client~=0.20.0
+requests~=2.32.3
index ea4d0270562c3d5e2a371201b2ecaf131c88b61c..db318542a279b606e95ff51c82fd77615fce30b8 100644 (file)
@@ -174,7 +174,7 @@ class Metadata:
             org_component, model_full_name_component = None, model_id
 
         # Check if we erroneously matched against './' or '../' etc...
-        if org_component is not None and org_component[0] == '.':
+        if org_component is not None and len(org_component) > 0 and org_component[0] == '.':
             org_component = None
 
         name_parts: list[str] = model_full_name_component.split('-')