from math import prod
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
-from transformers import AutoConfig
+from transformers import AutoConfig, AutoTokenizer
import torch
# reuse model definitions from convert_hf_to_gguf.py
from convert_hf_to_gguf import LazyTorchTensor, ModelBase
+from gguf.constants import GGUFValueType
+
logger = logging.getLogger("lora-to-gguf")
self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
def set_gguf_parameters(self):
+ logger.debug("GGUF KV: %s = %d", gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
+ alora_invocation_tokens = lparams.get("alora_invocation_tokens")
+ invocation_string = lparams.get("invocation_string")
+ if invocation_string and not alora_invocation_tokens:
+ logger.debug("Tokenizing invocation_string -> alora_invocation_tokens")
+ base_model_path_or_id = hparams.get("_name_or_path")
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(base_model_path_or_id)
+ except ValueError:
+ logger.error("Unable to load tokenizer from %s", base_model_path_or_id)
+ raise
+ # NOTE: There's an off-by-one with the older aLoRAs where
+ # the invocation string includes the "<|start_of_turn|>"
+ # token, but the adapters themselves were trained to
+ # activate _after_ that first token, so we drop it here.
+ alora_invocation_tokens = tokenizer(invocation_string)["input_ids"][1:]
+ if alora_invocation_tokens:
+ logger.debug("GGUF KV: %s = %s", gguf.Keys.Adapter.ALORA_INVOCATION_TOKENS, alora_invocation_tokens)
+ self.gguf_writer.add_key_value(
+ gguf.Keys.Adapter.ALORA_INVOCATION_TOKENS,
+ alora_invocation_tokens,
+ GGUFValueType.ARRAY,
+ GGUFValueType.UINT32,
+ )
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
# Never add extra tensors (e.g. rope_freqs) for LoRA adapters
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
int32_t n_predict = -1; // new tokens to predict
- int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters
+ int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters
int64_t t_max_prompt_ms = -1; // TODO: implement
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
common_speculative * spec = nullptr;
std::vector<common_adapter_lora_info> lora;
+ int32_t alora_invocation_start = -1;
// the index relative to completion multi-task request
size_t index = 0;
// clear speculative decoding stats
n_draft_total = 0;
n_draft_accepted = 0;
+
+ // clear alora start
+ alora_invocation_start = -1;
}
bool need_embd() const {
slot.prompt_tokens = std::move(task.prompt_tokens);
if (!are_lora_equal(slot.params.lora, slot.lora)) {
- // if lora is changed, we cannot reuse cached tokens
- slot.cache_tokens.clear();
+ // if lora has changed, check to see if the cache should be cleared
+ if (lora_should_clear_cache(slot.lora, slot.params.lora)) {
+ SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), slot.params.lora.size());
+ slot.cache_tokens.clear();
+ } else {
+ SLT_INF(slot, "keeping cache for alora. %zu target loras\n", slot.params.lora.size());
+ }
slot.lora = slot.params.lora;
}
+ // if using alora, make sure it's only a single one requested and active
+ size_t alora_invocation_start = slot.prompt_tokens.size();
+ if (lora_all_alora(slot.lora)) {
+
+ const auto & enabled_ids = lora_get_enabled_ids(slot.lora);
+ // TODO: This will error out if a user requests two aloras, but only
+ // provides the activation string for one. We could, instead search
+ // for all requested alora activation strings and then either keep
+ // only the last one, or reject if multiple are found.
+ if (enabled_ids.size() != 1) {
+ send_error(task, "Cannot run multiple aLoRAs in a single request", ERROR_TYPE_INVALID_REQUEST);
+ return false;
+ }
+ const auto & lora = slot.lora[enabled_ids[0]].ptr;
+
+ // get the pointer and count for the invocation tokens
+ const uint64_t n_invocation_tokens = llama_adapter_get_alora_n_invocation_tokens(lora);
+ const llama_token * invocation_tokens = llama_adapter_get_alora_invocation_tokens (lora);
+
+ // scan backwards through the prompt tokens to find the last
+ // occurrence of the invocation sequence
+ int match_idx = static_cast<int>(n_invocation_tokens) - 1;
+ for (int i = slot.prompt_tokens.size() - 1; i >= 0; --i) {
+ // the token in this position matches the next token to find in
+ // the invocation sequence
+ if (slot.prompt_tokens[i] == invocation_tokens[match_idx]) {
+ // if it's a full match, we've found the start
+ if (match_idx == 0) {
+ alora_invocation_start = i;
+ break;
+ }
+ // otherwise, check the next token in the sequence
+ --match_idx;
+ } else {
+ // no match in this position, so start looking over again
+ match_idx = static_cast<int>(n_invocation_tokens) - 1;
+ }
+ }
+
+ // if the activation string is not found, disable the alora
+ if (alora_invocation_start == slot.prompt_tokens.size()) {
+ SLT_DBG(slot, "alora %zu requested, but not found. deactivating\n", enabled_ids[0]);
+ slot.lora[enabled_ids[0]].scale = 0.0f;
+ } else {
+ SLT_DBG(slot, "alora %zu activated starting at %zu\n", enabled_ids[0], alora_invocation_start);
+ slot.alora_invocation_start = alora_invocation_start;
+ }
+ }
+
if (!slot.prompt_tokens.validate(ctx)) {
send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
return false;
int32_t n_ubatch = llama_n_ubatch(ctx);
// next, batch any pending prompts without exceeding n_batch
+ float alora_scale = -1.0f;
+ size_t alora_disabled_id = 0;
if (params_base.cont_batching || batch.n_tokens == 0) {
for (auto & slot : slots) {
// check if we can batch this slot with the previous one
// reuse any previously computed tokens that are common with the new prompt
slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens);
+ // if there is an alora invoked, don't cache after the invocation start
+ if (slot.alora_invocation_start >= 0) {
+ SLT_DBG(slot, "only caching to alora invocation start (n_past=%d, alora_invocation_start=%d)\n", slot.n_past, slot.alora_invocation_start);
+ slot.n_past = std::min(slot.n_past, slot.alora_invocation_start - 1);
+ }
+
// reuse chunks from the cached prompt by shifting their KV cache in the new position
if (params_base.n_cache_reuse > 0) {
size_t head_c = slot.n_past; // cache
slot.n_prompt_tokens_processed += n_pos;
}
+ // If using an alora, there may be uncached tokens that come
+ // before the invocation sequence. When this happens, the
+ // tokens before the invocation sequence need to be
+ // processed without the adpter in a separate batch, then
+ // the adapter needs to be enabled for the remaining tokens.
+ if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.n_past) {
+ SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_past = %d, alora_invocation_start = %d)\n", slot.n_past, slot.alora_invocation_start);
+ const auto & enabled_loras = lora_get_enabled_ids(slot.lora);
+ GGML_ASSERT(enabled_loras.size() == 1);
+ alora_scale = slot.lora[enabled_loras[0]].scale;
+ slot.lora[enabled_loras[0]].scale = 0.0f;
+ alora_disabled_id = enabled_loras[0];
+ }
+
// add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
// get next token to process
break; // end of text chunk
}
+ // if this is an alora request with pre-invocation
+ // tokens that are not cached, we need to stop filling
+ // this batch at those pre-invocation tokens.
+ if (alora_scale > 0 && slot.n_past == slot.alora_invocation_start - 1) {
+ SLT_DBG(slot, "stop prompt batch filling at (n_past = %d, alora_invocation_start = %d)\n", slot.n_past, slot.alora_invocation_start);
+ break;
+ }
+
// embedding requires all tokens in the batch to be output
const bool need_embd = server_task_type_need_embd(slot.task_type);
// apply lora, only need to do it once per batch
common_set_adapter_lora(ctx, slot_batched->lora);
+ // if the lora is temporarily disabled for an alora, re-enable it
+ // for next time
+ if (alora_scale > 0.0f) {
+ SRV_DBG("re-enabling alora with scale %f\n", alora_scale);
+ slot_batched->lora[alora_disabled_id].scale = alora_scale;
+ }
+
llama_set_embeddings(ctx, slot_batched->need_embd());
}
const auto & loras = ctx_server.params_base.lora_adapters;
for (size_t i = 0; i < loras.size(); ++i) {
auto & lora = loras[i];
- result.push_back({
+ json entry = {
{"id", i},
{"path", lora.path},
{"scale", lora.scale},
{"task_name", lora.task_name},
{"prompt_prefix", lora.prompt_prefix},
- });
+ };
+ std::string alora_invocation_string = "";
+ const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(lora.ptr);
+ std::vector<llama_token> alora_invocation_tokens;
+ if (n_alora_tokens) {
+ const llama_token * alora_tokens = llama_adapter_get_alora_invocation_tokens(lora.ptr);
+ for (uint64_t i = 0; i < n_alora_tokens; ++i) {
+ alora_invocation_string += common_token_to_piece(ctx_server.ctx, alora_tokens[i]);
+ alora_invocation_tokens.push_back(alora_tokens[i]);
+ }
+ entry["alora_invocation_string"] = alora_invocation_string;
+ entry["alora_invocation_tokens"] = alora_invocation_tokens;
+ }
+ result.push_back(std::move(entry));
}
res_ok(res, result);
res.status = 200; // HTTP OK