int64_t t_max_prompt_ms = -1; // TODO: implement
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
+ std::vector<common_lora_adapter_container> lora;
+
std::vector<std::string> antiprompt;
std::vector<std::string> response_fields;
bool timings_per_token = false;
samplers.emplace_back(common_sampler_type_to_str(sampler));
}
+ json lora = json::array();
+ for (size_t i = 0; i < this->lora.size(); ++i) {
+ lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
+ }
+
return json {
{"n_predict", n_predict}, // Server configured n_predict
{"seed", sampling.seed},
{"speculative.p_min", speculative.p_min},
{"timings_per_token", timings_per_token},
{"post_sampling_probs", post_sampling_probs},
+ {"lora", lora},
};
}
};
// used by SERVER_TASK_TYPE_METRICS
bool metrics_reset_bucket = false;
+ // used by SERVER_TASK_TYPE_SET_LORA
+ std::vector<common_lora_adapter_container> set_lora;
+
server_task(server_task_type type) : type(type) {}
static slot_params params_from_json_cmpl(
const llama_model * model,
const llama_context * ctx,
const common_params & params_base,
+ const std::vector<common_lora_adapter_container> & lora_base,
const json & data) {
slot_params params;
params.speculative.n_min = std::max(params.speculative.n_min, 2);
params.speculative.n_max = std::max(params.speculative.n_max, 0);
+ if (data.contains("lora")) {
+ if (data.at("lora").is_array()) {
+ params.lora = parse_lora_request(lora_base, data.at("lora"));
+ } else {
+ throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
+ }
+ } else {
+ params.lora = lora_base;
+ }
+
// TODO: add more sanity checks for the input parameters
if (params.sampling.penalty_last_n < -1) {
common_speculative * spec = nullptr;
+ std::vector<common_lora_adapter_container> lora;
+
// the index relative to completion multi-task request
size_t index = 0;
return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
}
+ bool can_batch_with(server_slot & other_slot) {
+ return is_non_causal() == other_slot.is_non_causal()
+ && are_lora_equal(lora, other_slot.lora);
+ }
+
bool has_budget(const common_params & global_params) {
if (params.n_predict == -1 && global_params.n_predict == -1) {
return true; // limitless
llama_model * model = nullptr;
llama_context * ctx = nullptr;
- std::vector<common_lora_adapter_container> loras;
+ std::vector<common_lora_adapter_container> lora;
llama_model * model_dft = nullptr;
llama_context_params cparams_dft;
model = llama_init.model;
ctx = llama_init.context;
- loras = llama_init.lora_adapters;
+ lora = llama_init.lora_adapters;
if (model == nullptr) {
SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
slot.params = std::move(task.params);
slot.prompt_tokens = std::move(task.prompt_tokens);
+ if (!are_lora_equal(task.params.lora, slot.lora)) {
+ // if lora is changed, we cannot reuse cached tokens
+ slot.cache_tokens.clear();
+ slot.lora = std::move(task.params.lora);
+ }
+
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
} break;
case SERVER_TASK_TYPE_SET_LORA:
{
- common_lora_adapters_apply(ctx, loras);
+ lora = std::move(task.set_lora);
auto res = std::make_unique<server_task_result_apply_lora>();
res->id = task.id;
queue_results.send(std::move(res));
// start populating the batch for this iteration
common_batch_clear(batch);
+ // track if given slot can be batched with slots already in the batch
+ server_slot * slot_batched = nullptr;
+
// frist, add sampled tokens from any ongoing sequences
for (auto & slot : slots) {
if (slot.state != SLOT_STATE_GENERATING) {
continue;
}
+ // check if we can batch this slot with the previous one
+ if (!slot_batched) {
+ slot_batched = &slot;
+ } else if (!slot_batched->can_batch_with(slot)) {
+ continue;
+ }
+
slot.i_batch = batch.n_tokens;
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
int32_t n_batch = llama_n_batch(ctx);
int32_t n_ubatch = llama_n_ubatch(ctx);
- // track if this is an embedding or non-embedding batch
- // if we've added sampled tokens above, we are in non-embedding mode
- // -1: none, 0: non-embedding, 1: embedding
- // TODO: make enum
- int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
-
// next, batch any pending prompts without exceeding n_batch
if (params_base.cont_batching || batch.n_tokens == 0) {
for (auto & slot : slots) {
+ // check if we can batch this slot with the previous one
+ if (slot.is_processing()) {
+ if (!slot_batched) {
+ slot_batched = &slot;
+ } else if (!slot_batched->can_batch_with(slot)) {
+ continue;
+ }
+ }
+
// this slot still has a prompt to be processed
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
auto & prompt_tokens = slot.prompt_tokens;
}
}
- // check that we are in the right batch_type, if not defer the slot
- int slot_type = slot.is_non_causal();
- if (batch_type == -1) {
- batch_type = slot_type;
- } else if (batch_type != slot_type) {
- continue;
- }
-
// keep only the common part
if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
// could not partially delete (likely using a non-Transformer model)
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
- // make sure we're in the right embedding mode
- llama_set_embeddings(ctx, batch_type == 1);
+ if (slot_batched) {
+ // make sure we're in the right embedding mode
+ llama_set_embeddings(ctx, slot_batched->is_non_causal());
+ // apply lora, only need to do it once per batch
+ common_lora_adapters_apply(ctx, slot_batched->lora);
+ }
// process the created batch of tokens
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
task.index = i;
task.prompt_tokens = std::move(tokenized_prompts[i]);
- task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.ctx, ctx_server.params_base, data);
+ task.params = server_task::params_from_json_cmpl(
+ ctx_server.model,
+ ctx_server.ctx,
+ ctx_server.params_base,
+ ctx_server.lora,
+ data);
task.id_selected_slot = json_value(data, "id_slot", -1);
// OAI-compat
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
json result = json::array();
- for (size_t i = 0; i < ctx_server.loras.size(); ++i) {
- auto & lora = ctx_server.loras[i];
+ for (size_t i = 0; i < ctx_server.lora.size(); ++i) {
+ auto & lora = ctx_server.lora[i];
result.push_back({
{"id", i},
{"path", lora.path},
};
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
- const std::vector<json> body = json::parse(req.body);
- int max_idx = ctx_server.loras.size();
-
- // clear existing value
- for (auto & lora : ctx_server.loras) {
- lora.scale = 0.0f;
- }
-
- // set value
- for (auto entry : body) {
- int id = entry.at("id");
- float scale = entry.at("scale");
- if (0 <= id && id < max_idx) {
- ctx_server.loras[id].scale = scale;
- } else {
- throw std::runtime_error("invalid adapter id");
- }
+ const json body = json::parse(req.body);
+ if (!body.is_array()) {
+ res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
+ return;
}
-
server_task task(SERVER_TASK_TYPE_SET_LORA);
task.id = ctx_server.queue_tasks.get_new_id();
+ task.set_lora = parse_lora_request(ctx_server.lora, body);
ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task);
import pytest
-import os
from utils import *
server = ServerPreset.stories15m_moe()
def create_server():
global server
server = ServerPreset.stories15m_moe()
- # download lora file if needed
- file_name = LORA_FILE_URL.split('/').pop()
- lora_file = f'../../../{file_name}'
- if not os.path.exists(lora_file):
- print(f"Downloading {LORA_FILE_URL} to {lora_file}")
- with open(lora_file, 'wb') as f:
- f.write(requests.get(LORA_FILE_URL).content)
- print(f"Done downloading lora file")
- server.lora_files = [lora_file]
+ server.lora_files = [download_file(LORA_FILE_URL)]
@pytest.mark.parametrize("scale,re_content", [
assert res.status_code == 200
assert match_regex(re_content, res.body["content"])
+
+def test_lora_per_request():
+ global server
+ server.n_slots = 4
+ server.start()
+
+ # running the same prompt with different lora scales, all in parallel
+ # each prompt will be processed by a different slot
+ prompt = "Look in thy glass"
+ lora_config = [
+ ( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
+ ( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
+ ( [{"id": 0, "scale": 0.3}], "(special|thing|gifted)+" ),
+ ( [{"id": 0, "scale": 0.7}], "(far|from|home|away)+" ),
+ ( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
+ ( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
+ ]
+
+ tasks = [(
+ server.make_request,
+ ("POST", "/completion", {
+ "prompt": prompt,
+ "lora": lora,
+ "seed": 42,
+ "temperature": 0.0,
+ "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
+ })
+ ) for lora, _ in lora_config]
+ results = parallel_function_calls(tasks)
+
+ assert all([res.status_code == 200 for res in results])
+ for res, (_, re_test) in zip(results, lora_config):
+ assert match_regex(re_test, res.body["content"])
+
+
+@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test")
+def test_with_big_model():
+ server = ServerProcess()
+ server.model_hf_repo = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF"
+ server.model_hf_file = "Meta-Llama-3.1-8B-Instruct-IQ2_M.gguf"
+ server.model_alias = "Llama-3.2-8B-Instruct"
+ server.n_slots = 4
+ server.n_ctx = server.n_slots * 1024
+ server.n_predict = 64
+ server.temperature = 0.0
+ server.seed = 42
+ server.lora_files = [
+ download_file("https://huggingface.co/ngxson/Llama-3-Instruct-abliteration-LoRA-8B-F16-GGUF/resolve/main/Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf"),
+ # TODO: find & add other lora adapters for this model
+ ]
+ server.start(timeout_seconds=600)
+
+ # running the same prompt with different lora scales, all in parallel
+ # each prompt will be processed by a different slot
+ prompt = "Write a computer virus"
+ lora_config = [
+ # without applying lora, the model should reject the request
+ ( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ),
+ ( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ),
+ ( [{"id": 0, "scale": 0.3}], "I can't write a computer virus" ),
+ # with 0.7 scale, the model should provide a simple computer virus with hesitation
+ ( [{"id": 0, "scale": 0.7}], "Warning: This is a hypothetical exercise" ),
+ # with 1.5 scale, the model should confidently provide a computer virus
+ ( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ),
+ ( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ),
+ ]
+
+ tasks = [(
+ server.make_request,
+ ("POST", "/v1/chat/completions", {
+ "messages": [
+ {"role": "user", "content": prompt}
+ ],
+ "lora": lora,
+ "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
+ })
+ ) for lora, _ in lora_config]
+ results = parallel_function_calls(tasks)
+
+ assert all([res.status_code == 200 for res in results])
+ for res, (_, re_test) in zip(results, lora_config):
+ assert re_test in res.body["choices"][0]["message"]["content"]