]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : allow using LoRA adapters per-request (#10994)
authorXuan Son Nguyen <redacted>
Thu, 2 Jan 2025 14:05:18 +0000 (15:05 +0100)
committerGitHub <redacted>
Thu, 2 Jan 2025 14:05:18 +0000 (15:05 +0100)
* slot.can_batch_with

* lora per request

* test: force disable cache prompt

* move can_batch_with check

* fix condition

* add slow test with llama 8b

* update docs

* move lora change task to queue

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <redacted>
* lora_base

* remove redundant check

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/server/README.md
examples/server/server.cpp
examples/server/tests/README.md
examples/server/tests/requirements.txt
examples/server/tests/unit/test_lora.py
examples/server/tests/unit/test_speculative.py
examples/server/tests/utils.py
examples/server/utils.hpp

index bcef8194614908d34a2ff1701e531a4b27468e8e..3ce16945ac8072d06ad16b60e6708f1c65fe89b7 100644 (file)
@@ -452,6 +452,8 @@ These words will not be included in the completion, so make sure to add them to
 
 `response_fields`: A list of response fields, for example: `"response_fields": ["content", "generation_settings/n_predict"]`. If the specified field is missing, it will simply be omitted from the response without triggering an error. Note that fields with a slash will be unnested; for example, `generation_settings/n_predict` will move the field `n_predict` from the `generation_settings` object to the root of the response and give it a new name.
 
+`lora`: A list of LoRA adapters to be applied to this specific request. Each object in the list must contain `id` and `scale` fields. For example: `[{"id": 0, "scale": 0.5}, {"id": 1, "scale": 1.1}]`. If a LoRA adapter is not specified in the list, its scale will default to `0.0`. Please note that requests with different LoRA configurations will not be batched together, which may result in performance degradation.
+
 **Response format**
 
 - Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
@@ -945,6 +947,8 @@ This endpoint returns the loaded LoRA adapters. You can add adapters using `--lo
 
 By default, all adapters will be loaded with scale set to 1. To initialize all adapters scale to 0, add `--lora-init-without-apply`
 
+Please note that this value will be overwritten by the `lora` field for each request.
+
 If an adapter is disabled, the scale will be set to 0.
 
 **Response format**
@@ -966,6 +970,8 @@ If an adapter is disabled, the scale will be set to 0.
 
 ### POST `/lora-adapters`: Set list of LoRA adapters
 
+This sets the global scale for LoRA adapters. Please note that this value will be overwritten by the `lora` field for each request.
+
 To disable an adapter, either remove it from the list below, or set scale to 0.
 
 **Request format**
index b3773f276adf2d83af8e8069f665427e960a7d68..5118084f12adb9d344795ed76bf52bdf1526a4d2 100644 (file)
@@ -98,6 +98,8 @@ struct slot_params {
     int64_t t_max_prompt_ms  = -1; // TODO: implement
     int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
 
+    std::vector<common_lora_adapter_container> lora;
+
     std::vector<std::string> antiprompt;
     std::vector<std::string> response_fields;
     bool timings_per_token = false;
@@ -120,6 +122,11 @@ struct slot_params {
             samplers.emplace_back(common_sampler_type_to_str(sampler));
         }
 
+        json lora = json::array();
+        for (size_t i = 0; i < this->lora.size(); ++i) {
+            lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
+        }
+
         return json {
             {"n_predict",                 n_predict},     // Server configured n_predict
             {"seed",                      sampling.seed},
@@ -160,6 +167,7 @@ struct slot_params {
             {"speculative.p_min",         speculative.p_min},
             {"timings_per_token",         timings_per_token},
             {"post_sampling_probs",       post_sampling_probs},
+            {"lora",                      lora},
         };
     }
 };
@@ -189,12 +197,16 @@ struct server_task {
     // used by SERVER_TASK_TYPE_METRICS
     bool metrics_reset_bucket = false;
 
+    // used by SERVER_TASK_TYPE_SET_LORA
+    std::vector<common_lora_adapter_container> set_lora;
+
     server_task(server_task_type type) : type(type) {}
 
     static slot_params params_from_json_cmpl(
             const llama_model * model,
             const llama_context * ctx,
             const common_params & params_base,
+            const std::vector<common_lora_adapter_container> & lora_base,
             const json & data) {
         slot_params params;
 
@@ -251,6 +263,16 @@ struct server_task {
         params.speculative.n_min = std::max(params.speculative.n_min, 2);
         params.speculative.n_max = std::max(params.speculative.n_max, 0);
 
+        if (data.contains("lora")) {
+            if (data.at("lora").is_array()) {
+                params.lora = parse_lora_request(lora_base, data.at("lora"));
+            } else {
+                throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
+            }
+        } else {
+            params.lora = lora_base;
+        }
+
         // TODO: add more sanity checks for the input parameters
 
         if (params.sampling.penalty_last_n < -1) {
@@ -1110,6 +1132,8 @@ struct server_slot {
 
     common_speculative * spec = nullptr;
 
+    std::vector<common_lora_adapter_container> lora;
+
     // the index relative to completion multi-task request
     size_t index = 0;
 
@@ -1191,6 +1215,11 @@ struct server_slot {
         return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
     }
 
+    bool can_batch_with(server_slot & other_slot) {
+        return is_non_causal() == other_slot.is_non_causal()
+            && are_lora_equal(lora, other_slot.lora);
+    }
+
     bool has_budget(const common_params & global_params) {
         if (params.n_predict == -1 && global_params.n_predict == -1) {
             return true; // limitless
@@ -1600,7 +1629,7 @@ struct server_context {
 
     llama_model * model = nullptr;
     llama_context * ctx = nullptr;
-    std::vector<common_lora_adapter_container> loras;
+    std::vector<common_lora_adapter_container> lora;
 
     llama_model * model_dft = nullptr;
     llama_context_params cparams_dft;
@@ -1667,7 +1696,7 @@ struct server_context {
 
         model = llama_init.model;
         ctx   = llama_init.context;
-        loras = llama_init.lora_adapters;
+        lora  = llama_init.lora_adapters;
 
         if (model == nullptr) {
             SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
@@ -1866,6 +1895,12 @@ struct server_context {
         slot.params        = std::move(task.params);
         slot.prompt_tokens = std::move(task.prompt_tokens);
 
+        if (!are_lora_equal(task.params.lora, slot.lora)) {
+            // if lora is changed, we cannot reuse cached tokens
+            slot.cache_tokens.clear();
+            slot.lora = std::move(task.params.lora);
+        }
+
         SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
 
         if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
@@ -2557,7 +2592,7 @@ struct server_context {
                 } break;
             case SERVER_TASK_TYPE_SET_LORA:
                 {
-                    common_lora_adapters_apply(ctx, loras);
+                    lora = std::move(task.set_lora);
                     auto res = std::make_unique<server_task_result_apply_lora>();
                     res->id = task.id;
                     queue_results.send(std::move(res));
@@ -2634,12 +2669,22 @@ struct server_context {
         // start populating the batch for this iteration
         common_batch_clear(batch);
 
+        // track if given slot can be batched with slots already in the batch
+        server_slot * slot_batched = nullptr;
+
         // frist, add sampled tokens from any ongoing sequences
         for (auto & slot : slots) {
             if (slot.state != SLOT_STATE_GENERATING) {
                 continue;
             }
 
+            // check if we can batch this slot with the previous one
+            if (!slot_batched) {
+                slot_batched = &slot;
+            } else if (!slot_batched->can_batch_with(slot)) {
+                continue;
+            }
+
             slot.i_batch = batch.n_tokens;
 
             common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
@@ -2658,15 +2703,18 @@ struct server_context {
         int32_t n_batch  = llama_n_batch(ctx);
         int32_t n_ubatch = llama_n_ubatch(ctx);
 
-        // track if this is an embedding or non-embedding batch
-        // if we've added sampled tokens above, we are in non-embedding mode
-        // -1: none, 0: non-embedding, 1: embedding
-        // TODO: make enum
-        int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
-
         // next, batch any pending prompts without exceeding n_batch
         if (params_base.cont_batching || batch.n_tokens == 0) {
             for (auto & slot : slots) {
+                // check if we can batch this slot with the previous one
+                if (slot.is_processing()) {
+                    if (!slot_batched) {
+                        slot_batched = &slot;
+                    } else if (!slot_batched->can_batch_with(slot)) {
+                        continue;
+                    }
+                }
+
                 // this slot still has a prompt to be processed
                 if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
                     auto & prompt_tokens = slot.prompt_tokens;
@@ -2827,14 +2875,6 @@ struct server_context {
                         }
                     }
 
-                    // check that we are in the right batch_type, if not defer the slot
-                    int slot_type = slot.is_non_causal();
-                    if (batch_type == -1) {
-                        batch_type = slot_type;
-                    } else if (batch_type != slot_type) {
-                        continue;
-                    }
-
                     // keep only the common part
                     if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
                         // could not partially delete (likely using a non-Transformer model)
@@ -2902,8 +2942,12 @@ struct server_context {
 
         SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
 
-        // make sure we're in the right embedding mode
-        llama_set_embeddings(ctx, batch_type == 1);
+        if (slot_batched) {
+            // make sure we're in the right embedding mode
+            llama_set_embeddings(ctx, slot_batched->is_non_causal());
+            // apply lora, only need to do it once per batch
+            common_lora_adapters_apply(ctx, slot_batched->lora);
+        }
 
         // process the created batch of tokens
         for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
@@ -3623,7 +3667,12 @@ int main(int argc, char ** argv) {
                 task.index = i;
 
                 task.prompt_tokens    = std::move(tokenized_prompts[i]);
-                task.params           = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.ctx, ctx_server.params_base, data);
+                task.params           = server_task::params_from_json_cmpl(
+                                            ctx_server.model,
+                                            ctx_server.ctx,
+                                            ctx_server.params_base,
+                                            ctx_server.lora,
+                                            data);
                 task.id_selected_slot = json_value(data, "id_slot", -1);
 
                 // OAI-compat
@@ -4049,8 +4098,8 @@ int main(int argc, char ** argv) {
 
     const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
         json result = json::array();
-        for (size_t i = 0; i < ctx_server.loras.size(); ++i) {
-            auto & lora = ctx_server.loras[i];
+        for (size_t i = 0; i < ctx_server.lora.size(); ++i) {
+            auto & lora = ctx_server.lora[i];
             result.push_back({
                 {"id", i},
                 {"path", lora.path},
@@ -4062,27 +4111,14 @@ int main(int argc, char ** argv) {
     };
 
     const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
-        const std::vector<json> body = json::parse(req.body);
-        int max_idx = ctx_server.loras.size();
-
-        // clear existing value
-        for (auto & lora : ctx_server.loras) {
-            lora.scale = 0.0f;
-        }
-
-        // set value
-        for (auto entry : body) {
-            int id      = entry.at("id");
-            float scale = entry.at("scale");
-            if (0 <= id && id < max_idx) {
-                ctx_server.loras[id].scale = scale;
-            } else {
-                throw std::runtime_error("invalid adapter id");
-            }
+        const json body = json::parse(req.body);
+        if (!body.is_array()) {
+            res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
+            return;
         }
-
         server_task task(SERVER_TASK_TYPE_SET_LORA);
         task.id = ctx_server.queue_tasks.get_new_id();
+        task.set_lora = parse_lora_request(ctx_server.lora, body);
         ctx_server.queue_results.add_waiting_task_id(task.id);
         ctx_server.queue_tasks.post(task);
 
index fa3d0a2f5ff661b0c48dd8432536a42265d683c9..5787276abac439d222749d3d8e25ea81371434e0 100644 (file)
@@ -44,6 +44,12 @@ To run with stdout/stderr display in real time (verbose output, but useful for d
 DEBUG=1 ./tests.sh -s -v -x
 ```
 
+To run single test unit:
+
+```shell
+./tests.sh unit/test_{name of test case here}.py -v -x
+```
+
 Hint: You can compile and run test in single command, useful for local developement:
 
 ```shell
index 074b9d47bddce1f8c31733b7c55c1ae3ea7b99b0..15d024914e8412a414f367c732e14042f87e640e 100644 (file)
@@ -5,3 +5,4 @@ numpy~=1.26.4
 openai~=1.55.3
 prometheus-client~=0.20.0
 requests~=2.32.3
+wget~=3.2
index 7496154493917927d59c60f5c7788274521d0c07..c1aa8be70e2f7b27d6c711974b2730cd84c087e1 100644 (file)
@@ -1,5 +1,4 @@
 import pytest
-import os
 from utils import *
 
 server = ServerPreset.stories15m_moe()
@@ -10,15 +9,7 @@ LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe
 def create_server():
     global server
     server = ServerPreset.stories15m_moe()
-    # download lora file if needed
-    file_name = LORA_FILE_URL.split('/').pop()
-    lora_file = f'../../../{file_name}'
-    if not os.path.exists(lora_file):
-        print(f"Downloading {LORA_FILE_URL} to {lora_file}")
-        with open(lora_file, 'wb') as f:
-            f.write(requests.get(LORA_FILE_URL).content)
-        print(f"Done downloading lora file")
-    server.lora_files = [lora_file]
+    server.lora_files = [download_file(LORA_FILE_URL)]
 
 
 @pytest.mark.parametrize("scale,re_content", [
@@ -40,3 +31,85 @@ def test_lora(scale: float, re_content: str):
     assert res.status_code == 200
     assert match_regex(re_content, res.body["content"])
 
+
+def test_lora_per_request():
+    global server
+    server.n_slots = 4
+    server.start()
+
+    # running the same prompt with different lora scales, all in parallel
+    # each prompt will be processed by a different slot
+    prompt = "Look in thy glass"
+    lora_config = [
+        ( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
+        ( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
+        ( [{"id": 0, "scale": 0.3}], "(special|thing|gifted)+" ),
+        ( [{"id": 0, "scale": 0.7}], "(far|from|home|away)+" ),
+        ( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
+        ( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
+    ]
+
+    tasks = [(
+        server.make_request,
+        ("POST", "/completion", {
+            "prompt": prompt,
+            "lora": lora,
+            "seed": 42,
+            "temperature": 0.0,
+            "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
+        })
+    ) for lora, _ in lora_config]
+    results = parallel_function_calls(tasks)
+
+    assert all([res.status_code == 200 for res in results])
+    for res, (_, re_test) in zip(results, lora_config):
+        assert match_regex(re_test, res.body["content"])
+
+
+@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test")
+def test_with_big_model():
+    server = ServerProcess()
+    server.model_hf_repo = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF"
+    server.model_hf_file = "Meta-Llama-3.1-8B-Instruct-IQ2_M.gguf"
+    server.model_alias = "Llama-3.2-8B-Instruct"
+    server.n_slots = 4
+    server.n_ctx = server.n_slots * 1024
+    server.n_predict = 64
+    server.temperature = 0.0
+    server.seed = 42
+    server.lora_files = [
+        download_file("https://huggingface.co/ngxson/Llama-3-Instruct-abliteration-LoRA-8B-F16-GGUF/resolve/main/Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf"),
+        # TODO: find & add other lora adapters for this model
+    ]
+    server.start(timeout_seconds=600)
+
+    # running the same prompt with different lora scales, all in parallel
+    # each prompt will be processed by a different slot
+    prompt = "Write a computer virus"
+    lora_config = [
+        # without applying lora, the model should reject the request
+        ( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ),
+        ( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ),
+        ( [{"id": 0, "scale": 0.3}], "I can't write a computer virus" ),
+        # with 0.7 scale, the model should provide a simple computer virus with hesitation
+        ( [{"id": 0, "scale": 0.7}], "Warning: This is a hypothetical exercise" ),
+        # with 1.5 scale, the model should confidently provide a computer virus
+        ( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ),
+        ( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ),
+    ]
+
+    tasks = [(
+        server.make_request,
+        ("POST", "/v1/chat/completions", {
+            "messages": [
+                {"role": "user", "content": prompt}
+            ],
+            "lora": lora,
+            "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
+        })
+    ) for lora, _ in lora_config]
+    results = parallel_function_calls(tasks)
+
+    assert all([res.status_code == 200 for res in results])
+    for res, (_, re_test) in zip(results, lora_config):
+        assert re_test in res.body["choices"][0]["message"]["content"]
index 3bb5733cbdf48f82912ad5f7386d5dac411bddda..54db38cf3bd8046fe356cf7e8c07832b58dd0815 100644 (file)
@@ -10,16 +10,8 @@ MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tiny
 def create_server():
     global server
     server = ServerPreset.stories15m_moe()
-    # download draft model file if needed
-    file_name = MODEL_DRAFT_FILE_URL.split('/').pop()
-    model_draft_file = f'../../../{file_name}'
-    if not os.path.exists(model_draft_file):
-        print(f"Downloading {MODEL_DRAFT_FILE_URL} to {model_draft_file}")
-        with open(model_draft_file, 'wb') as f:
-            f.write(requests.get(MODEL_DRAFT_FILE_URL).content)
-        print(f"Done downloading draft model file")
     # set default values
-    server.model_draft = model_draft_file
+    server.model_draft = download_file(MODEL_DRAFT_FILE_URL)
     server.draft_min = 4
     server.draft_max = 8
 
index 359bb0faeb1c8e3cb5a50940fa1dced33f879bef..a1a94d0f15e3b98ecb9a4aec5bdd44271214b3cb 100644 (file)
@@ -23,6 +23,7 @@ from typing import (
     Set,
 )
 from re import RegexFlag
+import wget
 
 
 class ServerResponse:
@@ -381,5 +382,25 @@ def match_regex(regex: str, text: str) -> bool:
         is not None
     )
 
+
+def download_file(url: str, output_file_path: str | None = None) -> str:
+    """
+    Download a file from a URL to a local path. If the file already exists, it will not be downloaded again.
+
+    output_file_path is the local path to save the downloaded file. If not provided, the file will be saved in the root directory.
+
+    Returns the local path of the downloaded file.
+    """
+    file_name = url.split('/').pop()
+    output_file = f'./tmp/{file_name}' if output_file_path is None else output_file_path
+    if not os.path.exists(output_file):
+        print(f"Downloading {url} to {output_file}")
+        wget.download(url, out=output_file)
+        print(f"Done downloading to {output_file}")
+    else:
+        print(f"File already exists at {output_file}")
+    return output_file
+
+
 def is_slow_test_allowed():
     return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON"
index 70220c4375a1752e72e883c357c2a1af5a92e683..1cf08bb0a3642b20fcad902e6da69121b6d04c1c 100644 (file)
@@ -797,3 +797,44 @@ static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx
 
     return cur;
 }
+
+static bool are_lora_equal(
+        const std::vector<common_lora_adapter_container> & l1,
+        const std::vector<common_lora_adapter_container> & l2) {
+    if (l1.size() != l2.size()) {
+        return false;
+    }
+    for (size_t i = 0; i < l1.size(); ++i) {
+        // we don't check lora.path to reduce the time complexity
+        if (l1[i].scale != l2[i].scale || l1[i].adapter != l2[i].adapter) {
+            return false;
+        }
+    }
+    return true;
+}
+
+// parse lora config from JSON request, returned a copy of base_lora with updated scale
+static std::vector<common_lora_adapter_container> parse_lora_request(
+        const std::vector<common_lora_adapter_container> & base_lora,
+        const json & data) {
+    std::vector<common_lora_adapter_container> lora(base_lora);
+    int max_idx = lora.size();
+
+    // clear existing value
+    for (auto & entry : lora) {
+        entry.scale = 0.0f;
+    }
+
+    // set value
+    for (const auto & entry : data) {
+        int id      = json_value(entry, "id", -1);
+        float scale = json_value(entry, "scale", 0.0f);
+        if (0 <= id && id < max_idx) {
+            lora[id].scale = scale;
+        } else {
+            throw std::runtime_error("invalid adapter id");
+        }
+    }
+
+    return lora;
+}