]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
talk-llama : sync llama.cpp
authorGeorgi Gerganov <redacted>
Thu, 8 Aug 2024 11:16:50 +0000 (14:16 +0300)
committerGeorgi Gerganov <redacted>
Thu, 8 Aug 2024 19:48:46 +0000 (22:48 +0300)
14 files changed:
Makefile
examples/talk-llama/CMakeLists.txt
examples/talk-llama/llama-grammar.cpp [new file with mode: 0644]
examples/talk-llama/llama-grammar.h [new file with mode: 0644]
examples/talk-llama/llama-impl.h [new file with mode: 0644]
examples/talk-llama/llama-sampling.cpp [new file with mode: 0644]
examples/talk-llama/llama-sampling.h [new file with mode: 0644]
examples/talk-llama/llama-vocab.cpp [new file with mode: 0644]
examples/talk-llama/llama-vocab.h [new file with mode: 0644]
examples/talk-llama/llama.cpp
examples/talk-llama/llama.h
examples/talk-llama/unicode.cpp
examples/talk-llama/unicode.h
scripts/sync-llama.sh

index 5644a7469c35589fcd560614411cf0eb5a1fefa6..d8ef07c8a240093fa2f17f562d6b3c94f7c461f2 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -785,7 +785,8 @@ OBJ_GGML += \
        ggml/src/ggml.o \
        ggml/src/ggml-alloc.o \
        ggml/src/ggml-backend.o \
-       ggml/src/ggml-quants.o
+       ggml/src/ggml-quants.o \
+       ggml/src/ggml-aarch64.o
 
 OBJ_WHISPER += \
        src/whisper.o
@@ -916,6 +917,13 @@ ggml/src/ggml-quants.o: \
        ggml/src/ggml-common.h
        $(CC) $(CFLAGS)    -c $< -o $@
 
+ggml/src/ggml-aarch64.o: \
+       ggml/src/ggml-aarch64.c \
+       ggml/include/ggml.h \
+       ggml/src/ggml-aarch64.h \
+       ggml/src/ggml-common.h
+       $(CC) $(CFLAGS)    -c $< -o $@
+
 ggml/src/ggml-blas.o: \
        ggml/src/ggml-blas.cpp \
        ggml/include/ggml-blas.h
@@ -1076,7 +1084,7 @@ talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp \
        $(CXX) $(CXXFLAGS) $(CFLAGS_SDL) -c $< -o $(call GET_OBJ_FILE, $<)
        $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LDFLAGS_SDL)
 
-talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp examples/talk-llama/unicode.cpp examples/talk-llama/unicode-data.cpp \
+talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp examples/talk-llama/llama-vocab.cpp examples/talk-llama/llama-grammar.cpp examples/talk-llama/llama-sampling.cpp examples/talk-llama/unicode.cpp examples/talk-llama/unicode-data.cpp \
        $(OBJ_GGML) $(OBJ_WHISPER) $(OBJ_COMMON) $(OBJ_SDL)
        $(CXX) $(CXXFLAGS) $(CFLAGS_SDL) -c $< -o $(call GET_OBJ_FILE, $<)
        $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LDFLAGS_SDL)
index f95ec372754b115df923b28d826fa260fc5bdfd6..56b4d0d75fe0d29cf6938ec8f6ebfc9c27080627 100644 (file)
@@ -1,7 +1,13 @@
 if (WHISPER_SDL2)
     # talk-llama
     set(TARGET talk-llama)
-    add_executable(${TARGET} talk-llama.cpp llama.cpp unicode.cpp unicode-data.cpp)
+    add_executable(${TARGET} talk-llama.cpp
+        llama.cpp
+        llama-vocab.cpp
+        llama-grammar.cpp
+        llama-sampling.cpp
+        unicode.cpp
+        unicode-data.cpp)
     target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
 
     if (WHISPER_CLBLAST)
diff --git a/examples/talk-llama/llama-grammar.cpp b/examples/talk-llama/llama-grammar.cpp
new file mode 100644 (file)
index 0000000..b123d73
--- /dev/null
@@ -0,0 +1,539 @@
+#include "llama-grammar.h"
+
+#include "llama-vocab.h"
+#include "llama-sampling.h"
+
+#include <algorithm>
+
+// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
+// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
+std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
+        const std::string & src,
+        llama_partial_utf8 partial_start) {
+    static const int      lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
+    const char          * pos      = src.c_str();
+    std::vector<uint32_t> code_points;
+
+    // common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0.
+    code_points.reserve(src.size() + 1);
+    uint32_t value    = partial_start.value;
+    int      n_remain = partial_start.n_remain;
+
+    // continue previous decode, if applicable
+    while (*pos != 0 && n_remain > 0) {
+        uint8_t next_byte = static_cast<uint8_t>(*pos);
+        if ((next_byte >> 6) != 2) {
+            // invalid sequence, abort
+            code_points.push_back(0);
+            return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 });
+        }
+        value = (value << 6) + (next_byte & 0x3F);
+        ++pos;
+        --n_remain;
+    }
+
+    if (partial_start.n_remain > 0 && n_remain == 0) {
+        code_points.push_back(value);
+    }
+
+    // decode any subsequent utf-8 sequences, which may end in an incomplete one
+    while (*pos != 0) {
+        uint8_t first_byte = static_cast<uint8_t>(*pos);
+        uint8_t highbits   = first_byte >> 4;
+                n_remain   = lookup[highbits] - 1;
+
+        if (n_remain < 0) {
+            // invalid sequence, abort
+            code_points.clear();
+            code_points.push_back(0);
+            return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain });
+        }
+
+        uint8_t mask  = (1 << (7 - n_remain)) - 1;
+                value = first_byte & mask;
+
+        ++pos;
+        while (*pos != 0 && n_remain > 0) {
+            value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
+            ++pos;
+            --n_remain;
+        }
+        if (n_remain == 0) {
+            code_points.push_back(value);
+        }
+    }
+    code_points.push_back(0);
+
+    return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
+}
+
+const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
+    return grammar->rules;
+}
+
+llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
+    return grammar->stacks;
+}
+
+// returns true iff pos points to the end of one of the definitions of a rule
+static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) {
+    switch (pos->type) {
+        case LLAMA_GRETYPE_END: return true;  // NOLINT
+        case LLAMA_GRETYPE_ALT: return true;  // NOLINT
+        default:                return false;
+    }
+}
+
+// returns true iff chr satisfies the char range at pos (regular or inverse range)
+// asserts that pos is pointing to a char range element
+static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
+        const llama_grammar_element * pos,
+        const uint32_t                chr) {
+
+    bool found            = false;
+    bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
+
+    GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT
+
+    do {
+        if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
+            // inclusive range, e.g. [a-z]
+            found = found || (pos->value <= chr && chr <= pos[1].value);
+            pos += 2;
+        } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
+            // Any character matches "."
+            found = true;
+            pos += 1;
+        } else {
+            // exact char match, e.g. [a] or "a"
+            found = found || pos->value == chr;
+            pos += 1;
+        }
+    } while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
+
+    return std::make_pair(found == is_positive_char, pos);
+}
+
+// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
+// range at pos (regular or inverse range)
+// asserts that pos is pointing to a char range element
+static bool llama_grammar_match_partial_char(
+        const llama_grammar_element * pos,
+        const llama_partial_utf8      partial_utf8) {
+    bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
+    GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);
+
+    uint32_t partial_value = partial_utf8.value;
+    int      n_remain      = partial_utf8.n_remain;
+
+    // invalid sequence or 7-bit char split across 2 bytes (overlong)
+    if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
+        return false;
+    }
+
+    // range of possible code points this partial UTF-8 sequence could complete to
+    uint32_t low  = partial_value << (n_remain * 6);
+    uint32_t high = low | ((1 << (n_remain * 6)) - 1);
+
+    if (low == 0) {
+        if (n_remain == 2) {
+            low = 1 << 11;
+        } else if (n_remain == 3) {
+            low = 1 << 16;
+        }
+    }
+
+    do {
+        if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
+            // inclusive range, e.g. [a-z]
+            if (pos->value <= high && low <= pos[1].value) {
+                return is_positive_char;
+            }
+            pos += 2;
+        } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
+            // Any character matches "."
+            return true;
+        } else {
+            // exact char match, e.g. [a] or "a"
+            if (low <= pos->value && pos->value <= high) {
+                return is_positive_char;
+            }
+            pos += 1;
+        }
+    } while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
+
+    return !is_positive_char;
+}
+
+// transforms a grammar pushdown stack into N possible stacks, all ending
+// at a character range (terminal element)
+static void llama_grammar_advance_stack(
+        const llama_grammar_rules  & rules,
+        const llama_grammar_stack  & stack,
+              llama_grammar_stacks & new_stacks) {
+    if (stack.empty()) {
+        if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
+            new_stacks.emplace_back(stack);
+        }
+        return;
+    }
+
+    const llama_grammar_element * pos = stack.back();
+
+    switch (pos->type) {
+        case LLAMA_GRETYPE_RULE_REF: {
+            const size_t                  rule_id = static_cast<size_t>(pos->value);
+            const llama_grammar_element * subpos  = rules[rule_id].data();
+            do {
+                // init new stack without the top (pos)
+                llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
+                if (!llama_grammar_is_end_of_sequence(pos + 1)) {
+                    // if this rule ref is followed by another element, add that to stack
+                    new_stack.push_back(pos + 1);
+                }
+                if (!llama_grammar_is_end_of_sequence(subpos)) {
+                    // if alternate is nonempty, add to stack
+                    new_stack.push_back(subpos);
+                }
+                llama_grammar_advance_stack(rules, new_stack, new_stacks);
+                while (!llama_grammar_is_end_of_sequence(subpos)) {
+                    // scan to end of alternate def
+                    subpos++;
+                }
+                if (subpos->type == LLAMA_GRETYPE_ALT) {
+                    // there's another alternate def of this rule to process
+                    subpos++;
+                } else {
+                    break;
+                }
+            } while (true);
+            break;
+        }
+        case LLAMA_GRETYPE_CHAR:
+        case LLAMA_GRETYPE_CHAR_NOT:
+        case LLAMA_GRETYPE_CHAR_ANY:
+            if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
+                // only add the stack if it's not a duplicate of one we already have
+                new_stacks.emplace_back(stack);
+            }
+            break;
+        default:
+            // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
+            // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
+            // those
+            GGML_ABORT("fatal error");
+    }
+}
+
+// takes a set of possible pushdown stacks on a grammar, which are required to
+// be positioned at a character range (see `llama_grammar_advance_stack`), and
+// produces the N possible stacks if the given char is accepted at those
+// positions
+void llama_grammar_accept(
+        const llama_grammar_rules  & rules,
+        const llama_grammar_stacks & stacks,
+        const uint32_t               chr,
+              llama_grammar_stacks & new_stacks) {
+    new_stacks.clear();
+
+    for (const auto & stack : stacks) {
+        if (stack.empty()) {
+            continue;
+        }
+
+        auto match = llama_grammar_match_char(stack.back(), chr);
+        if (match.first) {
+            const llama_grammar_element * pos = match.second;
+
+            // update top of stack to next element, if any
+            llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
+            if (!llama_grammar_is_end_of_sequence(pos)) {
+                new_stack.push_back(pos);
+            }
+            llama_grammar_advance_stack(rules, new_stack, new_stacks);
+        }
+    }
+}
+
+static llama_grammar_candidates llama_grammar_reject_candidates(
+        const llama_grammar_rules  & rules,
+        const llama_grammar_stacks & stacks,
+        const llama_grammar_candidates & candidates) {
+    GGML_ASSERT(!stacks.empty()); // REVIEW
+
+    if (candidates.empty()) {
+        return {};
+    }
+
+    auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
+
+    for (size_t i = 1, size = stacks.size(); i < size; ++i) {
+        rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
+    }
+    return rejects;
+}
+
+llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
+        const llama_grammar_rules      & rules,
+        const llama_grammar_stack      & stack,
+        const llama_grammar_candidates & candidates) {
+
+    llama_grammar_candidates rejects;
+    rejects.reserve(candidates.size());
+
+    if (stack.empty()) {
+        for (const auto & tok : candidates) {
+            if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
+                rejects.push_back(tok);
+            }
+        }
+        return rejects;
+    }
+
+    const llama_grammar_element * stack_pos = stack.back();
+
+    llama_grammar_candidates next_candidates;
+    next_candidates.reserve(candidates.size());
+
+    for (const auto & tok : candidates) {
+        if (*tok.code_points == 0) {
+            // reached end of full codepoints in token, reject iff it ended in a partial sequence
+            // that cannot satisfy this position in grammar
+            if (tok.partial_utf8.n_remain != 0 &&
+                    !llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
+                rejects.push_back(tok);
+            }
+        } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
+            next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 });
+        } else {
+            rejects.push_back(tok);
+        }
+    }
+
+    const auto * stack_pos_after = llama_grammar_match_char(stack_pos, 0).second;
+
+    // update top of stack to next element, if any
+    llama_grammar_stack stack_after(stack.begin(), stack.end() - 1);
+    if (!llama_grammar_is_end_of_sequence(stack_pos_after)) {
+        stack_after.push_back(stack_pos_after);
+    }
+    llama_grammar_stacks next_stacks;
+    llama_grammar_advance_stack(rules, stack_after, next_stacks);
+
+    auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
+    for (const auto & tok : next_rejects) {
+        rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
+    }
+
+    return rejects;
+}
+
+static bool llama_grammar_detect_left_recursion(
+        const llama_grammar_rules & rules,
+        size_t rule_index,
+        std::vector<bool> * rules_visited,
+        std::vector<bool> * rules_in_progress,
+        std::vector<bool> * rules_may_be_empty) {
+    if ((*rules_in_progress)[rule_index]) {
+        return true;
+    }
+
+    (*rules_in_progress)[rule_index] = true;
+
+    const llama_grammar_rule & rule = rules[rule_index];
+
+    // First check if the rule might produce the empty string. This could be done combined with the second
+    // step but it's more readable as two steps.
+    bool at_rule_start = true;
+    for (size_t i = 0; i < rule.size(); i++) {
+        if (llama_grammar_is_end_of_sequence(&rule[i])) {
+            if (at_rule_start) {
+                (*rules_may_be_empty)[rule_index] = true;
+                break;
+            }
+            at_rule_start = true;
+        } else {
+            at_rule_start = false;
+        }
+    }
+
+    // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
+    // be empty)
+    bool recurse_into_nonterminal = true;
+    for (size_t i = 0; i < rule.size(); i++) {
+        if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
+            if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
+                return true;
+            }
+            if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
+                recurse_into_nonterminal = false;
+            }
+        } else if (llama_grammar_is_end_of_sequence(&rule[i])) {
+            recurse_into_nonterminal = true;
+        } else {
+            recurse_into_nonterminal = false;
+        }
+    }
+
+    (*rules_in_progress)[rule_index] = false;
+    (*rules_visited)[rule_index] = true;
+    return false;
+}
+
+//
+// grammar - external
+//
+
+struct llama_grammar * llama_grammar_init_impl(
+            const llama_grammar_element ** rules,
+                                 size_t    n_rules,
+                                 size_t    start_rule_index) {
+    const llama_grammar_element * pos;
+
+    // copy rule definitions into vectors
+    llama_grammar_rules vec_rules(n_rules);
+    for (size_t i = 0; i < n_rules; i++) {
+        for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
+            vec_rules[i].push_back(*pos);
+        }
+        vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
+    }
+
+    // Check for left recursion
+    std::vector<bool> rules_visited(n_rules);
+    std::vector<bool> rules_in_progress(n_rules);
+    std::vector<bool> rules_may_be_empty(n_rules);
+    for (size_t i = 0; i < n_rules; i++) {
+        if (rules_visited[i]) {
+            continue;
+        }
+        if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
+            LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
+            return nullptr;
+        }
+    }
+
+    // loop over alternates of start rule to build initial stacks
+    llama_grammar_stacks stacks;
+    pos = vec_rules[start_rule_index].data();
+    do {
+        llama_grammar_stack stack;
+        if (!llama_grammar_is_end_of_sequence(pos)) {
+            // if alternate is nonempty, add to stack
+            stack.push_back(pos);
+        }
+        llama_grammar_advance_stack(vec_rules, stack, stacks);
+        while (!llama_grammar_is_end_of_sequence(pos)) {
+            // scan to end of alternate def
+            pos++;
+        }
+        if (pos->type == LLAMA_GRETYPE_ALT) {
+            // there's another alternate def of this rule to process
+            pos++;
+        } else {
+            break;
+        }
+    } while (true);
+
+    // Important: vec_rules has to be moved here, not copied, because stacks contains
+    // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
+    // then the pointers would be invalidated when the local vec_rules goes out of scope.
+    return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
+}
+
+void llama_grammar_free_impl(struct llama_grammar * grammar) {
+    delete grammar;
+}
+
+struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) {
+    llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 };
+
+    // redirect elements in stacks to point to new rules
+    for (size_t is = 0; is < result->stacks.size(); is++) {
+        for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
+            for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) {
+                for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) {
+                    if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) {
+                         result->stacks[is][ie]  =  &result->rules[ir0][ir1];
+                    }
+                }
+            }
+        }
+    }
+
+    return result;
+}
+
+void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) {
+    GGML_ASSERT(grammar);
+    GGML_ASSERT(vocab);
+
+    int64_t t_start_sample_us = ggml_time_us();
+
+    bool allow_eog = false;
+    for (const auto & stack : grammar->stacks) {
+        if (stack.empty()) {
+            allow_eog = true;
+            break;
+        }
+    }
+
+    std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
+    candidates_decoded.reserve(candidates->size);
+
+    llama_grammar_candidates candidates_grammar;
+    candidates_grammar.reserve(candidates->size);
+
+    for (size_t i = 0; i < candidates->size; ++i) {
+        const llama_token id      = candidates->data[i].id;
+        const std::string & piece = vocab->cache_token_to_piece.at(id);
+
+        if (llama_token_is_eog_impl(*vocab, id)) {
+            if (!allow_eog) {
+                candidates->data[i].logit = -INFINITY;
+            }
+        } else if (piece.empty() || piece[0] == 0) {
+            candidates->data[i].logit = -INFINITY;
+        } else {
+            candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
+            candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
+        }
+    }
+
+    const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar);
+    for (const auto & reject : rejects) {
+        candidates->data[reject.index].logit = -INFINITY;
+    }
+
+    smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+}
+
+void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) {
+    const int64_t t_start_sample_us = ggml_time_us();
+
+    if (llama_token_is_eog_impl(*vocab, token)) {
+        for (const auto & stack : grammar->stacks) {
+            if (stack.empty()) {
+                return;
+            }
+        }
+        GGML_ABORT("fatal error");
+    }
+
+    const std::string & piece = vocab->cache_token_to_piece.at(token);
+
+    // Note terminating 0 in decoded string
+    const auto   decoded     = decode_utf8(piece, grammar->partial_utf8);
+    const auto & code_points = decoded.first;
+
+    llama_grammar_stacks tmp_new_stacks;
+    for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
+        llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks);
+        grammar->stacks = tmp_new_stacks;
+    }
+
+    grammar->partial_utf8 = decoded.second;
+    GGML_ASSERT(!grammar->stacks.empty());
+
+    smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+}
diff --git a/examples/talk-llama/llama-grammar.h b/examples/talk-llama/llama-grammar.h
new file mode 100644 (file)
index 0000000..695ea06
--- /dev/null
@@ -0,0 +1,39 @@
+#pragma once
+
+#include "llama-impl.h"
+
+struct llama_vocab;
+struct llama_sampling;
+
+struct llama_grammar {
+    const llama_grammar_rules  rules;
+          llama_grammar_stacks stacks;
+
+    // buffer for partially generated UTF-8 sequence from accepted tokens
+    llama_partial_utf8 partial_utf8;
+};
+
+//
+// internal API
+//
+
+struct llama_grammar * llama_grammar_init_impl(
+            const llama_grammar_element ** rules,
+                                 size_t    n_rules,
+                                 size_t    start_rule_index);
+
+void llama_grammar_free_impl(struct llama_grammar * grammar);
+
+struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar);
+
+void llama_grammar_sample_impl(
+        const struct llama_grammar * grammar,
+          const struct llama_vocab * vocab,
+       const struct llama_sampling * smpl,
+            llama_token_data_array * candidates);
+
+void llama_grammar_accept_token_impl(
+              struct llama_grammar * grammar,
+          const struct llama_vocab * vocab,
+       const struct llama_sampling * smpl,
+                       llama_token   token);
diff --git a/examples/talk-llama/llama-impl.h b/examples/talk-llama/llama-impl.h
new file mode 100644 (file)
index 0000000..dcc8c1c
--- /dev/null
@@ -0,0 +1,26 @@
+#pragma once
+
+#define LLAMA_API_INTERNAL
+#include "llama.h"
+
+#ifdef __GNUC__
+#ifdef __MINGW32__
+#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
+#else
+#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
+#endif
+#else
+#define LLAMA_ATTRIBUTE_FORMAT(...)
+#endif
+
+//
+// logging
+//
+
+LLAMA_ATTRIBUTE_FORMAT(2, 3)
+void llama_log_internal        (ggml_log_level level, const char * format, ...);
+void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data);
+
+#define LLAMA_LOG_INFO(...)  llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
+#define LLAMA_LOG_WARN(...)  llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
+#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
diff --git a/examples/talk-llama/llama-sampling.cpp b/examples/talk-llama/llama-sampling.cpp
new file mode 100644 (file)
index 0000000..8910f6d
--- /dev/null
@@ -0,0 +1,635 @@
+#include "llama-sampling.h"
+
+#include <algorithm>
+#include <cstring>
+#include <ctime>
+#include <cfloat>
+#include <numeric>
+#include <unordered_map>
+
+static void llama_log_softmax(float * array, size_t size) {
+    float max_l = *std::max_element(array, array + size);
+    float sum = 0.f;
+    for (size_t i = 0; i < size; ++i) {
+        float p = expf(array[i] - max_l);
+        sum += p;
+        array[i] = p;
+    }
+
+    for (size_t i = 0; i < size; ++i) {
+        array[i] = logf(array[i] / sum);
+    }
+}
+
+void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) {
+    if (seed == LLAMA_DEFAULT_SEED) {
+        seed = time(NULL);
+    }
+
+    smpl->rng.seed(seed);
+}
+
+void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
+    GGML_ASSERT(candidates->size > 0);
+
+    const int64_t t_start_sample_us = ggml_time_us();
+
+    // Sort the logits in descending order
+    if (!candidates->sorted) {
+        std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
+            return a.logit > b.logit;
+        });
+        candidates->sorted = true;
+    }
+
+    float max_l = candidates->data[0].logit;
+    float cum_sum = 0.0f;
+    for (size_t i = 0; i < candidates->size; ++i) {
+        float p = expf(candidates->data[i].logit - max_l);
+        candidates->data[i].p = p;
+        cum_sum += p;
+    }
+    for (size_t i = 0; i < candidates->size; ++i) {
+        candidates->data[i].p /= cum_sum;
+    }
+
+    if (smpl) {
+        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+    }
+}
+
+void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
+    // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
+    // if (k >= (int32_t)candidates->size) {
+    //     return;
+    // }
+
+    const int64_t t_start_sample_us = ggml_time_us();
+
+    if (k <= 0) {
+        k = candidates->size;
+    }
+
+    k = std::max(k, (int) min_keep);
+    k = std::min(k, (int) candidates->size);
+
+    // Sort scores in descending order
+    if (!candidates->sorted) {
+        auto comp = [](const llama_token_data & a, const llama_token_data & b) {
+            return a.logit > b.logit;
+        };
+        if (k <= 128) {
+            std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
+        } else {
+            constexpr int   nbuckets     = 128;
+            constexpr float bucket_low   = -10.0f;
+            constexpr float bucket_high  =  10.0f;
+            constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
+            constexpr float bucker_inter = -bucket_low * bucket_scale;
+
+            std::vector<int> bucket_idx(candidates->size);
+            std::vector<int> histo(nbuckets, 0);
+
+            for (int i = 0; i < (int)candidates->size; ++i) {
+                const float val = candidates->data[i].logit;
+                int ib = int(bucket_scale * val + bucker_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
+                ib = std::max(0, std::min(nbuckets-1, ib));
+                bucket_idx[i] = ib;
+                ++histo[ib];
+            }
+            int nhave = 0;
+            int ib = nbuckets - 1;
+            for ( ; ib >= 0; --ib) {
+                nhave += histo[ib];
+                if (nhave >= k) break;
+            }
+            std::vector<llama_token_data> tmp_tokens(nhave);
+            auto ptr = tmp_tokens.data();
+            std::vector<llama_token_data*> bucket_ptrs;
+            bucket_ptrs.reserve(nbuckets - ib);
+            for (int j = nbuckets - 1; j >= ib; --j) {
+                bucket_ptrs.push_back(ptr);
+                ptr += histo[j];
+            }
+            for (int i = 0; i < (int)candidates->size; ++i) {
+                int j = bucket_idx[i];
+                if (j >= ib) {
+                    *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i];
+                }
+            }
+
+            ptr = tmp_tokens.data();
+            int ndone = 0;
+            for (int j = nbuckets-1; j > ib; --j) {
+                std::sort(ptr, ptr + histo[j], comp);
+                ptr += histo[j];
+                ndone += histo[j];
+            }
+            std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
+
+            std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data));
+
+        }
+        candidates->sorted = true;
+    }
+    candidates->size = k;
+
+    if (smpl) {
+        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+    }
+}
+
+void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
+    if (p >= 1.0f) {
+        return;
+    }
+
+    llama_sample_softmax_impl(smpl, candidates);
+
+    const int64_t t_start_sample_us = ggml_time_us();
+
+    // Compute the cumulative probabilities
+    float cum_sum = 0.0f;
+    size_t last_idx = candidates->size;
+
+    for (size_t i = 0; i < candidates->size; ++i) {
+        cum_sum += candidates->data[i].p;
+
+        // Check if the running sum is at least p or if we have kept at least min_keep tokens
+        // we set the last index to i+1 to indicate that the current iterate should be included in the set
+        if (cum_sum >= p && i + 1 >= min_keep) {
+            last_idx = i + 1;
+            break;
+        }
+    }
+
+    // Resize the output vector to keep only the top-p tokens
+    candidates->size = last_idx;
+
+    if (smpl) {
+        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+    }
+}
+
+void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
+    if (p <= 0.0f || !candidates->size) {
+        return;
+    }
+
+    const int64_t t_start_sample_us = ggml_time_us();
+
+    bool min_p_applied = false;
+
+    // if the candidates aren't sorted, try the unsorted implementation first
+    if (!candidates->sorted) {
+        std::vector<llama_token_data> filtered_tokens;
+
+        float max_logit = -FLT_MAX;
+        for (size_t i = 0; i < candidates->size; ++i) {
+            max_logit = std::max(max_logit, candidates->data[i].logit);
+        }
+        const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
+
+        for (size_t i = 0; i < candidates->size; ++i) {
+            if (candidates->data[i].logit >= min_logit) {
+                filtered_tokens.push_back(candidates->data[i]);
+            }
+        }
+
+        // if we have enough values the operation was a success
+        if (filtered_tokens.size() >= min_keep) {
+            memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
+            candidates->size = filtered_tokens.size();
+            min_p_applied = true;
+        }
+    }
+
+    // if the candidates are sorted or the unsorted implementation failed, use this implementation
+    if (!min_p_applied) {
+        // Sort the logits in descending order
+        if (!candidates->sorted) {
+            std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
+                return a.logit > b.logit;
+            });
+            candidates->sorted = true;
+        }
+
+        const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max
+        size_t i = 1; // first token always matches
+
+        for (; i < candidates->size; ++i) {
+            if (candidates->data[i].logit < min_logit && i >= min_keep) {
+                break; // prob too small
+            }
+        }
+
+        // Resize the output vector to keep only the matching tokens
+        candidates->size = i;
+    }
+
+    if (smpl) {
+        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+    }
+}
+
+void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
+    if (z >= 1.0f || candidates->size <= 2) {
+        return;
+    }
+
+    llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
+    const int64_t t_start_sample_us = ggml_time_us();
+
+    // Compute the first and second derivatives
+    std::vector<float> first_derivatives(candidates->size - 1);
+    std::vector<float> second_derivatives(candidates->size - 2);
+
+    for (size_t i = 0; i < first_derivatives.size(); ++i) {
+        first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p;
+    }
+    for (size_t i = 0; i < second_derivatives.size(); ++i) {
+        second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
+    }
+
+    // Calculate absolute value of second derivatives
+    for (size_t i = 0; i < second_derivatives.size(); ++i) {
+        second_derivatives[i] = std::abs(second_derivatives[i]);
+    }
+
+    // Normalize the second derivatives
+    {
+        const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
+
+        if (second_derivatives_sum > 1e-6f) {
+            for (float & value : second_derivatives) {
+                value /= second_derivatives_sum;
+            }
+        } else {
+            for (float & value : second_derivatives) {
+                value = 1.0f / second_derivatives.size();
+            }
+        }
+    }
+
+    float cum_sum = 0.0f;
+    size_t last_idx = candidates->size;
+    for (size_t i = 0; i < second_derivatives.size(); ++i) {
+        cum_sum += second_derivatives[i];
+
+        // Check if the running sum is greater than z or if we have kept at least min_keep tokens
+        if (cum_sum > z && i >= min_keep) {
+            last_idx = i;
+            break;
+        }
+    }
+
+    // Resize the output vector to keep only the tokens above the tail location
+    candidates->size = last_idx;
+
+    if (smpl) {
+        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+    }
+}
+
+void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
+    // Reference implementation:
+    // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
+    if (p >= 1.0f) {
+        return;
+    }
+
+    // Compute the softmax of logits and calculate entropy
+    llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
+
+    const int64_t t_start_sample_us = ggml_time_us();
+
+    float entropy = 0.0f;
+    for (size_t i = 0; i < candidates->size; ++i) {
+        entropy += -candidates->data[i].p * logf(candidates->data[i].p);
+    }
+
+    // Compute the absolute difference between negative log probability and entropy for each candidate
+    std::vector<float> shifted_scores;
+    for (size_t i = 0; i < candidates->size; ++i) {
+        float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy);
+        shifted_scores.push_back(shifted_score);
+    }
+
+    // Sort tokens based on the shifted_scores and their corresponding indices
+    std::vector<size_t> indices(candidates->size);
+    std::iota(indices.begin(), indices.end(), 0);
+
+    std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
+        return shifted_scores[a] < shifted_scores[b];
+    });
+
+    // Compute the cumulative probabilities
+    float cum_sum = 0.0f;
+    size_t last_idx = indices.size();
+
+    for (size_t i = 0; i < indices.size(); ++i) {
+        size_t idx = indices[i];
+        cum_sum += candidates->data[idx].p;
+
+        // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
+        if (cum_sum > p && i >= min_keep - 1) {
+            last_idx = i + 1;
+            break;
+        }
+    }
+
+    // Resize the output vector to keep only the locally typical tokens
+    std::vector<llama_token_data> new_candidates;
+    for (size_t i = 0; i < last_idx; ++i) {
+        size_t idx = indices[i];
+        new_candidates.push_back(candidates->data[idx]);
+    }
+
+    // Replace the data in candidates with the new_candidates data
+    std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
+    candidates->size = new_candidates.size();
+    candidates->sorted = false;
+
+    if (smpl) {
+        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+    }
+}
+
+void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
+    const int64_t t_start_sample_us = ggml_time_us();
+
+    // no need to do anything if there is only one (or zero) candidates
+    if(candidates->size <= 1) {
+        return;
+    }
+
+    // Calculate maximum possible entropy
+    float max_entropy = -logf(1.0f / candidates->size);
+
+    llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
+
+    // Calculate entropy of the softmax probabilities
+    float entropy = 0.0f;
+    for (size_t i = 0; i < candidates->size; ++i) {
+        float prob = candidates->data[i].p;
+        if (prob > 0.0f) { // Ensure no log(0)
+            entropy -= prob * logf(prob);
+        }
+    }
+
+    // Normalize the entropy (max_entropy cannot be 0 here because we checked candidates->size != 1 above)
+    float normalized_entropy = entropy / max_entropy;
+
+    // Map the normalized entropy to the desired temperature range using the power function
+    float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
+
+#ifdef DEBUG
+    LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
+    LLAMA_LOG_INFO("Entropy: %f\n", entropy);
+    LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
+    LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
+    LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
+    LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
+#endif
+
+    // Apply the dynamically calculated temperature scaling
+    for (size_t i = 0; i < candidates->size; ++i) {
+        candidates->data[i].logit /= dyn_temp;
+    }
+
+    // Re-compute softmax probabilities after scaling logits with dynamic temperature
+    double max_l_double = candidates->data[0].logit;
+    double cum_sum_double = 0.0;
+    for (size_t i = 0; i < candidates->size; ++i) {
+        double p = exp(candidates->data[i].logit - max_l_double);
+        candidates->data[i].p = p; // Store the scaled probability
+        cum_sum_double += p;
+    }
+    for (size_t i = 0; i < candidates->size; ++i) {
+        candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities
+    }
+
+#ifdef DEBUG
+    // Print the updated top 25 probabilities after temperature scaling
+    LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
+    for (size_t i = 0; i < 25 && i < candidates->size; ++i) {
+        LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f);
+    }
+#endif
+
+    if (smpl) {
+        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+    }
+}
+
+void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) {
+    const int64_t t_start_sample_us = ggml_time_us();
+
+    for (size_t i = 0; i < candidates->size; ++i) {
+        candidates->data[i].logit /= temp;
+    }
+
+    if (smpl) {
+        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+    }
+}
+
+void llama_sample_repetition_penalties_impl(
+        struct llama_sampling * smpl,
+       llama_token_data_array * candidates,
+            const llama_token * last_tokens,
+                       size_t   penalty_last_n,
+                       float   penalty_repeat,
+                       float   penalty_freq,
+                       float   penalty_present) {
+    if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
+        return;
+    }
+
+    const int64_t t_start_sample_us = ggml_time_us();
+
+    // Create a frequency map to count occurrences of each token in last_tokens
+    std::unordered_map<llama_token, int> token_count;
+    for (size_t i = 0; i < penalty_last_n; ++i) {
+        token_count[last_tokens[i]]++;
+    }
+
+    // Apply frequency and presence penalties to the candidates
+    for (size_t i = 0; i < candidates->size; ++i) {
+        const auto token_iter = token_count.find(candidates->data[i].id);
+        if (token_iter == token_count.end()) {
+            continue;
+        }
+
+        const int count = token_iter->second;
+
+        // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
+        // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
+        if (candidates->data[i].logit <= 0) {
+            candidates->data[i].logit *= penalty_repeat;
+        } else {
+            candidates->data[i].logit /= penalty_repeat;
+        }
+
+        candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present;
+    }
+
+    candidates->sorted = false;
+
+    if (smpl) {
+        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+    }
+}
+
+void llama_sample_apply_guidance_impl(
+        struct llama_sampling * smpl,
+                        float * logits,
+                        float * logits_guidance,
+                        float   scale) {
+    GGML_ASSERT(smpl);
+
+    const auto t_start_sample_us = ggml_time_us();
+    const auto n_vocab = smpl->n_vocab;
+
+    llama_log_softmax(logits, n_vocab);
+    llama_log_softmax(logits_guidance, n_vocab);
+
+    for (int i = 0; i < n_vocab; ++i) {
+              auto & l = logits[i];
+        const auto & g = logits_guidance[i];
+
+        l = scale * (l - g) + g;
+    }
+
+    smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+}
+
+llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
+    GGML_ASSERT(smpl);
+
+    const int32_t n_vocab = float(smpl->n_vocab);
+
+    int64_t t_start_sample_us = ggml_time_us();
+
+    llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
+
+    // Estimate s_hat using the most probable m tokens
+    float s_hat = 0.0;
+    float sum_ti_bi = 0.0;
+    float sum_ti_sq = 0.0;
+    for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
+        float t_i = logf(float(i + 2) / float(i + 1));
+        float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
+        sum_ti_bi += t_i * b_i;
+        sum_ti_sq += t_i * t_i;
+    }
+    s_hat = sum_ti_bi / sum_ti_sq;
+
+    // Compute k from the estimated s_hat and target surprise value
+    float epsilon_hat = s_hat - 1;
+    float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
+
+    // Sample the next word X using top-k sampling
+    llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1);
+    smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+    llama_token X = llama_sample_token_impl(smpl, candidates);
+    t_start_sample_us = ggml_time_us();
+
+    // Compute error as the difference between observed surprise and target surprise value
+    size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
+        return candidate.id == X;
+    }));
+    float observed_surprise = -log2f(candidates->data[X_idx].p);
+    float e = observed_surprise - tau;
+
+    // Update mu using the learning rate and error
+    *mu = *mu - eta * e;
+
+    smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+    return X;
+}
+
+llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
+    int64_t t_start_sample_us;
+    t_start_sample_us = ggml_time_us();
+
+    llama_sample_softmax_impl(smpl, candidates);
+
+    // Truncate the words with surprise values greater than mu
+    candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
+        return -log2f(candidate.p) > *mu;
+    }));
+
+    if (candidates->size == 0) {
+        candidates->size = 1;
+    }
+
+    if (smpl) {
+        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+    }
+
+    // Normalize the probabilities of the remaining words
+    llama_sample_softmax_impl(smpl, candidates);
+
+    // Sample the next word X from the remaining words
+    llama_token X = llama_sample_token_impl(smpl, candidates);
+    t_start_sample_us = ggml_time_us();
+
+    // Compute error as the difference between observed surprise and target surprise value
+    size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
+        return candidate.id == X;
+    }));
+    float observed_surprise = -log2f(candidates->data[X_idx].p);
+    float e = observed_surprise - tau;
+
+    // Update mu using the learning rate and error
+    *mu = *mu - eta * e;
+
+    if (smpl) {
+        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+    }
+    return X;
+}
+
+llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
+    const int64_t t_start_sample_us = ggml_time_us();
+
+    // Find max element
+    auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
+        return a.logit < b.logit;
+    });
+
+    llama_token result = max_iter->id;
+    if (smpl) {
+        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+        smpl->n_sample++;
+    }
+    return result;
+}
+
+llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
+    GGML_ASSERT(smpl);
+
+    const int64_t t_start_sample_us = ggml_time_us();
+    llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
+
+    std::vector<float> probs;
+    probs.reserve(candidates->size);
+    for (size_t i = 0; i < candidates->size; ++i) {
+        probs.push_back(candidates->data[i].p);
+    }
+
+    std::discrete_distribution<> dist(probs.begin(), probs.end());
+    int idx = dist(rng);
+
+    llama_token result = candidates->data[idx].id;
+
+    smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+    smpl->n_sample++;
+
+    return result;
+}
+
+llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
+    return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng);
+}
diff --git a/examples/talk-llama/llama-sampling.h b/examples/talk-llama/llama-sampling.h
new file mode 100644 (file)
index 0000000..f7f8e3e
--- /dev/null
@@ -0,0 +1,56 @@
+#pragma once
+
+#include "llama-impl.h"
+
+struct llama_sampling {
+    llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {}
+
+    std::mt19937 rng;
+
+    int32_t n_vocab = 0;
+
+    mutable int64_t t_sample_us = 0;
+    mutable int32_t n_sample = 0;
+
+    void reset_timings() const {
+        t_sample_us = 0;
+        n_sample = 0;
+    }
+};
+
+//
+// internal API
+//
+
+void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed);
+
+void llama_sample_softmax_impl  (struct llama_sampling * smpl, llama_token_data_array * candidates);
+void llama_sample_top_k_impl    (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
+void llama_sample_top_p_impl    (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
+void llama_sample_min_p_impl    (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
+void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep);
+void llama_sample_typical_impl  (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
+void llama_sample_entropy_impl  (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
+void llama_sample_temp_impl     (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp);
+
+void llama_sample_repetition_penalties_impl(
+        struct llama_sampling * smpl,
+       llama_token_data_array * candidates,
+            const llama_token * last_tokens,
+                       size_t   penalty_last_n,
+                        float   penalty_repeat,
+                        float   penalty_freq,
+                        float   penalty_present);
+
+void llama_sample_apply_guidance_impl(
+        struct llama_sampling * smpl,
+                        float * logits,
+                        float * logits_guidance,
+                        float   scale);
+
+llama_token llama_sample_token_mirostat_impl   (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu);
+llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu);
+llama_token llama_sample_token_greedy_impl     (struct llama_sampling * smpl, llama_token_data_array * candidates);
+llama_token llama_sample_token_with_rng_impl   (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
+llama_token llama_sample_token_impl            (struct llama_sampling * smpl, llama_token_data_array * candidates);
+
diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp
new file mode 100644 (file)
index 0000000..e6d6059
--- /dev/null
@@ -0,0 +1,1729 @@
+#include "llama-vocab.h"
+
+#include "unicode.h"
+
+#include <algorithm>
+#include <cassert>
+#include <cfloat>
+#include <climits>
+#include <cstdarg>
+#include <cstring>
+#include <forward_list>
+#include <queue>
+#include <sstream>
+
+//
+// helpers
+//
+
+static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
+    std::string result;
+    for (size_t pos = 0; ; pos += search.length()) {
+        auto new_pos = s.find(search, pos);
+        if (new_pos == std::string::npos) {
+            result += s.substr(pos, s.size() - pos);
+            break;
+        }
+        result += s.substr(pos, new_pos - pos) + replace;
+        pos = new_pos;
+    }
+    s = std::move(result);
+}
+
+LLAMA_ATTRIBUTE_FORMAT(1, 2)
+static std::string format(const char * fmt, ...) {
+    va_list ap;
+    va_list ap2;
+    va_start(ap, fmt);
+    va_copy(ap2, ap);
+    int size = vsnprintf(NULL, 0, fmt, ap);
+    GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
+    std::vector<char> buf(size + 1);
+    int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
+    GGML_ASSERT(size2 == size);
+    va_end(ap2);
+    va_end(ap);
+    return std::string(buf.data(), size);
+}
+
+struct naive_trie {
+    naive_trie() : has_value(false), value(0) {
+    }
+    void insert(const char * key, size_t len, int32_t value = 0) {
+        if (len == 0) {
+            this->has_value = true;
+            this->value = value;
+            return;
+        }
+        char c = key[0];
+        auto res = children.find(c);
+        if (res != children.end()) {
+            res->second.insert(key + 1, len - 1, value);
+        } else {
+            auto res = children.insert(std::make_pair(c, naive_trie()));
+            res.first->second.insert(key + 1, len - 1, value);
+        }
+    }
+    std::pair<const char *, size_t> get_longest_prefix(const char * key, size_t len, size_t offset = 0) {
+        if (len == 0 || offset == len) {
+            return std::make_pair(key, offset);
+        }
+        char c = key[offset];
+        auto res = children.find(c);
+        if (res != children.end()) {
+            return res->second.get_longest_prefix(key, len, offset + 1);
+        } else {
+            return std::make_pair(key, offset);
+        }
+    }
+    struct naive_trie * traverse(const char c) {
+        auto res = children.find(c);
+        if (res != children.end()) {
+            return &res->second;
+        } else {
+            return NULL;
+        }
+    }
+    std::map<char, struct naive_trie> children;
+    bool has_value;
+    llama_token value;
+};
+
+//
+// impl
+//
+
+int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
+    GGML_ASSERT(token_left.find(' ')   == std::string::npos);
+    GGML_ASSERT(token_left.find('\n')  == std::string::npos);
+    GGML_ASSERT(token_right.find(' ')  == std::string::npos);
+    GGML_ASSERT(token_right.find('\n') == std::string::npos);
+
+    auto it = bpe_ranks.find(std::make_pair(token_left, token_right));
+    if (it == bpe_ranks.end()) {
+        return -1;
+    }
+
+    return it->second;
+}
+
+static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) {
+    return vocab.type;
+}
+
+static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) {
+    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
+    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL;
+}
+
+static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) {
+    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
+    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNKNOWN;
+}
+
+static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) {
+    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
+    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_CONTROL;
+}
+
+static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) {
+    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
+    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_BYTE;
+}
+
+static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token id) {
+    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
+    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED;
+}
+
+static bool llama_is_unused_token(const llama_vocab & vocab, llama_token id) {
+    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
+    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNUSED;
+}
+
+static uint8_t llama_token_to_byte(const llama_vocab & vocab, llama_token id) {
+    GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
+    GGML_ASSERT(llama_is_byte_token(vocab, id));
+    const auto & token_data = vocab.id_to_token.at(id);
+    switch (llama_vocab_get_type(vocab)) {
+        case LLAMA_VOCAB_TYPE_SPM:
+        case LLAMA_VOCAB_TYPE_UGM: {
+            auto buf = token_data.text.substr(3, 2);
+            return strtol(buf.c_str(), NULL, 16);
+        }
+        case LLAMA_VOCAB_TYPE_BPE: {
+            GGML_ABORT("fatal error");
+            //return unicode_utf8_to_byte(token_data.text); // TODO: why is this here after GGML_ASSERT?
+        }
+        case LLAMA_VOCAB_TYPE_WPM: {
+            GGML_ABORT("fatal error");
+        }
+        default:
+            GGML_ABORT("fatal error");
+    }
+}
+
+static void llama_escape_whitespace(std::string & text) {
+    replace_all(text, " ", "\xe2\x96\x81");
+}
+
+static void llama_unescape_whitespace(std::string & word) {
+    replace_all(word, "\xe2\x96\x81", " ");
+}
+
+struct llm_symbol {
+    using index = int;
+    index prev;
+    index next;
+    const char * text;
+    size_t n;
+};
+
+static_assert(std::is_trivially_copyable<llm_symbol>::value, "llm_symbol is not trivially copyable");
+
+//
+// SPM tokenizer
+// original implementation:
+// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
+//
+
+struct llm_bigram_spm {
+    struct comparator {
+        bool operator()(llm_bigram_spm & l, llm_bigram_spm & r) {
+            return (l.score < r.score) || (l.score == r.score && l.left > r.left);
+        }
+    };
+    using queue_storage = std::vector<llm_bigram_spm>;
+    using queue = std::priority_queue<llm_bigram_spm, queue_storage, comparator>;
+    llm_symbol::index left;
+    llm_symbol::index right;
+    float score;
+    size_t size;
+};
+
+struct llm_tokenizer_spm {
+    llm_tokenizer_spm(const llama_vocab & vocab) : vocab(vocab) {}
+
+    void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
+        // split string into utf8 chars
+        int index = 0;
+        size_t offs = 0;
+        while (offs < text.size()) {
+            llm_symbol sym;
+            size_t len = unicode_len_utf8(text[offs]);
+            sym.text = text.c_str() + offs;
+            sym.n = std::min(len, text.size() - offs);
+            offs += sym.n;
+            sym.prev = index - 1;
+            sym.next = offs == text.size() ? -1 : index + 1;
+            index++;
+            symbols.emplace_back(sym);
+        }
+
+        // seed the work queue with all possible 2-character tokens.
+        for (size_t i = 1; i < symbols.size(); ++i) {
+            try_add_bigram(i - 1, i);
+        }
+
+        // keep substituting the highest frequency pairs for as long as we can.
+        while (!work_queue.empty()) {
+            auto bigram = work_queue.top();
+            work_queue.pop();
+
+            auto & left_sym = symbols[bigram.left];
+            auto & right_sym = symbols[bigram.right];
+
+            // if one of the symbols already got merged, skip it.
+            if (left_sym.n == 0 || right_sym.n == 0 ||
+                left_sym.n + right_sym.n != bigram.size) {
+                continue;
+            }
+
+            // merge the right sym into the left one
+            left_sym.n += right_sym.n;
+            right_sym.n = 0;
+
+            //LLAMA_LOG_INFO("left = '%*s' size = %zu\n", (int) left_sym.n, left_sym.text, bigram.size);
+
+            // remove the right sym from the chain
+            left_sym.next = right_sym.next;
+            if (right_sym.next >= 0) {
+                symbols[right_sym.next].prev = bigram.left;
+            }
+
+            // find more substitutions
+            try_add_bigram(left_sym.prev, bigram.left);
+            try_add_bigram(bigram.left, left_sym.next);
+        }
+
+        for (int i = 0; i != -1; i = symbols[i].next) {
+            auto & symbol = symbols[i];
+            resegment(symbol, output);
+        }
+    }
+
+private:
+    void resegment(llm_symbol & symbol, std::vector<llama_vocab::id> & output) {
+        auto text = std::string(symbol.text, symbol.n);
+        auto token = vocab.token_to_id.find(text);
+
+        // Do we need to support is_unused?
+        if (token != vocab.token_to_id.end()) {
+            output.push_back((*token).second);
+            return;
+        }
+
+        const auto p = rev_merge.find(text);
+
+        if (p == rev_merge.end()) {
+            // output any symbols that did not form tokens as bytes.
+            output.reserve(output.size() + symbol.n);
+            for (int j = 0; j < (int)symbol.n; ++j) {
+                llama_vocab::id token_id = llama_byte_to_token_impl(vocab, symbol.text[j]);
+                output.push_back(token_id);
+            }
+            return;
+        }
+
+        resegment(symbols[p->second.first],  output);
+        resegment(symbols[p->second.second], output);
+    }
+
+    void try_add_bigram(int left, int right) {
+        if (left == -1 || right == -1) {
+            return;
+        }
+
+        const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n);
+        auto token = vocab.token_to_id.find(text);
+
+        if (token == vocab.token_to_id.end()) {
+            return;
+        }
+
+        if (static_cast<size_t>((*token).second) >= vocab.id_to_token.size()) {
+            return;
+        }
+
+        const auto & tok_data = vocab.id_to_token[(*token).second];
+
+        llm_bigram_spm bigram;
+        bigram.left  = left;
+        bigram.right = right;
+        bigram.score = tok_data.score;
+        bigram.size  = text.size();
+
+        work_queue.push(bigram);
+
+        // Do we need to support is_unused?
+        rev_merge[text] = std::make_pair(left, right);
+    }
+
+    const llama_vocab & vocab;
+
+    std::vector<llm_symbol> symbols;
+    llm_bigram_spm::queue work_queue;
+
+    std::map<std::string, std::pair<int, int>> rev_merge;
+};
+
+//
+// BPE tokenizer
+// adapted from https://github.com/cmp-nct/ggllm.cpp [MIT License]
+// tried to simplify unicode stuff, so most likely does not work 100% correctly!
+//
+
+// TODO: there are a lot of common parts between spm and bpe tokenizers, should be refactored and reused
+
+struct llm_bigram_bpe {
+    struct comparator {
+        bool operator()(const llm_bigram_bpe & l, const llm_bigram_bpe & r) const {
+            return l.rank > r.rank || (l.rank == r.rank && l.left > r.left);
+        }
+    };
+
+    using queue_storage = std::vector<llm_bigram_bpe>;
+    using queue = std::priority_queue<llm_bigram_bpe, queue_storage, comparator>;
+    llm_symbol::index left;
+    llm_symbol::index right;
+    std::string text;
+    int rank;
+    size_t size;
+};
+
+struct llm_tokenizer_bpe {
+    llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {
+        GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE);
+        switch (vocab.type_pre) {
+            case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
+                regex_exprs = {
+                    // original regex from tokenizer.json
+                    //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+
+                    // adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989
+                    "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_DBRX:
+            case LLAMA_VOCAB_PRE_TYPE_SMAUG:
+                regex_exprs = {
+                    // same as llama3
+                    "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM:
+                regex_exprs = {
+                    "[\r\n]",
+                    "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
+                    "\\s?[!-/:-~!-/:-~‘-‟ -。]+",
+                    "\\s+$",
+                    "[一-龥ࠀ-一가-퟿]+",
+                    "\\p{N}+",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
+                regex_exprs = {
+                    "[\r\n]",
+                    "\\s?\\p{L}+",
+                    "\\s?\\p{P}+",
+                    "[一-龥ࠀ-一가-퟿]+",
+                    "\\p{N}",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_FALCON:
+                regex_exprs = {
+                    "[\\p{P}\\$\\+<=>\\^~\\|`]+",
+                    "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
+                    "[0-9][0-9][0-9]",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_STARCODER:
+            case LLAMA_VOCAB_PRE_TYPE_REFACT:
+            case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
+            case LLAMA_VOCAB_PRE_TYPE_SMOLLM:
+            case LLAMA_VOCAB_PRE_TYPE_CODESHELL:
+                regex_exprs = {
+                    "\\p{N}",
+                    "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_GPT2:
+            case LLAMA_VOCAB_PRE_TYPE_MPT:
+            case LLAMA_VOCAB_PRE_TYPE_OLMO:
+            case LLAMA_VOCAB_PRE_TYPE_JAIS:
+                regex_exprs = {
+                    "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
+            case LLAMA_VOCAB_PRE_TYPE_QWEN2:
+                regex_exprs = {
+                    // original regex from tokenizer.json
+                    // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
+                    "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_PORO:
+                regex_exprs = {
+                    " ?[^(\\s|.,!?…。,、।۔،)]+",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_CHATGLM4:
+                regex_exprs = {
+                    "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_VIKING:
+                regex_exprs = {
+                    " ?[^(\\s|.,!?…。,、।۔،)]+",
+                    "\\p{N}",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_TEKKEN:
+                // original regex from tokenizer.json
+                // "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
+                regex_exprs = {
+                    "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+                };
+                break;
+            default:
+                // default regex for BPE tokenization pre-processing
+                regex_exprs = {
+                    "[\\p{P}\\$\\+<=>\\^~\\|]+",
+                    "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
+                    "\\p{N}+",
+                    "[0-9][0-9][0-9]",
+                };
+                break;
+        }
+    }
+
+    void append(const llama_vocab::id token_id, std::vector<llama_vocab::id> & output) const {
+        output.push_back(token_id);
+    }
+
+    bool append_bos(std::vector<llama_vocab::id> & output) const {
+        if (vocab.tokenizer_add_bos) {
+            GGML_ASSERT(vocab.special_bos_id != -1);
+            output.push_back(vocab.special_bos_id);
+            return true;
+        }
+        return false;
+    }
+
+    bool append_eos(std::vector<llama_vocab::id> & output) const {
+        if (vocab.tokenizer_add_eos) {
+            GGML_ASSERT(vocab.special_eos_id != -1);
+            output.push_back(vocab.special_eos_id);
+            return true;
+        }
+        return false;
+    }
+
+    void check_double_bos_eos(const std::vector<llama_vocab::id> & output) const {
+        if (vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
+            LLAMA_LOG_WARN(
+                "%s: Added a BOS token to the prompt as specified by the model but the prompt "
+                "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
+                "Are you sure this is what you want?\n", __FUNCTION__);
+        }
+        if (vocab.tokenizer_add_eos && output.size() >= 2 && *(output.end()-2) == vocab.special_eos_id) {
+            LLAMA_LOG_WARN(
+                "%s: Added a EOS token to the prompt as specified by the model but the prompt "
+                "also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. "
+                "Are you sure this is what you want?\n", __FUNCTION__);
+        }
+    }
+
+    void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
+        int final_prev_index = -1;
+
+        const auto word_collection = unicode_regex_split(text, regex_exprs);
+
+        symbols_final.clear();
+
+        for (auto & word : word_collection) {
+            work_queue = llm_bigram_bpe::queue();
+            symbols.clear();
+
+            int index = 0;
+            size_t offset = 0;
+
+            if (vocab.tokenizer_ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
+                symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
+                offset = word.size();
+            }
+
+            while (offset < word.size()) {
+                llm_symbol sym;
+                size_t char_len = std::min(word.size() - offset, (size_t) unicode_len_utf8(word[offset]));
+                sym.text = word.c_str() + offset;
+                sym.n = char_len;
+                offset += sym.n;
+                sym.prev = index - 1;
+                sym.next = offset == word.size() ? -1 : index + 1;
+                index++;
+                symbols.emplace_back(sym);
+            }
+            for (size_t i = 1; i < symbols.size(); ++i) {
+                add_new_bigram(i - 1, i);
+            }
+
+            // build token(s)
+            while (!work_queue.empty()) {
+                auto bigram = work_queue.top();
+                work_queue.pop();
+
+                auto & left_symbol = symbols[bigram.left];
+                auto & right_symbol = symbols[bigram.right];
+
+                if (left_symbol.n == 0 || right_symbol.n == 0) {
+                    continue;
+                }
+                std::string left_token = std::string(left_symbol.text, left_symbol.n);
+                std::string right_token = std::string(right_symbol.text, right_symbol.n);
+                if (left_token + right_token != bigram.text) {
+                    continue;  // Skip this bigram if it's outdated
+                }
+
+                // merge the right sym into the left one
+                left_symbol.n += right_symbol.n;
+                right_symbol.n = 0;
+
+                // remove the right sym from the chain
+                left_symbol.next = right_symbol.next;
+                if (right_symbol.next >= 0) {
+                    symbols[right_symbol.next].prev = bigram.left;
+                }
+
+                add_new_bigram(left_symbol.prev, bigram.left);  // left side of current symbol
+                add_new_bigram(bigram.left, left_symbol.next);  // right side of current symbol
+            }
+
+            // add the finished tokens to the final list keeping correct order for next and prev
+            for (auto & sym : symbols) {
+                if (sym.n > 0) {
+                    sym.prev = final_prev_index;
+                    sym.next = -1;
+                    if (final_prev_index != -1) {
+                        symbols_final[final_prev_index].next = symbols_final.size();
+                    }
+                    symbols_final.emplace_back(sym);
+                    final_prev_index = symbols_final.size() - 1;
+                }
+            }
+        }
+
+        symbols = symbols_final;
+
+        if (!symbols.empty()) {
+            for (int i = 0; i != -1; i = symbols[i].next) {
+                auto & symbol = symbols[i];
+                if (symbol.n == 0) {
+                    continue;
+                }
+
+                const std::string str = std::string(symbol.text, symbol.n);
+                const auto token = vocab.token_to_id.find(str);
+
+                if (token == vocab.token_to_id.end()) {
+                    for (auto j = str.begin(); j != str.end(); ++j) {
+                        std::string byte_str(1, *j);
+                        auto token_multibyte = vocab.token_to_id.find(byte_str);
+                        if (token_multibyte != vocab.token_to_id.end()) {
+                            output.push_back(token_multibyte->second);
+                        }
+                    }
+                } else {
+                    output.push_back((*token).second);
+                }
+            }
+        }
+    }
+
+private:
+    void add_new_bigram(int left, int right) {
+        if (left == -1 || right == -1) {
+            return;
+        }
+
+        std::string left_token  = std::string(symbols[left].text,  symbols[left].n);
+        std::string right_token = std::string(symbols[right].text, symbols[right].n);
+
+        int rank_found = -1;
+
+        rank_found = vocab.find_bpe_rank(left_token, right_token);
+
+        if (rank_found < 0) {
+            return;
+        }
+
+        llm_bigram_bpe bigram;
+
+        bigram.left  = left;
+        bigram.right = right;
+        bigram.text  = left_token + right_token;
+        bigram.size  = left_token.size() + right_token.size();
+        bigram.rank  = rank_found;
+
+        work_queue.push(bigram);
+    }
+
+    const llama_vocab & vocab;
+
+    std::vector<std::string> regex_exprs;
+
+    std::vector<llm_symbol> symbols;
+    std::vector<llm_symbol> symbols_final;
+
+    llm_bigram_bpe::queue work_queue;
+};
+
+//
+// WPM tokenizer
+//
+
+struct llm_tokenizer_wpm {
+    llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}
+
+    void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) const {
+        const auto & token_map = vocab.token_to_id;
+
+        // normalize and split by whitespace
+        std::vector<std::string> words = preprocess(text);
+
+        // bos token prepended already
+
+        // find the longest tokens that form the words
+        for (const std::string & word : words) {
+            // skip empty words
+            if (word.size() == 0) {
+                continue;
+            }
+
+            // prepend phantom space
+            const std::string word1 = "\xe2\x96\x81" + word;
+            const int n = word1.size();
+
+            const size_t current_tokens = output.size();
+
+            // we're at the start of a new word
+            // move through character position in word
+            for (int i = 0; i < n; ++i) {
+                // loop through possible match length
+                bool match = false;
+                for (int j = std::min(n, i + vocab.max_token_len + 1); j > i; j--) {
+                    auto it = token_map.find(word1.substr(i, j - i));
+                    if (it != token_map.end()) {
+                        output.push_back(it->second);
+                        match = true;
+                        i = j - 1;
+                        break;
+                    }
+                }
+
+                if (!match) { // discard all
+                    output.resize(current_tokens);
+                    break;  // and discard next tokens
+                }
+            }
+
+            // we didn't find any matches for this word
+            if (current_tokens == output.size()) {
+                output.push_back(vocab.special_unk_id);
+            }
+        }
+    }
+
+    // TODO: reduce string copies by using cpts_offs array
+    std::vector<std::string> preprocess(const std::string & text) const {
+        const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
+        std::vector<std::string> words(1, "");
+
+        for (const uint32_t cpt : cpts_nfd) {
+            const auto flags = unicode_cpt_flags(cpt);
+
+            if (flags.is_whitespace) {
+                if (words.back().size()) {  // finish previous word if any
+                    words.emplace_back();
+                }
+                continue;
+            }
+
+            assert (!flags.is_separator);
+            if (cpt == 0 || cpt == 0xFFFD || flags.is_control) {
+                continue;
+            }
+
+            const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt));
+            if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) {
+                if (words.back().size()) {  // finish previous word if any
+                    words.emplace_back();
+                }
+                words.back() = s;       // single char word
+                words.emplace_back();   // start a new word
+            } else {
+                words.back() += s;  // append char to word
+            }
+        }
+
+        if (!words.back().size()) {
+            words.pop_back();
+        }
+
+        return words;
+    }
+
+    static bool is_chinese_char(uint32_t cpt) {
+        return
+            (cpt >= 0x04E00 && cpt <= 0x09FFF) ||
+            (cpt >= 0x03400 && cpt <= 0x04DBF) ||
+            (cpt >= 0x20000 && cpt <= 0x2A6DF) ||
+            (cpt >= 0x2A700 && cpt <= 0x2B73F) ||
+            (cpt >= 0x2B740 && cpt <= 0x2B81F) ||
+            (cpt >= 0x2B920 && cpt <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920
+            (cpt >= 0x0F900 && cpt <= 0x0FAFF) ||
+            (cpt >= 0x2F800 && cpt <= 0x2FA1F);
+            //(cpt >= 0x3000  && cpt <= 0x303F)  ||
+            //(cpt >= 0xFF00  && cpt <= 0xFFEF);
+    }
+
+    const llama_vocab & vocab;
+};
+
+//
+// UGM tokenizer
+//
+
+struct llm_tokenizer_ugm {
+    llm_tokenizer_ugm(const llama_vocab & vocab) : vocab(vocab) {
+        if (vocab.precompiled_charsmap.size() > 0) {
+            size_t charsmap_offset = 0;
+
+            // First four bytes of precompiled_charsmap contains length of binary
+            // blob containing XOR-compressed compact double array (XCDA) entries
+            uint32_t xcda_blob_size = *(const uint32_t *) &vocab.precompiled_charsmap[0];
+            charsmap_offset += sizeof(xcda_blob_size);
+            if (xcda_blob_size + charsmap_offset >= vocab.precompiled_charsmap.size()) {
+                throw std::runtime_error("Index out of array bounds in precompiled charsmap!");
+            }
+
+            // Next xcda_blob_size bytes contain entries of XOR-compressed compact
+            // double array (XCDA). Each entry is bit-packed into a 32-bit integer.
+            xcda_array = (const uint32_t *) &vocab.precompiled_charsmap[charsmap_offset];
+            xcda_array_size = xcda_blob_size / sizeof(uint32_t);
+            charsmap_offset += xcda_blob_size;
+
+            // Remaining bytes of precompiled charsmap contain null-terminated
+            // replacement strings for prefixes matched by the XCDA.
+            prefix_replacements = &vocab.precompiled_charsmap[charsmap_offset];
+            prefix_replacements_size = vocab.precompiled_charsmap.size() - charsmap_offset;
+        }
+
+        for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) {
+            const auto &token_data = vocab.id_to_token[id];
+
+            if (llama_is_normal_token(vocab, id)) {
+                min_score = std::min<float>(min_score, token_data.score);
+                max_score = std::max<float>(max_score, token_data.score);
+            }
+
+            if (llama_is_normal_token(vocab, id) ||
+                llama_is_user_defined_token(vocab, id) ||
+                llama_is_unused_token(vocab, id)) {
+                token_matcher.insert(token_data.text.data(), token_data.text.size(), id);
+            }
+
+            if (llama_is_user_defined_token(vocab, id)) {
+                user_defined_token_matcher.insert(token_data.text.data(), token_data.text.size());
+            }
+        }
+
+        unknown_token_score = min_score - unknown_token_score_penalty;
+    }
+
+    /* This implementation is based on SentencePiece optimized Viterbi algorithm for
+     * unigram language models. The general idea is to:
+     * - move along the input sequence in steps of one UTF code point,
+     * - at each step find all possible tokenizations of the prefix by
+     *   traversing the tokens trie,
+     * - for each tokenization store the best one so far (by higher score)
+     * - use the position in sequence after given token as an index to store
+     *   results
+     * - if there was no valid tokenization of the current UTF code point
+     *   then use unknown token with additional score penalty
+     * After processing the whole sequence we backtrack from the end to get
+     * the best tokenization.
+    */
+    void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
+        // get current size of output (for reversal later)
+        size_t output_size = output.size();
+
+        // normalize the input first
+        std::string normalized;
+        normalize(text, &normalized);
+        size_t input_len = normalized.size();
+        if (input_len == 0) {
+            return;
+        }
+
+        // initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores
+        std::vector<struct best_tokenization> tokenization_results(input_len + 1, {vocab.special_unk_id, 0, -FLT_MAX});
+        // at the beginning tokenization score is zero
+        tokenization_results[0] = { vocab.special_unk_id, 0, 0 };
+
+        for (size_t input_offset = 0; input_offset < input_len;) {
+            size_t prefix_offset = input_offset;
+            // calculate how many code units are in the currently processed UTF code point
+            size_t n_utf8_code_units = std::min<size_t>(unicode_len_utf8(normalized[input_offset]), input_len - input_offset);
+
+            // traverse the token matcher trie to find a matching token
+            bool single_codepoint_token_found = false;
+            const struct best_tokenization & current_best = tokenization_results[input_offset];
+            struct naive_trie * node  = token_matcher.traverse(normalized[prefix_offset++]);
+
+            while (prefix_offset <= input_len && node != NULL) {
+                // check if we found valid token in prefix
+                if (node->has_value) {
+                    // check if it corresponds to the whole UTF code point
+                    if (prefix_offset - input_offset == n_utf8_code_units) {
+                        single_codepoint_token_found = true;
+                    }
+                    llama_token token_id = node->value;
+                    const auto & token_data = vocab.id_to_token[token_id];
+
+                    // we set the user-defined token scores to 0 to make them more likely to be selected
+                    // (normal token scores are log probabilities, so they are negative)
+                    // score type is double here to make tokenization results exactly
+                    // the same as in the HF tokenizer using SentencePiece
+                    const double token_score = llama_is_user_defined_token(vocab, token_id) ? 0.0 : token_data.score;
+                    const double challenger_score = current_best.score_sum + token_score;
+                    struct best_tokenization & current_champ = tokenization_results[prefix_offset];
+                    if (challenger_score > current_champ.score_sum) {
+                        struct best_tokenization challenger = { token_id, input_offset, (float) challenger_score };
+                        current_champ = challenger;
+                    }
+                }
+                node = node->traverse(normalized[prefix_offset++]);
+            }
+
+            // if we didn't find a valid token corresponding to the whole UTF code point
+            // then use unknown token as the tokenization of this UTF code point
+            if (!single_codepoint_token_found) {
+                const double challenger_score = current_best.score_sum + unknown_token_score;
+                prefix_offset = input_offset + n_utf8_code_units;
+                struct best_tokenization & current_champ = tokenization_results[prefix_offset];
+                if (challenger_score > current_champ.score_sum) {
+                    struct best_tokenization challenger = { vocab.special_unk_id, input_offset, (float) challenger_score };
+                    current_champ = challenger;
+                }
+            }
+
+            // move to the next UTF code point
+            input_offset += n_utf8_code_units;
+        }
+
+        // now backtrack from the end to gather token ids of the best tokenization
+        // merge sequences of consecutive unknown tokens into single unknown tokens
+        bool is_prev_unknown = false;
+        for (struct best_tokenization & tokenization = tokenization_results[input_len]; ; tokenization = tokenization_results[tokenization.input_offset]) {
+            bool is_unknown = tokenization.token_id == vocab.special_unk_id;
+            if (!(is_prev_unknown && is_unknown)) {
+                output.push_back(tokenization.token_id);
+            }
+            if (tokenization.input_offset == 0) {
+                break;
+            }
+            is_prev_unknown = is_unknown;
+        }
+
+        // reverse the output since we added tokens starting from the end of the input
+        std::reverse(output.begin() + output_size, output.end());
+    }
+
+private:
+    const llama_vocab & vocab;
+
+    // helper structure for returning normalization results
+    struct normalization_result {
+        const char * normalized;
+        size_t normalized_len;
+        size_t consumed_input;
+    };
+
+    void normalize(const std::string& input, std::string * normalized) {
+        normalized->clear();
+        normalized->reserve(input.size() * 3);
+
+        const std::string space = vocab.tokenizer_escape_whitespaces ? escaped_space : " ";
+
+        bool shall_prepend_space = !vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
+        bool shall_append_space = vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
+        bool shall_merge_spaces = vocab.tokenizer_remove_extra_whitespaces;
+
+        bool is_space_prepended = false;
+        bool processing_non_ws = false;
+
+        size_t input_len = input.size();
+
+        for (size_t input_offset = 0; input_offset < input_len; ) {
+            auto norm_res = normalize_prefix(input, input_offset);
+            for (size_t i = 0; i < norm_res.normalized_len; i++) {
+                char c = norm_res.normalized[i];
+                if (c != ' ') {
+                    if (!processing_non_ws) {
+                        processing_non_ws = true;
+                        if ((shall_prepend_space && !is_space_prepended) || shall_merge_spaces) {
+                            normalized->append(space);
+                            is_space_prepended = true;
+                        }
+                    }
+                    normalized->push_back(c);
+                } else {
+                    if (processing_non_ws) {
+                        processing_non_ws = false;
+                    }
+                    if (!shall_merge_spaces) {
+                        normalized->append(space);
+                    }
+                }
+            }
+
+            input_offset += norm_res.consumed_input;
+        }
+
+        if (shall_append_space) {
+            normalized->append(space);
+        }
+    }
+
+    /*
+     * This structure is a view wrapper for XOR-compressed double array (XCDA)
+     * See Shunsuke Kanda (2018). Space- and Time-Efficient String Dictionaries.
+     * Eeach bit-packed entry contains:
+     * - BASE array value in bits 10-30
+     * - LCHECK array value in bits 0-7
+     * - LEAF array value in bit 9
+     * Entries containing indexes of replacement sequences have set bit 31
+     */
+    struct xcda_array_view {
+    public:
+        xcda_array_view(const uint32_t * xcda_array, size_t xcda_array_size) : xcda_array(xcda_array), xcda_array_size(xcda_array_size) {
+        }
+        uint32_t get_base(size_t index) {
+            uint32_t packed_node = get_node(index);
+            return (packed_node >> 10) << ((packed_node & (1U << 9)) >> 6);
+        }
+        uint32_t get_lcheck(size_t index) {
+            uint32_t packed_node = get_node(index);
+            return packed_node & ((1U << 31) | 0xff);
+        }
+        bool get_leaf(size_t index) {
+            uint32_t packed_node = get_node(index);
+            return (packed_node >> 8) & 1;
+        }
+        uint32_t get_value(size_t index) {
+            uint32_t packed_node = get_node(index);
+            return packed_node & ((1U << 31) - 1);
+        }
+    private:
+        uint32_t get_node(size_t index) {
+            if (index > xcda_array_size) {
+                throw std::runtime_error("Index out of array bounds in XCDA array!");
+            }
+            return xcda_array[index];
+        }
+        const uint32_t * xcda_array;
+        size_t xcda_array_size;
+    };
+
+    struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) {
+        if (input_offset == input.size()) {
+            return { &input[input_offset], 0, 0 };
+        }
+
+        // if input prefix matches some user-defined token return this token as normalization result
+        auto user_defined_token_match = user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
+        if (user_defined_token_match.second > 0) {
+            return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second };
+        }
+
+        size_t longest_prefix_length = 0;
+        size_t longest_prefix_offset = 0;
+
+        if (xcda_array_size > 0) {
+            struct xcda_array_view xcda_view(xcda_array, xcda_array_size);
+
+            // Find the longest normalized sequence matching the input prefix by walking
+            // the XOR-compressed compact double array (XCDA) starting from the root node
+            // We find the index of the next node by calculating BASE[s] ^ c where s is
+            // the index of the previous node and c is a numerical character value
+            uint32_t node_index = 0;
+            // get BASE of the root node
+            node_index = xcda_view.get_base(node_index);
+            for (size_t prefix_offset = input_offset; prefix_offset < input.size(); prefix_offset++) {
+                unsigned char c = input[prefix_offset];
+                if (c == 0) {
+                    break;
+                }
+                node_index ^= c;
+                // if value of LCHECK is not c it means that this is not a child of
+                // the previous node, so we stop matching
+                if (xcda_view.get_lcheck(node_index) != c) {
+                    break;
+                }
+                bool is_leaf = xcda_view.get_leaf(node_index);
+                // get BASE of the current node
+                node_index ^= xcda_view.get_base(node_index);
+                // if LEAF of the current node is true, it means that its BASE points to the node
+                // containing index of replacement sequence for currently matched input prefix
+                if (is_leaf)
+                {
+                    longest_prefix_length = prefix_offset - input_offset + 1;
+                    // get index of replacement sequence for currently matched input prefix
+                    longest_prefix_offset = xcda_view.get_value(node_index);
+                }
+            }
+        }
+
+        if (longest_prefix_length > 0) {
+            // we have a match, so return the replacement sequence
+            if (longest_prefix_offset >= prefix_replacements_size) {
+                throw std::runtime_error("Index out of array bounds in precompiled charsmap!");
+            }
+            const char * prefix_replacement = &prefix_replacements[longest_prefix_offset];
+            return { prefix_replacement, strlen(prefix_replacement), longest_prefix_length };
+        } else {
+            // check if the input prefix contains a valid sequence of UTF-8 code units
+            try {
+                // if yes, return this sequence unmodified
+                size_t prefix_offset = input_offset;
+                unicode_cpt_from_utf8(input, prefix_offset);
+                return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset };
+            } catch (std::invalid_argument & /*ex*/) {
+                // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER
+                return { "\xEF\xBF\xBD", 3, 1 };
+            }
+        }
+    }
+
+    // escaped space symbol - U+2581 (Lower One Eighth Block)
+    const std::string escaped_space = "\xE2\x96\x81";
+
+    const char * prefix_replacements = NULL;
+    size_t prefix_replacements_size = 0;
+
+    const uint32_t * xcda_array = NULL;
+    size_t xcda_array_size = 0;
+
+    struct naive_trie user_defined_token_matcher;
+
+    // this structure stores the best tokenization so far at input_offset
+    struct best_tokenization {
+        llama_token token_id;
+        size_t input_offset;
+        float score_sum;
+    };
+
+    float min_score = FLT_MAX;
+    float max_score = -FLT_MAX;
+
+    float unknown_token_score_penalty = 10.0;
+    float unknown_token_score;
+
+    struct naive_trie token_matcher;
+};
+
+//
+// (de-) tokenize
+//
+
+typedef enum FRAGMENT_BUFFER_VARIANT_TYPE {
+    FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN,
+    FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT
+} FRAGMENT_BUFFER_VARIANT_TYPE;
+
+struct fragment_buffer_variant {
+    fragment_buffer_variant(llama_vocab::id _token)
+    :
+        type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN),
+        token(_token),
+        raw_text(_dummy),
+        offset(0),
+        length(0) {}
+
+    fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length)
+    :
+        type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT),
+        token((llama_vocab::id) - 1),
+        raw_text(_raw_text),
+        offset(_offset),
+        length(_length){
+            GGML_ASSERT(_offset >= 0);
+            GGML_ASSERT(_length >= 1);
+            GGML_ASSERT(offset + length <= raw_text.length());
+        }
+
+    const FRAGMENT_BUFFER_VARIANT_TYPE type;
+    const llama_vocab::id token;
+    const std::string _dummy;
+    const std::string & raw_text;
+    const uint64_t offset;
+    const uint64_t length;
+};
+
+// #define PRETOKENIZERDEBUG
+
+static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer, bool parse_special) {
+    // for each special token
+    for (const llama_vocab::id special_id : vocab.cache_special_tokens) {
+        const auto & data = vocab.id_to_token[special_id];
+        const auto & special_token = data.text;
+
+        if (!parse_special && (data.attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_UNKNOWN))) {
+            // Ignore control and unknown tokens when parse_special == false
+            continue;
+            // User-defined tokens are still pre-tokenized before everything else
+            // ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726
+            // This is mostly relevant for neox-style tokenizers (mpt, olmo, stablelm, etc.)
+        }
+
+        // for each text fragment
+        std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
+        while (it != buffer.end()) {
+            auto & fragment = (*it);
+
+            // if a fragment is text ( not yet processed )
+            if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                auto & raw_text = fragment.raw_text;
+
+                auto raw_text_base_offset = fragment.offset;
+                auto raw_text_base_length = fragment.length;
+
+                // loop over the text
+                while (true) {
+                    // find the first occurrence of a given special token in this fragment
+                    //  passing offset argument only limit the "search area" but match coordinates
+                    //  are still relative to the source full raw_text
+                    auto match = raw_text.find(special_token, raw_text_base_offset);
+
+                    // no occurrences found, stop processing this fragment for a given special token
+                    if (match == std::string::npos) break;
+
+                    // check if match is within bounds of offset <-> length
+                    if (match + special_token.length() > raw_text_base_offset + raw_text_base_length) break;
+
+#ifdef PRETOKENIZERDEBUG
+                    LLAMA_LOG_WARN("FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
+#endif
+                    auto source = std::distance(buffer.begin(), it);
+
+                    // if match is further than base offset
+                    //  then we have some text to the left of it
+                    if (match > raw_text_base_offset) {
+                        // left
+                        const int64_t left_reminder_offset = raw_text_base_offset + 0;
+                        int64_t left_reminder_length = match - raw_text_base_offset;
+
+                        if (data.attr & LLAMA_TOKEN_ATTR_LSTRIP) {
+                            while (left_reminder_length > 0 && isspace(raw_text[left_reminder_offset + left_reminder_length - 1])) {
+                                left_reminder_length--;
+                            }
+                        }
+
+                        if (left_reminder_length > 0) {
+                            buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length);
+                            it++;
+                        }
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str());
+#endif
+                    }
+
+                    // special token
+                    buffer.emplace_after(it, special_id);
+                    it++;
+
+                    // right
+                    if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) {
+                        int64_t right_reminder_offset = match + special_token.length();
+                        int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
+
+                        if (data.attr & LLAMA_TOKEN_ATTR_RSTRIP) {
+                            while (right_reminder_length > 0 && isspace(raw_text[right_reminder_offset])) {
+                                right_reminder_offset++;
+                                right_reminder_length--;
+                            }
+                        }
+
+                        if (right_reminder_length > 0) {
+                            buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length);
+                            it++;
+                        }
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str());
+#endif
+
+                        if (source == 0) {
+                            buffer.erase_after(buffer.before_begin());
+                        } else {
+                            buffer.erase_after(std::next(buffer.begin(), (source-1)));
+                        }
+
+                        // repeat for the right side
+                        raw_text_base_offset = right_reminder_offset;
+                        raw_text_base_length = right_reminder_length;
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("RR: (%ld %ld) '%s'\n", raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
+#endif
+                    } else {
+                        if (source == 0) {
+                            buffer.erase_after(buffer.before_begin());
+                        } else {
+                            buffer.erase_after(std::next(buffer.begin(), (source-1)));
+                        }
+                        break;
+                    }
+                }
+            }
+            it++;
+        }
+    }
+}
+
+std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special) {
+    std::vector<llama_vocab::id> output;
+    std::forward_list<fragment_buffer_variant> fragment_buffer;
+
+    if (!raw_text.empty()) {
+        fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
+        tokenizer_st_partition(vocab, fragment_buffer, parse_special);
+    }
+
+    switch (vocab.type) {
+        case LLAMA_VOCAB_TYPE_SPM:
+            {
+                // OG tokenizer behavior:
+                //
+                // tokenizer.encode('', add_special_tokens=True)  returns [1]
+                // tokenizer.encode('', add_special_tokens=False) returns []
+
+                bool is_prev_special = true;  // prefix with space if first token
+
+                if (add_special && vocab.tokenizer_add_bos) {
+                    GGML_ASSERT(vocab.special_bos_id != -1);
+                    output.push_back(vocab.special_bos_id);
+                    is_prev_special = true;
+                }
+
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+                        // prefix with space if previous is special
+                        if (vocab.tokenizer_add_space_prefix && is_prev_special) {
+                            raw_text = " " + raw_text;
+                        }
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
+#endif
+                        llm_tokenizer_spm tokenizer(vocab);
+                        llama_escape_whitespace(raw_text);
+                        tokenizer.tokenize(raw_text, output);
+                        is_prev_special = false;
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                        is_prev_special = true;
+                    }
+                }
+
+                if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
+                    LLAMA_LOG_WARN(
+                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
+                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
+                        "Are you sure this is what you want?\n", __FUNCTION__);
+                }
+
+                if (add_special && vocab.tokenizer_add_eos) {
+                    GGML_ASSERT(vocab.special_eos_id != -1);
+                    output.push_back(vocab.special_eos_id);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_BPE:
+            {
+                llm_tokenizer_bpe tokenizer(vocab);
+
+                if (add_special) {
+                    tokenizer.append_bos(output);
+                }
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
+#endif
+                        tokenizer.tokenize(raw_text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        tokenizer.append(fragment.token, output);
+                    }
+                }
+
+                if (add_special) {
+                    tokenizer.append_eos(output);
+                    tokenizer.check_double_bos_eos(output);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_WPM:
+            {
+                if (add_special) {
+                    GGML_ASSERT(vocab.special_cls_id != -1);
+                    output.push_back(vocab.special_cls_id);
+                }
+
+                llm_tokenizer_wpm tokenizer(vocab);
+
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
+#endif
+                        tokenizer.tokenize(raw_text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+
+                if (add_special) {
+                    GGML_ASSERT(vocab.special_sep_id != -1);
+                    output.push_back(vocab.special_sep_id);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_UGM:
+            {
+                llm_tokenizer_ugm tokenizer(vocab);
+
+                if (add_special && vocab.tokenizer_add_bos != 0) {
+                    GGML_ASSERT(vocab.special_bos_id != -1);
+                    output.push_back(vocab.special_bos_id);
+                }
+
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
+#endif
+                        tokenizer.tokenize(raw_text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+
+                if (add_special && vocab.tokenizer_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
+                    LLAMA_LOG_WARN(
+                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
+                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
+                        "Are you sure this is what you want?\n", __FUNCTION__);
+                }
+
+                if (add_special && vocab.tokenizer_add_eos == 1) {
+                    GGML_ASSERT(vocab.special_eos_id != -1);
+                    output.push_back(vocab.special_eos_id);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_NONE:
+            GGML_ABORT("fatal error");
+    }
+
+    return output;
+}
+
+llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch) {
+    GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
+    static const char * hex = "0123456789ABCDEF";
+    switch (llama_vocab_get_type(vocab)) {
+        case LLAMA_VOCAB_TYPE_SPM:
+        case LLAMA_VOCAB_TYPE_UGM: {
+            const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
+            auto token = vocab.token_to_id.find(buf);
+            if (token != vocab.token_to_id.end()) {
+                return (*token).second;
+            }
+            // Try to fall back to just the byte as a string
+            const char buf2[2] = { (char)ch, 0 };
+            return vocab.token_to_id.at(buf2);
+        }
+        case LLAMA_VOCAB_TYPE_WPM:
+        case LLAMA_VOCAB_TYPE_BPE: {
+            return vocab.token_to_id.at(unicode_byte_to_utf8(ch));
+        }
+        default:
+            GGML_ABORT("fatal error");
+    }
+}
+
+const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token) {
+    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
+    return vocab.id_to_token[token].text.c_str();
+}
+
+float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token) {
+    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
+    return vocab.id_to_token[token].score;
+}
+
+llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token) {
+    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
+    return vocab.id_to_token[token].attr;
+}
+
+bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
+    return token != -1 && (
+        token == llama_token_eos_impl(vocab) ||
+        token == llama_token_eot_impl(vocab) ||
+        token == llama_token_eom_impl(vocab)
+    );
+}
+
+bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
+    return llama_is_control_token(vocab, token);
+}
+
+llama_token llama_token_bos_impl(const struct llama_vocab & vocab) {
+    return vocab.special_bos_id;
+}
+
+llama_token llama_token_eos_impl(const struct llama_vocab & vocab) {
+    return vocab.special_eos_id;
+}
+
+llama_token llama_token_cls_impl(const struct llama_vocab & vocab) {
+    return vocab.special_cls_id;
+}
+
+llama_token llama_token_sep_impl(const struct llama_vocab & vocab) {
+    return vocab.special_sep_id;
+}
+
+llama_token llama_token_nl_impl(const struct llama_vocab & vocab) {
+    return vocab.linefeed_id;
+}
+
+llama_token llama_token_pad_impl(const struct llama_vocab & vocab) {
+    return vocab.special_pad_id;
+}
+
+int32_t llama_add_bos_token_impl(const struct llama_vocab & vocab) {
+    return vocab.tokenizer_add_bos;
+}
+
+int32_t llama_add_eos_token_impl(const struct llama_vocab & vocab) {
+    return vocab.tokenizer_add_eos;
+}
+
+llama_token llama_token_prefix_impl(const struct llama_vocab & vocab) {
+    return vocab.special_prefix_id;
+}
+
+llama_token llama_token_middle_impl(const struct llama_vocab & vocab) {
+    return vocab.special_middle_id;
+}
+
+llama_token llama_token_suffix_impl(const struct llama_vocab & vocab) {
+    return vocab.special_suffix_id;
+}
+
+llama_token llama_token_eot_impl(const struct llama_vocab & vocab) {
+    return vocab.special_eot_id;
+}
+
+llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
+    return vocab.special_eom_id;
+}
+
+int32_t llama_tokenize_impl(
+    const struct llama_vocab & vocab,
+                  const char * text,
+                     int32_t   text_len,
+                 llama_token * tokens,
+                     int32_t   n_tokens_max,
+                        bool   add_special,
+                        bool   parse_special) {
+    auto res = llama_tokenize_internal(vocab, std::string(text, text_len), add_special, parse_special);
+    if (n_tokens_max < (int) res.size()) {
+        // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
+        return -((int) res.size());
+    }
+
+    for (size_t i = 0; i < res.size(); i++) {
+        tokens[i] = res[i];
+    }
+
+    return res.size();
+}
+
+static std::string llama_decode_text(const std::string & text) {
+    std::string decoded_text;
+
+    const auto cpts = unicode_cpts_from_utf8(text);
+    for (const auto cpt : cpts) {
+        const auto utf8 = unicode_cpt_to_utf8(cpt);
+        try {
+            decoded_text += unicode_utf8_to_byte(utf8);
+        } catch (const std::out_of_range & /*e*/) {
+            decoded_text += "[UNK_BYTE_0x";
+            for (const auto c : utf8) {
+                decoded_text += format("%02x", (uint8_t) c);
+            }
+            decoded_text += text + "]";
+        }
+    }
+
+    return decoded_text;
+}
+
+// does not write null-terminator to buf
+int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) {
+    // ref: https://github.com/ggerganov/llama.cpp/pull/7587#discussion_r1620983843
+    static const int attr_special = LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_CONTROL;
+    const llama_token_attr attr = llama_token_get_attr_impl(vocab, token);
+    if (!special && (attr & attr_special)) {
+        return 0;
+    }
+
+    // copy piece chars to output text buffer
+    // skip up to 'lstrip' leading spaces before copying
+    auto _try_copy = [=] (const char * token, size_t size) -> int32_t {
+        for (int32_t i = 0; i < lstrip && size && *token == ' '; ++i) {
+            token++;
+            size--;
+        }
+        if (length < (int32_t)size) {
+            return -(int32_t) size;
+        }
+        memcpy(buf, token, size);
+        return (int32_t) size;
+    };
+
+    // if we have a cache - use it
+    {
+        const auto & cache = vocab.cache_token_to_piece;
+
+        if (!cache.empty()) {
+            const auto & result = cache.at(token);
+            return _try_copy(result.data(), result.size());
+        }
+    }
+
+    if (0 <= token && token < (int32_t) vocab.id_to_token.size()) {
+        const std::string & token_text = vocab.id_to_token[token].text;
+        switch (llama_vocab_get_type(vocab)) {
+            case LLAMA_VOCAB_TYPE_WPM:
+            case LLAMA_VOCAB_TYPE_SPM:
+            case LLAMA_VOCAB_TYPE_UGM: {
+                // NOTE: we accept all unsupported token types,
+                // suppressing them like CONTROL tokens.
+                if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
+                    return _try_copy(token_text.data(), token_text.size());
+                } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
+                    std::string result = token_text;
+                    llama_unescape_whitespace(result);
+                    return _try_copy(result.data(), result.size());
+                } else if (attr & LLAMA_TOKEN_ATTR_BYTE) {
+                    char byte = (char) llama_token_to_byte(vocab, token);
+                    return _try_copy((char*) &byte, 1);
+                }
+                break;
+            }
+            case LLAMA_VOCAB_TYPE_BPE: {
+                // NOTE: we accept all unsupported token types,
+                // suppressing them like CONTROL tokens.
+                if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
+                    return _try_copy(token_text.data(), token_text.size());
+                } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
+                    std::string result = llama_decode_text(token_text);
+                    return _try_copy(result.data(), result.size());
+                }
+                break;
+            }
+            default:
+                GGML_ABORT("fatal error");
+        }
+    }
+
+    return 0;
+}
+
+int32_t llama_detokenize_impl(
+        const struct llama_vocab & vocab,
+               const llama_token * tokens,
+                         int32_t   n_tokens,
+                            char * text,
+                         int32_t   text_len_max,
+                            bool   remove_special,
+                            bool   unparse_special) {
+    int32_t avail = text_len_max;
+    int32_t total = 0;
+
+    // remove the leading space
+    bool remove_space = vocab.tokenizer_add_space_prefix;
+
+    if (remove_special && vocab.tokenizer_add_bos) {
+        if (n_tokens > 0 && tokens[0] == vocab.special_bos_id) {
+            remove_space = false;
+            n_tokens--;
+            tokens++;
+        }
+    }
+
+    if (remove_special && vocab.tokenizer_add_eos) {
+        if (n_tokens > 0 && tokens[n_tokens-1] == vocab.special_eos_id) {
+            n_tokens--;
+        }
+    }
+
+    for (int32_t i = 0; i < n_tokens; ++i) {
+        GGML_ASSERT(avail >= 0);
+        int32_t n_chars = llama_token_to_piece_impl(vocab, tokens[i], text, avail, remove_space, unparse_special);
+        remove_space = false;
+        if (n_chars < 0) {
+            avail = 0;
+            total -= n_chars;
+        } else if (n_chars > 0) {
+            avail -= n_chars;
+            text  += n_chars;
+            total += n_chars;
+        }
+    }
+
+    if (total > text_len_max) {
+        return -total;
+    }
+
+    if (vocab.tokenizer_clean_spaces) {
+        text -= total;  // restart text
+
+        // first pass: characters ?!.,  //TODO: where do these characters come from?
+        const int32_t total1 = total;
+        total = total ? 1 : 0;
+        for (int32_t i = 1; i < total1; ++i) {
+            const char x = text[i];
+            if (text[i - 1] == ' ') {
+                if (x == '?' || x == '!' || x == '.' || x == ',') {  // " ?", " !", " .", " ,"
+                    total--;  // remove space
+                }
+            }
+            text[total++] = x;
+        }
+
+        // second pass: strip single apostrophe between spaces
+        const int32_t total2 = total;
+        total = total ? 1 : 0;
+        for (int32_t i = 1; i < total2; ++i) {
+            const char x = text[i];
+            if (x == '\'' && i + 1 < total2 && text[i - 1] == ' ' && text[i + 1] == ' ') {  // " ' "
+                total--;           // remove prev space
+                text[++i] = '\0';  // remove next space
+            }
+            text[total++] = x;
+        }
+
+        // third pass: apostrophe contractions  //NOTE: this makes sense?
+        const int32_t total3 = total;
+        total = total ? 1 : 0;
+        for (int32_t i = 1; i < total3; ++i) {
+            const char x = text[i];
+            if (text[i - 1] == ' ') {
+                if (x == '\'' && i + 1 < total3) {
+                    const char x1 = text[i + 1];
+                    if (x1 == 't' || x1 == 'd') {  // " 't", " 'd"
+                        //total--;  // remove space
+                    } else if (x1 == 's' || x1 == 'm') {  // " 's", " 'm"
+                        total--;  // remove space
+                    } else if (i + 2 < total3) {
+                        const char x2 = text[i + 2];
+                        if ((x1 == 'l' && x2 == 'l')) {  // " 'll"
+                            //total--;  // remove space
+                        } else if ((x1 == 'r' && x2 == 'e') || (x1 == 'v' && x2 == 'e')) {  // " 're", " 've"
+                            total--;  // remove space
+                        } else {
+                            //total--;  // remove space
+                        }
+                    } else {
+                        //total--;  // remove space
+                    }
+                }
+            }
+            text[total++] = x;
+        }
+    }
+
+    return total <= text_len_max ? total : -total;
+}
diff --git a/examples/talk-llama/llama-vocab.h b/examples/talk-llama/llama-vocab.h
new file mode 100644 (file)
index 0000000..7adfc16
--- /dev/null
@@ -0,0 +1,132 @@
+#pragma once
+
+#include "llama-impl.h"
+
+#include <string>
+#include <vector>
+#include <unordered_map>
+#include <map>
+
+struct llama_vocab {
+    using id    = llama_token;
+    using token = std::string;
+    using tattr = llama_token_attr;
+
+    struct token_data {
+        token text;
+        float score;
+        tattr attr;
+    };
+
+    enum llama_vocab_type     type     = LLAMA_VOCAB_TYPE_SPM;
+    enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+
+    int max_token_len = 0; // used for optimizing longest token search
+
+    std::unordered_map<token, id> token_to_id;
+    std::vector<token_data>       id_to_token;
+
+    std::vector<id>    cache_special_tokens;
+    std::vector<token> cache_token_to_piece; // llama_token_to_piece(special = true);
+
+    std::map<std::pair<std::string, std::string>, int> bpe_ranks;
+
+    // default LLaMA special tokens
+    id special_bos_id  = 1;
+    id special_eos_id  = 2;
+    id special_unk_id  = 0;
+    id special_sep_id  = -1;
+    id special_pad_id  = -1;
+    id special_cls_id  = -1;
+    id special_mask_id = -1;
+
+    id linefeed_id       = 13;
+    id special_prefix_id = -1;
+    id special_suffix_id = -1;
+    id special_middle_id = -1;
+    id special_eot_id    = -1; // TODO: move above after "eos_id", and here add "file separator" token
+    id special_eom_id    = -1;
+
+    // tokenizer flags
+    bool tokenizer_add_space_prefix = false;
+    bool tokenizer_add_bos          = false;
+    bool tokenizer_add_eos          = false;
+    bool tokenizer_ignore_merges    = false;
+    bool tokenizer_clean_spaces     = false;  // clean_up_tokenization_spaces
+    bool tokenizer_remove_extra_whitespaces   = false;
+    bool tokenizer_escape_whitespaces         = true;
+    bool tokenizer_treat_whitespace_as_suffix = false;
+
+    std::vector<char> precompiled_charsmap;
+
+    int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
+};
+
+const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx);
+
+//
+// internal API
+//
+
+// TODO: rename to llama_tokenize_impl
+// TODO: This should probably be in llama.h
+std::vector<llama_vocab::id> llama_tokenize_internal(
+        const llama_vocab & vocab,
+        std::string raw_text,
+        bool add_special,
+        bool parse_special = false);
+
+llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
+
+const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
+
+float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token);
+
+llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token);
+
+bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token);
+
+bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token);
+
+llama_token llama_token_bos_impl(const struct llama_vocab & vocab);
+llama_token llama_token_eos_impl(const struct llama_vocab & vocab);
+llama_token llama_token_cls_impl(const struct llama_vocab & vocab);
+llama_token llama_token_sep_impl(const struct llama_vocab & vocab);
+llama_token llama_token_nl_impl (const struct llama_vocab & vocab);
+llama_token llama_token_pad_impl(const struct llama_vocab & vocab);
+
+int32_t llama_add_bos_token_impl(const struct llama_vocab & vocab);
+int32_t llama_add_eos_token_impl(const struct llama_vocab & vocab);
+
+llama_token llama_token_prefix_impl(const struct llama_vocab & vocab);
+llama_token llama_token_middle_impl(const struct llama_vocab & vocab);
+llama_token llama_token_suffix_impl(const struct llama_vocab & vocab);
+llama_token llama_token_eot_impl   (const struct llama_vocab & vocab);
+llama_token llama_token_eom_impl   (const struct llama_vocab & vocab);
+
+int32_t llama_tokenize_impl(
+        const struct llama_vocab & vocab,
+                      const char * text,
+                         int32_t   text_len,
+                     llama_token * tokens,
+                         int32_t   n_tokens_max,
+                            bool   add_special,
+                            bool   parse_special);
+
+// does not write null-terminator to buf
+int32_t llama_token_to_piece_impl(
+        const struct llama_vocab & vocab,
+                     llama_token   token,
+                            char * buf,
+                         int32_t   length,
+                         int32_t   lstrip,
+                            bool   special);
+
+int32_t llama_detokenize_impl(
+        const struct llama_vocab & vocab,
+               const llama_token * tokens,
+                         int32_t   n_tokens,
+                            char * text,
+                         int32_t   text_len_max,
+                            bool   remove_special,
+                            bool   unparse_special);
index 2b9ace28584572ff60a8c7844c6f06f9d5e4f4ec..a7b1c9ebd9e37d1e9017bdec315b86ecd51a738b 100644 (file)
@@ -1,5 +1,7 @@
-#define LLAMA_API_INTERNAL
-#include "llama.h"
+#include "llama-impl.h"
+#include "llama-vocab.h"
+#include "llama-grammar.h"
+#include "llama-sampling.h"
 
 #include "unicode.h"
 
@@ -19,6 +21,8 @@
 #  include "ggml-sycl.h"
 #elif defined(GGML_USE_KOMPUTE)
 #   include "ggml-kompute.h"
+#elif defined(GGML_USE_CANN)
+#   include "ggml-cann.h"
 #endif
 
 #ifdef GGML_USE_BLAS
     #include <io.h>
 #endif
 
+#if __cplusplus >= 202000L
+    #define LU8(x) (const char*)(u8##x)
+#else
+    #define LU8(x) u8##x
+#endif
+
 #include <algorithm>
 #include <array>
 #include <cassert>
@@ -71,7 +81,6 @@
 #include <cstdio>
 #include <cstring>
 #include <ctime>
-#include <forward_list>
 #include <fstream>
 #include <functional>
 #include <future>
@@ -81,9 +90,6 @@
 #include <memory>
 #include <mutex>
 #include <numeric>
-#include <queue>
-#include <random>
-#include <regex>
 #include <set>
 #include <sstream>
 #include <thread>
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
 
-#ifdef __GNUC__
-#ifdef __MINGW32__
-#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
-#else
-#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
-#endif
-#else
-#define LLAMA_ATTRIBUTE_FORMAT(...)
-#endif
-
 // bump if necessary
-#define LLAMA_MAX_NODES   8192
-#define LLAMA_MAX_LAYERS  256
+#define LLAMA_MAX_LAYERS  512
 #define LLAMA_MAX_EXPERTS 160  // DeepSeekV2
 
-//
-// logging
-//
-
-LLAMA_ATTRIBUTE_FORMAT(2, 3)
-static void llama_log_internal        (ggml_log_level level, const char * format, ...);
-static void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data);
-
-#define LLAMA_LOG_INFO(...)  llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
-#define LLAMA_LOG_WARN(...)  llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
-#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
-
 //
 // helpers
 //
 
-static size_t utf8_len(char src) {
-    const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
-    uint8_t highbits = static_cast<uint8_t>(src) >> 4;
-    return lookup[highbits];
+// trim whitespace from the beginning and end of a string
+static std::string trim(const std::string & str) {
+    size_t start = 0;
+    size_t end = str.size();
+    while (start < end && isspace(str[start])) {
+        start += 1;
+    }
+    while (end > start && isspace(str[end - 1])) {
+        end -= 1;
+    }
+    return str.substr(start, end - start);
 }
 
 static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
-    std::string result;
-    for (size_t pos = 0; ; pos += search.length()) {
-        auto new_pos = s.find(search, pos);
-        if (new_pos == std::string::npos) {
-            result += s.substr(pos, s.size() - pos);
-            break;
-        }
-        result += s.substr(pos, new_pos - pos) + replace;
-        pos = new_pos;
+    if (search.empty()) {
+        return; // Avoid infinite loop if 'search' is an empty string
+    }
+    size_t pos = 0;
+    while ((pos = s.find(search, pos)) != std::string::npos) {
+        s.replace(pos, search.length(), replace);
+        pos += replace.length();
     }
-    s = std::move(result);
 }
 
 static bool is_float_close(float a, float b, float abs_tol) {
@@ -281,6 +268,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
 };
 
 enum llm_kv {
+    LLM_KV_GENERAL_TYPE,
     LLM_KV_GENERAL_ARCHITECTURE,
     LLM_KV_GENERAL_QUANTIZATION_VERSION,
     LLM_KV_GENERAL_ALIGNMENT,
@@ -371,9 +359,14 @@ enum llm_kv {
     LLM_KV_TOKENIZER_SUFFIX_ID,
     LLM_KV_TOKENIZER_MIDDLE_ID,
     LLM_KV_TOKENIZER_EOT_ID,
+    LLM_KV_TOKENIZER_EOM_ID,
+
+    LLM_KV_ADAPTER_TYPE,
+    LLM_KV_ADAPTER_LORA_ALPHA,
 };
 
 static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
+    { LLM_KV_GENERAL_TYPE,                  "general.type"                          },
     { LLM_KV_GENERAL_ARCHITECTURE,          "general.architecture"                  },
     { LLM_KV_GENERAL_QUANTIZATION_VERSION,  "general.quantization_version"          },
     { LLM_KV_GENERAL_ALIGNMENT,             "general.alignment"                     },
@@ -464,6 +457,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_TOKENIZER_SUFFIX_ID,            "tokenizer.ggml.suffix_token_id"          },
     { LLM_KV_TOKENIZER_MIDDLE_ID,            "tokenizer.ggml.middle_token_id"          },
     { LLM_KV_TOKENIZER_EOT_ID,               "tokenizer.ggml.eot_token_id"             },
+    { LLM_KV_TOKENIZER_EOM_ID,               "tokenizer.ggml.eom_token_id"             },
+
+    { LLM_KV_ADAPTER_TYPE,                  "adapter.type"       },
+    { LLM_KV_ADAPTER_LORA_ALPHA,            "adapter.lora.alpha" },
 };
 
 struct LLM_KV {
@@ -2065,6 +2062,8 @@ struct llama_state {
         ggml_backend_metal_log_set_callback(log_callback, log_callback_user_data);
 #elif defined(GGML_USE_CUDA)
         ggml_backend_cuda_log_set_callback(log_callback, log_callback_user_data);
+#elif defined(GGML_USE_CANN)
+        ggml_backend_cann_log_set_callback(log_callback, log_callback_user_data);
 #endif
     }
 
@@ -2258,8 +2257,7 @@ struct llama_hparams {
             return n_head_arr[il];
         }
 
-        GGML_ASSERT(false);
-        return 0;
+        GGML_ABORT("fatal error");
     }
 
     uint32_t n_head_kv(uint32_t il = 0) const {
@@ -2267,8 +2265,7 @@ struct llama_hparams {
             return n_head_kv_arr[il];
         }
 
-        GGML_ASSERT(false);
-        return 0;
+        GGML_ABORT("fatal error");
     }
 
     uint32_t n_ff(uint32_t il = 0) const {
@@ -2276,8 +2273,7 @@ struct llama_hparams {
             return n_ff_arr[il];
         }
 
-        GGML_ASSERT(false);
-        return 0;
+        GGML_ABORT("fatal error");
     }
 
     uint32_t n_gqa(uint32_t il = 0) const {
@@ -2454,6 +2450,7 @@ struct llama_layer {
     // long rope factors
     struct ggml_tensor * rope_long  = nullptr;
     struct ggml_tensor * rope_short = nullptr;
+    struct ggml_tensor * rope_freqs = nullptr;
 
     // bitnet scale
     struct ggml_tensor * wq_scale;
@@ -2565,72 +2562,6 @@ struct llama_control_vector {
     }
 };
 
-struct llama_vocab {
-    using id    = int32_t;
-    using token = std::string;
-    using tattr = llama_token_attr;
-
-    struct token_data {
-        token text;
-        float score;
-        tattr attr;
-    };
-
-    enum llama_vocab_type     type     = LLAMA_VOCAB_TYPE_SPM;
-    enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-
-    int max_token_len = 0; // used for optimizing longest token search
-
-    std::unordered_map<token, id> token_to_id;
-    std::vector<token_data>       id_to_token;
-
-    std::vector<id>    cache_special_tokens;
-    std::vector<token> cache_token_to_piece; // llama_token_to_piece(special = true);
-
-    std::map<std::pair<std::string, std::string>, int> bpe_ranks;
-
-    // default LLaMA special tokens
-    id special_bos_id  = 1;
-    id special_eos_id  = 2;
-    id special_unk_id  = 0;
-    id special_sep_id  = -1;
-    id special_pad_id  = -1;
-    id special_cls_id  = -1;
-    id special_mask_id = -1;
-
-    id linefeed_id       = 13;
-    id special_prefix_id = -1;
-    id special_suffix_id = -1;
-    id special_middle_id = -1;
-    id special_eot_id    = -1; // TODO: move above after "eos_id", and here add "file separator" token
-
-    // tokenizer flags
-    bool tokenizer_add_space_prefix = false;
-    bool tokenizer_add_bos          = false;
-    bool tokenizer_add_eos          = false;
-    bool tokenizer_ignore_merges    = false;
-    bool tokenizer_clean_spaces     = false;  // clean_up_tokenization_spaces
-    bool tokenizer_remove_extra_whitespaces   = false;
-    bool tokenizer_escape_whitespaces         = true;
-    bool tokenizer_treat_whitespace_as_suffix = false;
-
-    std::vector<char> precompiled_charsmap;
-
-    int find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
-        GGML_ASSERT(token_left.find(' ') == std::string::npos);
-        GGML_ASSERT(token_left.find('\n') == std::string::npos);
-        GGML_ASSERT(token_right.find(' ') == std::string::npos);
-        GGML_ASSERT(token_right.find('\n') == std::string::npos);
-
-        auto it = bpe_ranks.find(std::make_pair(token_left, token_right));
-        if (it == bpe_ranks.end()) {
-            return -1;
-        }
-
-        return it->second;
-    }
-};
-
 struct llama_model {
     e_model     type  = MODEL_UNKNOWN;
     llm_arch    arch  = LLM_ARCH_UNKNOWN;
@@ -2697,6 +2628,9 @@ struct llama_model {
     int64_t t_load_us = 0;
     int64_t t_start_us = 0;
 
+    // keep track of loaded lora adapters
+    std::set<struct llama_lora_adapter *> lora_adapters;
+
     ~llama_model() {
         for (struct ggml_context * ctx : ctxs) {
             ggml_free(ctx);
@@ -2709,11 +2643,19 @@ struct llama_model {
 #endif
             ggml_backend_buffer_free(buf);
         }
+        while (!lora_adapters.empty()) {
+            llama_lora_adapter_free(*lora_adapters.begin());
+        }
     }
 };
 
 struct llama_context {
-    llama_context(const llama_model & model) : model(model), t_start_us(model.t_start_us), t_load_us(model.t_load_us) {}
+    llama_context(const llama_model & model)
+        : model(model)
+        , sampling(llama_n_vocab(&model))
+        , t_start_us(model.t_start_us)
+        , t_load_us(model.t_load_us) {}
+
     ~llama_context() {
         ggml_backend_sched_free(sched);
 
@@ -2724,7 +2666,14 @@ struct llama_context {
         ggml_backend_buffer_free(buf_output);
     }
 
-    llama_cparams cparams;
+    const struct llama_model & model;
+
+    struct llama_cparams        cparams;
+    struct llama_sampling       sampling;
+    struct llama_kv_cache       kv_self;
+    struct llama_control_vector cvec;
+
+    std::unordered_map<struct llama_lora_adapter *, float> lora_adapters;
 
     std::vector<ggml_backend_t> backends;
 #ifdef GGML_USE_METAL
@@ -2735,26 +2684,16 @@ struct llama_context {
 #endif
     ggml_backend_t backend_cpu = nullptr;
 
-
-    const llama_model & model;
-
-    // key + value cache for the self attention
-    struct llama_kv_cache kv_self;
-
-    std::mt19937 rng;
-
     bool has_evaluated_once = false;
 
     int64_t t_start_us;
     int64_t t_load_us;
-    int64_t t_sample_us = 0;
     int64_t t_p_eval_us = 0;
     int64_t t_eval_us   = 0;
 
     int64_t t_compute_start_us = 0;
     int64_t n_queued_tokens = 0;
 
-    int32_t n_sample = 0; // number of tokens sampled
     int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
     int32_t n_eval   = 0; // number of eval calls
 
@@ -2810,9 +2749,49 @@ struct llama_context {
     struct ggml_tensor * inp_pos_bucket;    // I32 [n_batch|n_kv, n_batch]
     struct ggml_tensor * inp_embd_enc;      // F32 [n_embd, n_outputs_enc]
     struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
+};
 
-    // control vectors
-    struct llama_control_vector cvec;
+struct llama_lora_weight {
+    struct ggml_tensor * a = nullptr;
+    struct ggml_tensor * b = nullptr;
+    llama_lora_weight() = default;
+    llama_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b): a(a), b(b) {}
+};
+
+struct llama_lora_adapter {
+    struct llama_model * base_model;
+    // map tensor name to lora_a_b
+    std::unordered_map<std::string, struct llama_lora_weight> ab_map;
+    std::vector<struct ggml_context *> ctxs;
+    std::vector<ggml_backend_buffer_t> bufs;
+
+    float alpha;
+
+    llama_lora_adapter(struct llama_model * base_model): base_model(base_model) {
+        base_model->lora_adapters.insert(this);
+    }
+
+    llama_lora_weight * get_weight(struct ggml_tensor * w) {
+        std::string name(w->name);
+        auto pos = ab_map.find(name);
+        if (ab_map.find(name) != ab_map.end()) {
+            return &pos->second;
+        }
+        return nullptr;
+    }
+
+    ~llama_lora_adapter() {
+        for (struct ggml_context * ctx : ctxs) {
+            ggml_free(ctx);
+        }
+        for (ggml_backend_buffer_t buf : bufs) {
+            ggml_backend_buffer_free(buf);
+        }
+        auto pos = base_model->lora_adapters.find(this);
+        if (pos != base_model->lora_adapters.end()) {
+            base_model->lora_adapters.erase(pos);
+        }
+    }
 };
 
 static size_t llama_get_device_count(const llama_model & model) {
@@ -2823,6 +2802,8 @@ static size_t llama_get_device_count(const llama_model & model) {
     count = ggml_backend_sycl_get_device_count();
 #elif defined(GGML_USE_VULKAN)
     count = ggml_backend_vk_get_device_count();
+#elif defined(GGML_USE_CANN)
+    return ggml_backend_cann_get_device_count();
 #endif
 #if defined(GGML_USE_RPC)
     count += model.rpc_servers.size();
@@ -2855,6 +2836,8 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_
     if (buft == nullptr) {
         LLAMA_LOG_WARN("%s: cannot use GPU %d, check `vulkaninfo --summary`\n", __func__, gpu);
     }
+#elif defined(GGML_USE_CANN)
+    buft = ggml_backend_cann_buffer_type(gpu);
 #endif
 
     if (buft == nullptr) {
@@ -2915,6 +2898,11 @@ static size_t llama_get_device_memory(const llama_model & model, int device) {
     size_t free;
     ggml_backend_vk_get_device_memory(device, &free, &total);
     return free;
+#elif defined(GGML_USE_CANN)
+    size_t total;
+    size_t free;
+    ggml_backend_cann_get_device_memory(device, &free, &total);
+    return free;
 #else
     return 1;
 #endif
@@ -2944,7 +2932,7 @@ static bool llama_kv_cache_init(
 
     // TODO: find a nicer way to add other recurrent model architectures
     cache.recurrent = model.arch == LLM_ARCH_MAMBA;
-    cache.v_trans   = !cparams.flash_attn;
+    cache.v_trans   = !cache.recurrent && !cparams.flash_attn;
 
     cache.head = 0;
     cache.size = kv_size;
@@ -3578,6 +3566,15 @@ namespace GGUFMeta {
 
 using llama_buf_map = std::unordered_map<uint32_t, ggml_backend_buffer_t>;
 
+// TODO: update when needed or think of some clever automatic way to do this
+static size_t llama_model_max_nodes(const llama_model & /*model*/) {
+    //if (model.arch == LLM_ARCH_LLAMA && model.hparams.n_layer > ??) { // llama-3 405B
+    //    return 32768;
+    //}
+
+    return 8192;
+}
+
 struct llama_model_loader {
     int n_kv      = 0;
     int n_tensors = 0;
@@ -3628,7 +3625,7 @@ struct llama_model_loader {
         }
 
         if (param_overrides_p != nullptr) {
-            for (const struct llama_model_kv_override *p = param_overrides_p; p->key[0] != 0; p++) {
+            for (const struct llama_model_kv_override * p = param_overrides_p; p->key[0] != 0; p++) {
                 kv_overrides.insert({std::string(p->key), *p});
             }
         }
@@ -3782,6 +3779,9 @@ struct llama_model_loader {
                 case GGML_TYPE_IQ4_NL:  ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL;  break;
                 case GGML_TYPE_IQ4_XS:  ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS;  break;
                 case GGML_TYPE_IQ3_S:   ftype = LLAMA_FTYPE_MOSTLY_IQ3_S;   break;
+                case GGML_TYPE_Q4_0_4_4: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_4; break;
+                case GGML_TYPE_Q4_0_4_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_8; break;
+                case GGML_TYPE_Q4_0_8_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_8_8; break;
                 default:
                     {
                         LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max));
@@ -3793,7 +3793,7 @@ struct llama_model_loader {
             ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED);
 
             {
-                const int kid = gguf_find_key(meta, "general.file_type");
+                const int kid = gguf_find_key(meta, "general.file_type"); // TODO: use LLM_KV
                 if (kid >= 0) {
                     ftype = (llama_ftype) gguf_get_val_u32(meta, kid);
                 }
@@ -3925,7 +3925,9 @@ struct llama_model_loader {
                 throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
         }
 
-        GGML_ASSERT(arr_info.length <= N_MAX);
+        if (arr_info.length > N_MAX) {
+            throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX));
+        }
 
         std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
 
@@ -3961,8 +3963,6 @@ struct llama_model_loader {
     // get array of n <= N_MAX elements, or a single element repeated n times
     template<typename T, size_t N_MAX>
     bool get_key_or_arr(const std::string & key, std::array<T, N_MAX> & result, uint32_t n, const bool required = true) {
-        GGML_ASSERT(n <= N_MAX);
-
         const int kid = gguf_find_key(meta, key.c_str());
 
         if (kid < 0) {
@@ -3972,6 +3972,10 @@ struct llama_model_loader {
             return false;
         }
 
+        if (n > N_MAX) {
+            throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", (uint32_t) n, (uint32_t) N_MAX, key.c_str()));
+        }
+
         if (gguf_get_kv_type(meta, kid) == GGUF_TYPE_ARRAY) {
             struct GGUFMeta::ArrayInfo arr_info =
                 GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta, kid);
@@ -4441,40 +4445,39 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
     }
 
     switch (ftype) {
-        case LLAMA_FTYPE_ALL_F32:     return "all F32";
-        case LLAMA_FTYPE_MOSTLY_F16:  return "F16";
-        case LLAMA_FTYPE_MOSTLY_BF16: return "BF16";
-        case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0";
-        case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1";
-        case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16:
-                                      return "Q4_1, some F16";
-        case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0";
-        case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1";
-        case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0";
-
-        // K-quants
-        case LLAMA_FTYPE_MOSTLY_Q2_K:   return "Q2_K - Medium";
-        case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small";
-        case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small";
-        case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "Q3_K - Medium";
-        case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "Q3_K - Large";
-        case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "Q4_K - Small";
-        case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "Q4_K - Medium";
-        case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small";
-        case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium";
-        case LLAMA_FTYPE_MOSTLY_Q6_K:   return "Q6_K";
-        case LLAMA_FTYPE_MOSTLY_IQ2_XXS:return "IQ2_XXS - 2.0625 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ2_S:  return "IQ2_S - 2.5 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ2_M:  return "IQ2_M - 2.7 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ3_XXS:return "IQ3_XXS - 3.0625 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ1_S  :return "IQ1_S - 1.5625 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ1_M  :return "IQ1_M - 1.75 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ3_S:  return "IQ3_S - 3.4375 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ3_M:  return "IQ3_S mix - 3.66 bpw";
+        case LLAMA_FTYPE_ALL_F32:         return "all F32";
+        case LLAMA_FTYPE_MOSTLY_F16:      return "F16";
+        case LLAMA_FTYPE_MOSTLY_BF16:     return "BF16";
+        case LLAMA_FTYPE_MOSTLY_Q4_0:     return "Q4_0";
+        case LLAMA_FTYPE_MOSTLY_Q4_1:     return "Q4_1";
+        case LLAMA_FTYPE_MOSTLY_Q5_0:     return "Q5_0";
+        case LLAMA_FTYPE_MOSTLY_Q5_1:     return "Q5_1";
+        case LLAMA_FTYPE_MOSTLY_Q8_0:     return "Q8_0";
+        case LLAMA_FTYPE_MOSTLY_Q2_K:     return "Q2_K - Medium";
+        case LLAMA_FTYPE_MOSTLY_Q2_K_S:   return "Q2_K - Small";
+        case LLAMA_FTYPE_MOSTLY_Q3_K_S:   return "Q3_K - Small";
+        case LLAMA_FTYPE_MOSTLY_Q3_K_M:   return "Q3_K - Medium";
+        case LLAMA_FTYPE_MOSTLY_Q3_K_L:   return "Q3_K - Large";
+        case LLAMA_FTYPE_MOSTLY_Q4_K_S:   return "Q4_K - Small";
+        case LLAMA_FTYPE_MOSTLY_Q4_K_M:   return "Q4_K - Medium";
+        case LLAMA_FTYPE_MOSTLY_Q5_K_S:   return "Q5_K - Small";
+        case LLAMA_FTYPE_MOSTLY_Q5_K_M:   return "Q5_K - Medium";
+        case LLAMA_FTYPE_MOSTLY_Q6_K:     return "Q6_K";
+        case LLAMA_FTYPE_MOSTLY_IQ2_XXS:  return "IQ2_XXS - 2.0625 bpw";
+        case LLAMA_FTYPE_MOSTLY_IQ2_XS:   return "IQ2_XS - 2.3125 bpw";
+        case LLAMA_FTYPE_MOSTLY_IQ2_S:    return "IQ2_S - 2.5 bpw";
+        case LLAMA_FTYPE_MOSTLY_IQ2_M:    return "IQ2_M - 2.7 bpw";
+        case LLAMA_FTYPE_MOSTLY_IQ3_XS:   return "IQ3_XS - 3.3 bpw";
+        case LLAMA_FTYPE_MOSTLY_IQ3_XXS:  return "IQ3_XXS - 3.0625 bpw";
+        case LLAMA_FTYPE_MOSTLY_IQ1_S:    return "IQ1_S - 1.5625 bpw";
+        case LLAMA_FTYPE_MOSTLY_IQ1_M:    return "IQ1_M - 1.75 bpw";
+        case LLAMA_FTYPE_MOSTLY_IQ4_NL:   return "IQ4_NL - 4.5 bpw";
+        case LLAMA_FTYPE_MOSTLY_IQ4_XS:   return "IQ4_XS - 4.25 bpw";
+        case LLAMA_FTYPE_MOSTLY_IQ3_S:    return "IQ3_S - 3.4375 bpw";
+        case LLAMA_FTYPE_MOSTLY_IQ3_M:    return "IQ3_S mix - 3.66 bpw";
+        case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: return "Q4_0_4_4";
+        case LLAMA_FTYPE_MOSTLY_Q4_0_4_8: return "Q4_0_4_8";
+        case LLAMA_FTYPE_MOSTLY_Q4_0_8_8: return "Q4_0_8_8";
 
         default: return "unknown, may not work";
     }
@@ -4889,6 +4892,7 @@ static void llm_load_hparams(
             } break;
         case LLM_ARCH_PHI3:
             {
+                ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
 
                 switch (hparams.n_layer) {
@@ -4922,7 +4926,7 @@ static void llm_load_hparams(
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
                 switch (hparams.n_layer) {
-                    case 42: model.type = e_model::MODEL_SMALL; break;
+                    case 42: model.type = e_model::MODEL_7B; break;
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
@@ -4964,6 +4968,7 @@ static void llm_load_hparams(
                 hparams.attn_soft_cap = true;
 
                 switch (hparams.n_layer) {
+                    case 26: model.type = e_model::MODEL_2B; break;
                     case 42: model.type = e_model::MODEL_9B; break;
                     case 46: model.type = e_model::MODEL_27B; break;
                     default: model.type = e_model::MODEL_UNKNOWN;
@@ -5217,12 +5222,6 @@ static void llm_load_hparams(
     hparams.rope_type = llama_rope_type(&model);
 }
 
-// TODO: This should probably be in llama.h
-static std::vector<llama_vocab::id> llama_tokenize_internal(
-    const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special = false
-);
-static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch);
-
 static void llm_load_vocab(
         llama_model_loader & ml,
         llama_model & model) {
@@ -5284,6 +5283,7 @@ static void llm_load_vocab(
             if (merges_keyidx == -1) {
                 throw std::runtime_error("cannot find tokenizer merges in model file\n");
             }
+
             const int n_merges = gguf_get_arr_n(ctx, merges_keyidx);
             for (int i = 0; i < n_merges; i++) {
                 const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
@@ -5322,16 +5322,6 @@ static void llm_load_vocab(
             vocab.special_cls_id  = -1;
             vocab.special_mask_id = -1;
 
-            const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
-            if (add_space_prefix_keyidx != -1) {
-                vocab.tokenizer_add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
-            } // The default value of add_space_prefix is true.
-
-            const int remove_extra_whitespaces_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS).c_str());
-            if (remove_extra_whitespaces_keyidx != -1) {
-                vocab.tokenizer_remove_extra_whitespaces = gguf_get_val_bool(ctx, remove_extra_whitespaces_keyidx);
-            } // The default value of remove_extra_whitespaces is false.
-
             const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
             if (precompiled_charsmap_keyidx != -1) {
                 size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
@@ -5407,6 +5397,7 @@ static void llm_load_vocab(
             } else if (
                 tokenizer_pre == "command-r") {
                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_COMMAND_R;
+                vocab.tokenizer_clean_spaces = false;
             } else if (
                 tokenizer_pre == "qwen2") {
                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2;
@@ -5438,6 +5429,19 @@ static void llm_load_vocab(
             } else if (
                 tokenizer_pre == "jais") {
                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS;
+            } else if (
+                tokenizer_pre == "tekken") {
+                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_TEKKEN;
+                vocab.tokenizer_clean_spaces = false;
+                vocab.tokenizer_ignore_merges = true;
+                vocab.tokenizer_add_bos = true;
+            } else if (
+                tokenizer_pre == "smollm") {
+                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMOLLM;
+                vocab.tokenizer_clean_spaces = false;
+            } else if (
+                tokenizer_pre == "codeshell") {
+                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CODESHELL;
             } else {
                 throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
             }
@@ -5461,10 +5465,8 @@ static void llm_load_vocab(
             vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
         }
 
-        const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
-        if (add_space_prefix_keyidx != -1) {
-            vocab.tokenizer_add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
-        }
+        ml.get_key(LLM_KV_TOKENIZER_ADD_PREFIX,      vocab.tokenizer_add_space_prefix,         false);
+        ml.get_key(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.tokenizer_remove_extra_whitespaces, false);
     }
 
     const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str());
@@ -5556,7 +5558,7 @@ static void llm_load_vocab(
             }
         }
         try {
-            vocab.linefeed_id = llama_byte_to_token(vocab, '\n');
+            vocab.linefeed_id = llama_byte_to_token_impl(vocab, '\n');
         } catch (const std::exception & e) {
             LLAMA_LOG_WARN("%s: SPM vocabulary, but newline token not found: %s! Using special_pad_id instead.", __func__, e.what());
             vocab.linefeed_id = vocab.special_pad_id;
@@ -5583,6 +5585,7 @@ static void llm_load_vocab(
             { LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_suffix_id },
             { LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_middle_id },
             { LLM_KV_TOKENIZER_EOT_ID,    vocab.special_eot_id    },
+            { LLM_KV_TOKENIZER_EOM_ID,    vocab.special_eom_id    },
         };
 
         for (const auto & it : special_token_types) {
@@ -5635,12 +5638,23 @@ static void llm_load_vocab(
                 }
             }
         }
+
+        // find EOM token: "<|eom_id|>"
+        //
+        // TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOM_ID
+        //       for now, we apply this workaround to find the EOM token based on its text
+        if (vocab.special_eom_id == -1) {
+            const auto & t = vocab.token_to_id.find("<|eom_id|>");
+            if (t != vocab.token_to_id.end()) {
+                vocab.special_eom_id = t->second;
+            }
+        }
     }
 
     // build special tokens cache
     {
         for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
-            if (!(vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL)) {
+            if (vocab.id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
                 vocab.cache_special_tokens.push_back(id);
             }
         }
@@ -5871,13 +5885,6 @@ static bool llm_load_tensors(
 
     auto & hparams = model.hparams;
 
-#ifdef GGML_USE_SYCL
-    // disable MoE with SYCL until mul_mat_id is updated
-    if (hparams.n_expert > 0) {
-        n_gpu_layers = 0;
-    }
-#endif
-
     model.split_mode   = split_mode;
     model.main_gpu     = main_gpu;
     model.n_gpu_layers = n_gpu_layers;
@@ -6052,10 +6059,10 @@ static bool llm_load_tensors(
 
                         layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
 
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head});
+                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
+                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
 
                         // optional bias tensors
                         layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
@@ -6065,6 +6072,8 @@ static bool llm_load_tensors(
 
                         layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
 
+                        layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_embd/n_head/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+
                         if (n_expert == 0) {
                             layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
                             layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
@@ -7803,6 +7812,58 @@ static void llm_build_kv_store(
     ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
 }
 
+// do mat_mul, while optionally apply lora
+static struct ggml_tensor * llm_build_lora_mm(
+        struct llama_context & lctx,
+         struct ggml_context * ctx0,
+          struct ggml_tensor * w,
+          struct ggml_tensor * cur) {
+    struct ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
+    for (auto & it : lctx.lora_adapters) {
+        struct llama_lora_weight * lora = it.first->get_weight(w);
+        if (lora == nullptr) {
+            continue;
+        }
+        const float alpha = it.first->alpha;
+        const float rank  = (float) lora->b->ne[0];
+        const float scale = alpha ? it.second * alpha / rank : it.second;
+        struct ggml_tensor * ab_cur = ggml_mul_mat(
+            ctx0, lora->b,
+            ggml_mul_mat(ctx0, lora->a, cur)
+        );
+        ab_cur = ggml_scale(ctx0, ab_cur, scale);
+        res = ggml_add(ctx0, res, ab_cur);
+    }
+    return res;
+}
+
+// do mat_mul_id, while optionally apply lora
+static struct ggml_tensor * llm_build_lora_mm_id(
+        struct llama_context & lctx,
+         struct ggml_context * ctx0,
+          struct ggml_tensor * w,   // struct ggml_tensor * as
+          struct ggml_tensor * cur, // struct ggml_tensor * b
+          struct ggml_tensor * ids) {
+    struct ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
+    for (auto & it : lctx.lora_adapters) {
+        struct llama_lora_weight * lora = it.first->get_weight(w);
+        if (lora == nullptr) {
+            continue;
+        }
+        const float alpha = it.first->alpha;
+        const float rank  = (float) lora->b->ne[0];
+        const float scale = alpha ? it.second * alpha / rank : it.second;
+        struct ggml_tensor * ab_cur = ggml_mul_mat_id(
+            ctx0, lora->b,
+            ggml_mul_mat_id(ctx0, lora->a, cur, ids),
+            ids
+        );
+        ab_cur = ggml_scale(ctx0, ab_cur, scale);
+        res = ggml_add(ctx0, res, ab_cur);
+    }
+    return res;
+}
+
 static struct ggml_tensor * llm_build_norm(
         struct ggml_context * ctx,
          struct ggml_tensor * cur,
@@ -7837,6 +7898,7 @@ static struct ggml_tensor * llm_build_norm(
 
 static struct ggml_tensor * llm_build_ffn(
         struct ggml_context * ctx,
+       struct llama_context & lctx,
          struct ggml_tensor * cur,
          struct ggml_tensor * up,
          struct ggml_tensor * up_b,
@@ -7852,7 +7914,7 @@ static struct ggml_tensor * llm_build_ffn(
           llm_ffn_gate_type   type_gate,
          const llm_build_cb & cb,
                         int   il) {
-    struct ggml_tensor * tmp = up ? ggml_mul_mat(ctx, up, cur) : cur;
+    struct ggml_tensor * tmp = up ? llm_build_lora_mm(lctx, ctx, up, cur) : cur;
     cb(tmp, "ffn_up", il);
 
     if (up_b) {
@@ -7869,12 +7931,12 @@ static struct ggml_tensor * llm_build_ffn(
         switch (type_gate) {
             case LLM_FFN_SEQ:
                 {
-                    cur = ggml_mul_mat(ctx, gate, tmp);
+                    cur = llm_build_lora_mm(lctx, ctx, gate, tmp);
                     cb(cur, "ffn_gate", il);
                 } break;
             case LLM_FFN_PAR:
                 {
-                    cur = ggml_mul_mat(ctx, gate, cur);
+                    cur = llm_build_lora_mm(lctx, ctx, gate, cur);
                     cb(cur, "ffn_gate", il);
                 } break;
         }
@@ -7942,7 +8004,7 @@ static struct ggml_tensor * llm_build_ffn(
     }
 
     if (down) {
-        cur = ggml_mul_mat(ctx, down, cur);
+        cur = llm_build_lora_mm(lctx, ctx, down, cur);
     }
 
     if (down_b) {
@@ -7963,6 +8025,7 @@ static struct ggml_tensor * llm_build_ffn(
 
 static struct ggml_tensor * llm_build_moe_ffn(
         struct ggml_context * ctx,
+       struct llama_context & lctx,
          struct ggml_tensor * cur,
          struct ggml_tensor * gate_inp,
          struct ggml_tensor * up_exps,
@@ -7979,7 +8042,7 @@ static struct ggml_tensor * llm_build_moe_ffn(
     int64_t n_embd = cur->ne[0];
     int64_t n_tokens = cur->ne[1];
 
-    ggml_tensor * logits = ggml_mul_mat(ctx, gate_inp, cur); // [n_expert, n_tokens]
+    ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens]
     cb(logits, "ffn_moe_logits", il);
 
     ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
@@ -8011,10 +8074,10 @@ static struct ggml_tensor * llm_build_moe_ffn(
     }
 
     cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
-    ggml_tensor * up = ggml_mul_mat_id(ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
+    ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
     cb(up, "ffn_moe_up", il);
 
-    ggml_tensor * gate = ggml_mul_mat_id(ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
+    ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
     cb(gate, "ffn_moe_gate", il);
 
     switch (type_op) {
@@ -8029,13 +8092,13 @@ static struct ggml_tensor * llm_build_moe_ffn(
                 cb(gate, "ffn_moe_gelu", il);
             } break;
         default:
-            GGML_ASSERT(false);
+            GGML_ABORT("fatal error");
     }
 
     ggml_tensor * par = ggml_mul(ctx, up, gate); // [n_ff, n_expert_used, n_tokens]
     cb(par, "ffn_moe_gate_par", il);
 
-    ggml_tensor * experts = ggml_mul_mat_id(ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
+    ggml_tensor * experts = llm_build_lora_mm_id(lctx, ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
     cb(experts, "ffn_moe_down", il);
 
     experts = ggml_mul(ctx, experts, weights);
@@ -8063,9 +8126,7 @@ static struct ggml_tensor * llm_build_moe_ffn(
 
 static struct ggml_tensor * llm_build_kqv(
         struct ggml_context * ctx,
-          const llama_model & model,
-        const llama_hparams & hparams,
-        const llama_cparams & cparams,
+       struct llama_context & lctx,
        const llama_kv_cache & kv,
          struct ggml_cgraph * graph,
          struct ggml_tensor * wo,
@@ -8077,6 +8138,10 @@ static struct ggml_tensor * llm_build_kqv(
                     float     kq_scale,
          const llm_build_cb & cb,
                     int       il) {
+    const llama_model   & model   = lctx.model;
+    const llama_hparams & hparams = lctx.model.hparams;
+    const llama_cparams & cparams = lctx.cparams;
+
     const int64_t n_ctx         = cparams.n_ctx;
     const int64_t n_head        = hparams.n_head(il);
     const int64_t n_head_kv     = hparams.n_head_kv(il);
@@ -8122,7 +8187,7 @@ static struct ggml_tensor * llm_build_kqv(
         struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
         cb(kq, "kq", il);
 
-        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
+        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
             // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
             // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
             ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
@@ -8175,7 +8240,7 @@ static struct ggml_tensor * llm_build_kqv(
     ggml_build_forward_expand(graph, cur);
 
     if (wo) {
-        cur = ggml_mul_mat(ctx, wo, cur);
+        cur = llm_build_lora_mm(lctx, ctx, wo, cur);
     }
 
     if (wo_b) {
@@ -8191,9 +8256,7 @@ static struct ggml_tensor * llm_build_kqv(
 
 static struct ggml_tensor * llm_build_kv(
         struct ggml_context * ctx,
-          const llama_model & model,
-        const llama_hparams & hparams,
-        const llama_cparams & cparams,
+       struct llama_context & lctx,
        const llama_kv_cache & kv,
          struct ggml_cgraph * graph,
          struct ggml_tensor * wo,
@@ -8208,6 +8271,8 @@ static struct ggml_tensor * llm_build_kv(
                     float     kq_scale,
          const llm_build_cb & cb,
                     int       il) {
+    const llama_hparams & hparams = lctx.model.hparams;
+    const llama_cparams & cparams = lctx.cparams;
 
     // these nodes are added to the graph together so that they are not reordered
     // by doing so, the number of splits in the graph is reduced
@@ -8219,7 +8284,7 @@ static struct ggml_tensor * llm_build_kv(
 
     struct ggml_tensor * cur;
 
-    cur  = llm_build_kqv(ctx, model, hparams, cparams, kv, graph, wo, wo_b,
+    cur  = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b,
             q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il);
     cb(cur, "kqv_out", il);
 
@@ -8354,7 +8419,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_k_shift() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         GGML_ASSERT(kv_self.size == n_ctx);
 
@@ -8385,7 +8450,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_s_copy() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         GGML_ASSERT(kv_self.recurrent);
 
@@ -8408,7 +8473,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         for (uint32_t i = 0; i < ids.size(); ++i) {
             const uint32_t id = ids[i];
@@ -8486,6 +8551,10 @@ struct llm_build_context {
         // choose long/short freq factors based on the context size
         const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max;
 
+        if (model.layers[il].rope_freqs != nullptr) {
+            return model.layers[il].rope_freqs;
+        }
+
         if (n_ctx_pre_seq > hparams.n_ctx_orig_yarn) {
             return model.layers[il].rope_long;
         }
@@ -8590,8 +8659,8 @@ struct llm_build_context {
                 } break;
             default:
                 {
-                    GGML_ASSERT(false && "unknown pooling type");
-                } break;
+                    GGML_ABORT("unknown pooling type");
+                }
         }
 
         cb(cur, "result_embd_pooled", -1);
@@ -8649,7 +8718,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_llama() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -8680,22 +8749,25 @@ struct llm_build_context {
 
             // self-attention
             {
+                // rope freq factors for llama3; may return nullptr for llama2 and other models
+                struct ggml_tensor * rope_factors = build_rope_factors(il);
+
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 if (model.layers[il].bq) {
                     Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
                     cb(Qcur, "Qcur", il);
                 }
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 if (model.layers[il].bk) {
                     Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
                     cb(Kcur, "Kcur", il);
                 }
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 if (model.layers[il].bv) {
                     Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -8703,20 +8775,20 @@ struct llm_build_context {
                 }
 
                 Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
                     n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
                     n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -8739,7 +8811,7 @@ struct llm_build_context {
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
+                cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
                         model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
                         model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
@@ -8753,7 +8825,7 @@ struct llm_build_context {
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_moe_ffn(ctx0, cur,
+                cur = llm_build_moe_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_gate_inp,
                         model.layers[il].ffn_up_exps,
                         model.layers[il].ffn_gate_exps,
@@ -8783,7 +8855,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -8792,7 +8864,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_baichuan() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -8819,13 +8891,13 @@ struct llm_build_context {
 
             // self-attention
             {
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
                 switch (model.type) {
@@ -8846,12 +8918,12 @@ struct llm_build_context {
                         Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd/n_head, n_head, n_tokens);
                         break;
                     default:
-                        GGML_ASSERT(false);
+                        GGML_ABORT("fatal error");
                 }
                 cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -8873,7 +8945,7 @@ struct llm_build_context {
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
+                cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   NULL, NULL,
                         model.layers[il].ffn_gate, NULL, NULL,
                         model.layers[il].ffn_down, NULL, NULL,
@@ -8898,7 +8970,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -8907,7 +8979,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_xverse() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -8934,13 +9006,13 @@ struct llm_build_context {
 
             // self-attention
             {
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
                 Qcur = ggml_rope_ext(
@@ -8956,7 +9028,7 @@ struct llm_build_context {
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -8978,7 +9050,7 @@ struct llm_build_context {
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
+                cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   NULL, NULL,
                         model.layers[il].ffn_gate, NULL, NULL,
                         model.layers[il].ffn_down, NULL, NULL,
@@ -9001,7 +9073,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -9010,7 +9082,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_falcon() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -9050,7 +9122,7 @@ struct llm_build_context {
                     cur = attn_norm;
                 }
 
-                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
@@ -9077,7 +9149,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -9094,7 +9166,7 @@ struct llm_build_context {
 
             // feed forward
             {
-                cur = llm_build_ffn(ctx0, attn_norm, // !! use the attn norm, not the result
+                cur = llm_build_ffn(ctx0, lctx, attn_norm, // !! use the attn norm, not the result
                         model.layers[il].ffn_up,   NULL, NULL,
                         NULL,                      NULL, NULL,
                         model.layers[il].ffn_down, NULL, NULL,
@@ -9121,7 +9193,7 @@ struct llm_build_context {
                 LLM_NORM, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -9130,7 +9202,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_grok() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -9166,21 +9238,21 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 if (model.layers[il].bq) {
                     Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
                     cb(Qcur, "Qcur", il);
                 }
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 if (model.layers[il].bk) {
                     Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
                     cb(Kcur, "Kcur", il);
                 }
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 if (model.layers[il].bv) {
                     Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -9201,7 +9273,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
             }
@@ -9233,7 +9305,7 @@ struct llm_build_context {
                     LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_norm", il);
 
-            cur = llm_build_moe_ffn(ctx0, cur,
+            cur = llm_build_moe_ffn(ctx0, lctx, cur,
                     model.layers[il].ffn_gate_inp,
                     model.layers[il].ffn_up_exps,
                     model.layers[il].ffn_gate_exps,
@@ -9272,7 +9344,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
 
         // Grok
         // multiply logits by output_multiplier_scale of 0.5773502691896257
@@ -9287,7 +9359,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_dbrx() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -9323,7 +9395,7 @@ struct llm_build_context {
                 struct ggml_tensor * Kcur = nullptr;
                 struct ggml_tensor * Vcur = nullptr;
 
-                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
@@ -9351,7 +9423,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -9374,7 +9446,7 @@ struct llm_build_context {
                                  LLM_NORM, cb, il);
             cb(cur, "attn_out_norm", il);
 
-            cur = llm_build_moe_ffn(ctx0, cur,
+            cur = llm_build_moe_ffn(ctx0, lctx, cur,
                     model.layers[il].ffn_gate_inp,
                     model.layers[il].ffn_up_exps,
                     model.layers[il].ffn_gate_exps,
@@ -9403,7 +9475,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
 
         cb(cur, "result_output", -1);
 
@@ -9413,7 +9485,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_starcoder() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -9445,7 +9517,7 @@ struct llm_build_context {
 
             // self-attention
             {
-                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
@@ -9461,7 +9533,7 @@ struct llm_build_context {
 
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -9485,7 +9557,7 @@ struct llm_build_context {
                         LLM_NORM, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
+                cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
                         NULL,                      NULL,                        NULL,
                         model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
@@ -9508,7 +9580,7 @@ struct llm_build_context {
                 LLM_NORM, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -9517,7 +9589,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_refact() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -9540,13 +9612,13 @@ struct llm_build_context {
 
             // self-attention
             {
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
@@ -9555,7 +9627,7 @@ struct llm_build_context {
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                 cb(Qcur, "Qcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -9577,7 +9649,7 @@ struct llm_build_context {
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
+                cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   NULL, NULL,
                         model.layers[il].ffn_gate, NULL, NULL,
                         model.layers[il].ffn_down, NULL, NULL,
@@ -9602,7 +9674,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -9611,7 +9683,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_bert() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -9654,7 +9726,7 @@ struct llm_build_context {
 
             // self-attention
             if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) {
-                Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq);
+                Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur), model.layers[il].bq);
                 cb(Qcur, "Qcur", il);
 
                 if (model.layers[il].attn_q_norm) {
@@ -9664,7 +9736,7 @@ struct llm_build_context {
                             LLM_NORM, cb, il);
                 }
 
-                Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), model.layers[il].bk);
+                Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur), model.layers[il].bk);
                 cb(Kcur, "Kcur", il);
 
                 if (model.layers[il].attn_k_norm) {
@@ -9673,14 +9745,14 @@ struct llm_build_context {
                             model.layers[il].attn_k_norm_b,
                             LLM_NORM, cb, il);
                 }
-                Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), model.layers[il].bv);
+                Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur), model.layers[il].bv);
                 cb(Vcur, "Vcur", il);
 
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
             } else {
                 // compute Q and K and RoPE them
-                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
@@ -9729,7 +9801,7 @@ struct llm_build_context {
 
             ggml_build_forward_expand(gf, cur);
 
-            cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
+            cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
             if (model.layers[il].bo) {
                 cb(cur, "kqv_wo", il);
             }
@@ -9762,21 +9834,21 @@ struct llm_build_context {
 
             // feed-forward network
             if (model.arch == LLM_ARCH_BERT) {
-                cur = llm_build_ffn(ctx0, cur,
+                cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
                         NULL,                      NULL,                        NULL,
                         model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
             } else if (model.arch == LLM_ARCH_JINA_BERT_V2) {
-                cur = llm_build_ffn(ctx0, cur,
+                cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   NULL,                        NULL,
                         model.layers[il].ffn_gate, NULL,                        NULL,
                         model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
             } else {
-                cur = llm_build_ffn(ctx0, cur,
+                cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   NULL, NULL,
                         model.layers[il].ffn_gate, NULL, NULL,
                         model.layers[il].ffn_down, NULL, NULL,
@@ -9805,7 +9877,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_bloom() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -9834,7 +9906,7 @@ struct llm_build_context {
 
             // self-attention
             {
-                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
@@ -9850,7 +9922,7 @@ struct llm_build_context {
 
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -9874,7 +9946,7 @@ struct llm_build_context {
                         LLM_NORM, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
+                cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
                         NULL,                      NULL,                        NULL,
                         model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
@@ -9897,7 +9969,7 @@ struct llm_build_context {
                 LLM_NORM, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -9906,7 +9978,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_mpt() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -9944,7 +10016,7 @@ struct llm_build_context {
             {
                 cur = attn_norm;
 
-                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 if (model.layers[il].bqkv){
@@ -9982,13 +10054,13 @@ struct llm_build_context {
                     Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                     Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
 
-                    cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                    cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                             model.layers[il].wo, model.layers[il].bo,
                             Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 } else {
                     Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 
-                    cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                    cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                             model.layers[il].wo, model.layers[il].bo,
                             Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 }
@@ -10012,7 +10084,7 @@ struct llm_build_context {
                         model.layers[il].ffn_norm_b,
                         LLM_NORM, cb, il);
                 cb(cur, "ffn_norm", il);
-                cur = llm_build_ffn(ctx0, cur,
+                cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
                         NULL,                      NULL,                        NULL,
                         model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
@@ -10037,7 +10109,7 @@ struct llm_build_context {
                 LLM_NORM, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -10077,21 +10149,21 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 if (model.layers[il].bq) {
                     Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
                     cb(Qcur, "Qcur", il);
                 }
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 if (model.layers[il].bk) {
                     Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
                     cb(Kcur, "Kcur", il);
                 }
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 if (model.layers[il].bv) {
                     Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -10133,7 +10205,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -10161,7 +10233,7 @@ struct llm_build_context {
                     // parallel residual
                     cur = inpSA;
                 }
-                cur = llm_build_ffn(ctx0, cur,
+                cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   NULL, NULL,
                         model.layers[il].ffn_gate, NULL, NULL,
                         model.layers[il].ffn_down, NULL, NULL,
@@ -10187,7 +10259,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -10196,7 +10268,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_qwen() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -10222,7 +10294,7 @@ struct llm_build_context {
 
             // self-attention
             {
-                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
@@ -10252,7 +10324,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -10274,7 +10346,7 @@ struct llm_build_context {
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
+                cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   NULL, NULL,
                         model.layers[il].ffn_gate, NULL, NULL,
                         model.layers[il].ffn_down, NULL, NULL,
@@ -10299,7 +10371,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -10308,7 +10380,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_qwen2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -10337,17 +10409,17 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
                 cb(Qcur, "Qcur", il);
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
                 cb(Kcur, "Kcur", il);
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
                 cb(Vcur, "Vcur", il);
@@ -10366,7 +10438,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -10387,7 +10459,7 @@ struct llm_build_context {
                     LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_norm", il);
 
-            cur = llm_build_ffn(ctx0, cur,
+            cur = llm_build_ffn(ctx0, lctx, cur,
                     model.layers[il].ffn_up,   NULL, NULL,
                     model.layers[il].ffn_gate, NULL, NULL,
                     model.layers[il].ffn_down, NULL, NULL,
@@ -10411,7 +10483,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -10420,7 +10492,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_qwen2moe() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -10452,17 +10524,17 @@ struct llm_build_context {
             // self_attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
                 cb(Qcur, "Qcur", il);
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
                 cb(Kcur, "Kcur", il);
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
                 cb(Vcur, "Vcur", il);
@@ -10481,7 +10553,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -10504,7 +10576,7 @@ struct llm_build_context {
             cb(cur, "ffn_norm", il);
 
             ggml_tensor * moe_out =
-                    llm_build_moe_ffn(ctx0, cur,
+                    llm_build_moe_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_gate_inp,
                         model.layers[il].ffn_up_exps,
                         model.layers[il].ffn_gate_exps,
@@ -10517,14 +10589,14 @@ struct llm_build_context {
 
             // FFN shared expert
             {
-                ggml_tensor * cur_gate_inp = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp_shexp, cur);
+                ggml_tensor * cur_gate_inp = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_gate_inp_shexp, cur);
                 cb(cur_gate_inp, "ffn_shexp_gate_inp", il);
 
                 // sigmoid
                 ggml_tensor * cur_gate = ggml_div(ctx0, ggml_silu(ctx0, cur_gate_inp), cur_gate_inp);
                 cb(cur_gate, "ffn_shexp_gate", il);
 
-                ggml_tensor * cur_ffn = llm_build_ffn(ctx0, cur,
+                ggml_tensor * cur_ffn = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up_shexp,   NULL, NULL,
                         model.layers[il].ffn_gate_shexp, NULL, NULL,
                         model.layers[il].ffn_down_shexp, NULL, NULL,
@@ -10557,7 +10629,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -10566,7 +10638,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_phi2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -10599,7 +10671,7 @@ struct llm_build_context {
                 struct ggml_tensor * Vcur = nullptr;
 
                 if (model.layers[il].wqkv) {
-                    cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output);
+                    cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, attn_norm_output);
                     cb(cur, "wqkv", il);
 
                     cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
@@ -10609,9 +10681,9 @@ struct llm_build_context {
                     Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
                     Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
                 } else {
-                    Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
-                    Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
-                    Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
+                    Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
+                    Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
+                    Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
                 }
 
                 cb(Qcur, "Qcur", il);
@@ -10638,7 +10710,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
             }
@@ -10653,7 +10725,7 @@ struct llm_build_context {
 
             // FF
             {
-                ffn_output = llm_build_ffn(ctx0, attn_norm_output,
+                ffn_output = llm_build_ffn(ctx0, lctx, attn_norm_output,
                         model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
                         NULL,                      NULL,                        NULL,
                         model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
@@ -10677,7 +10749,7 @@ struct llm_build_context {
                 LLM_NORM, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output_no_bias", -1);
 
         cur = ggml_add(ctx0, cur, model.output_b);
@@ -10687,7 +10759,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_phi3() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
@@ -10702,7 +10774,7 @@ struct llm_build_context {
         struct ggml_tensor * inp_pos = build_inp_pos();
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+        struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa();
 
         for (int il = 0; il < n_layer; ++il) {
             auto residual = inpL;
@@ -10723,7 +10795,7 @@ struct llm_build_context {
                 struct ggml_tensor * Vcur = nullptr;
 
                 if (model.layers[il].wqkv) {
-                    cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output);
+                    cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, attn_norm_output);
                     cb(cur, "wqkv", il);
 
                     Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd)));
@@ -10731,9 +10803,9 @@ struct llm_build_context {
                     Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)));
                 }
                 else {
-                    Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
-                    Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
-                    Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
+                    Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
+                    Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
+                    Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
                 }
 
                 cb(Qcur, "Qcur", il);
@@ -10758,9 +10830,9 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask_swa, n_tokens, kv_head, n_kv, 1.0f, cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -10782,7 +10854,7 @@ struct llm_build_context {
             // special-case: the up and gate tensors are merged into a single tensor
             // TOOD: support into llm_build_ffn
             {
-                cur = llm_build_ffn(ctx0, cur,
+                cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   NULL, NULL,
                         NULL,                      NULL, NULL,
                         model.layers[il].ffn_down, NULL, NULL,
@@ -10805,7 +10877,7 @@ struct llm_build_context {
             LLM_NORM_RMS, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -10845,13 +10917,13 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
                 Qcur = ggml_rope_ext(
@@ -10866,7 +10938,7 @@ struct llm_build_context {
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -10884,7 +10956,7 @@ struct llm_build_context {
 
             // feed-forward network
             {
-                cur = llm_build_ffn(ctx0, cur,
+                cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   NULL, NULL,
                         model.layers[il].ffn_gate, NULL, NULL,
                         model.layers[il].ffn_down, NULL, NULL,
@@ -10910,7 +10982,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -10919,7 +10991,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_gpt2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -10952,7 +11024,7 @@ struct llm_build_context {
 
             // self-attention
             {
-                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
@@ -10968,7 +11040,7 @@ struct llm_build_context {
 
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -10992,7 +11064,7 @@ struct llm_build_context {
                         LLM_NORM, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
+                cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
                         NULL,                      NULL,                        NULL,
                         model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
@@ -11015,7 +11087,7 @@ struct llm_build_context {
                 LLM_NORM, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -11024,7 +11096,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_codeshell() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -11051,7 +11123,7 @@ struct llm_build_context {
 
             // self-attention
             {
-                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
@@ -11079,7 +11151,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -11103,7 +11175,7 @@ struct llm_build_context {
                         LLM_NORM, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
+                cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
                         NULL,                      NULL,                        NULL,
                         model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
@@ -11126,7 +11198,7 @@ struct llm_build_context {
                 LLM_NORM, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -11135,7 +11207,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_orion() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -11164,21 +11236,21 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 // if (model.layers[il].bq) {
                 //     Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
                 //     cb(Qcur, "Qcur", il);
                 // }
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 // if (model.layers[il].bk) {
                 //     Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
                 //     cb(Kcur, "Kcur", il);
                 // }
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 // if (model.layers[il].bv) {
                 //     Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -11199,7 +11271,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -11220,7 +11292,7 @@ struct llm_build_context {
                     LLM_NORM, cb, il);
             cb(cur, "ffn_norm", il);
 
-            cur = llm_build_ffn(ctx0, cur,
+            cur = llm_build_ffn(ctx0, lctx, cur,
                     model.layers[il].ffn_up,   NULL, NULL,
                     model.layers[il].ffn_gate, NULL, NULL,
                     model.layers[il].ffn_down, NULL, NULL,
@@ -11244,7 +11316,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -11253,7 +11325,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_internlm2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -11282,21 +11354,21 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 if (model.layers[il].bq) {
                     Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
                     cb(Qcur, "Qcur", il);
                 }
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 if (model.layers[il].bk) {
                     Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
                     cb(Kcur, "Kcur", il);
                 }
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 if (model.layers[il].bv) {
                     Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -11317,7 +11389,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -11338,7 +11410,7 @@ struct llm_build_context {
                     LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_norm", il);
 
-            cur = llm_build_ffn(ctx0, cur,
+            cur = llm_build_ffn(ctx0, lctx, cur,
                     model.layers[il].ffn_up,   NULL, NULL,
                     model.layers[il].ffn_gate, NULL, NULL,
                     model.layers[il].ffn_down, NULL, NULL,
@@ -11362,7 +11434,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -11374,7 +11446,7 @@ struct llm_build_context {
     //      https://github.com/ggerganov/llama.cpp/issues/5276#issuecomment-1925774738
     // based on the original build_llama() function
     struct ggml_cgraph * build_minicpm() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -11413,21 +11485,21 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 if (model.layers[il].bq) {
                     Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
                     cb(Qcur, "Qcur", il);
                 }
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 if (model.layers[il].bk) {
                     Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
                     cb(Kcur, "Kcur", il);
                 }
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 if (model.layers[il].bv) {
                     Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -11448,7 +11520,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -11475,7 +11547,7 @@ struct llm_build_context {
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
+                cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   NULL, NULL,
                         model.layers[il].ffn_gate, NULL, NULL,
                         model.layers[il].ffn_down, NULL, NULL,
@@ -11509,7 +11581,7 @@ struct llm_build_context {
         cb(cur, "lmhead_scaling", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -11518,7 +11590,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_gemma() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head_k = hparams.n_embd_head_k;
 
@@ -11546,13 +11618,13 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
                 Qcur = ggml_rope_ext(
@@ -11570,7 +11642,7 @@ struct llm_build_context {
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
             }
@@ -11592,7 +11664,7 @@ struct llm_build_context {
 
             // feed-forward network
             {
-                cur = llm_build_ffn(ctx0, cur,
+                cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   NULL, NULL,
                         model.layers[il].ffn_gate, NULL, NULL,
                         model.layers[il].ffn_down, NULL, NULL,
@@ -11617,7 +11689,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -11626,7 +11698,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_gemma2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head_k = hparams.n_embd_head_k;
 
@@ -11659,13 +11731,13 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
                 Qcur = ggml_rope_ext(
@@ -11674,7 +11746,13 @@ struct llm_build_context {
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Qcur, "Qcur", il);
 
-                Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head)));
+                // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
+                switch (model.type) {
+                    case e_model::MODEL_2B:
+                    case e_model::MODEL_9B:  Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));   break;
+                    case e_model::MODEL_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break;
+                    default: GGML_ABORT("fatal error");
+                };
                 cb(Qcur, "Qcur_scaled", il);
 
                 Kcur = ggml_rope_ext(
@@ -11683,7 +11761,7 @@ struct llm_build_context {
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il);
             }
@@ -11710,7 +11788,7 @@ struct llm_build_context {
 
             // feed-forward network
             {
-                cur = llm_build_ffn(ctx0, cur,
+                cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   NULL, NULL,
                         model.layers[il].ffn_gate, NULL, NULL,
                         model.layers[il].ffn_down, NULL, NULL,
@@ -11740,7 +11818,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
 
         // final logit soft-capping
         cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
@@ -11756,7 +11834,7 @@ struct llm_build_context {
 
 
     struct ggml_cgraph * build_starcoder2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -11785,21 +11863,21 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 if (model.layers[il].bq) {
                     Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
                     cb(Qcur, "Qcur", il);
                 }
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 if (model.layers[il].bk) {
                     Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
                     cb(Kcur, "Kcur", il);
                 }
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 if (model.layers[il].bv) {
                     Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -11820,7 +11898,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -11842,7 +11920,7 @@ struct llm_build_context {
                     LLM_NORM, cb, il);
             cb(cur, "ffn_norm", il);
 
-            cur = llm_build_ffn(ctx0, cur,
+            cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
                         NULL,                      NULL,                        NULL,
                         model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
@@ -11866,7 +11944,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -11875,7 +11953,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_mamba() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t d_model = n_embd;
         const int64_t d_conv  = hparams.ssm_d_conv;
@@ -11918,7 +11996,7 @@ struct llm_build_context {
             cb(cur, "attn_norm", il);
 
             // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens}
-            struct ggml_tensor * xz = ggml_mul_mat(ctx0, model.layers[il].ssm_in, cur);
+            struct ggml_tensor * xz = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_in, cur);
             // split the above in two
             // => {d_inner, n_tokens}
             struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0);
@@ -11957,14 +12035,14 @@ struct llm_build_context {
             // ssm
             {
                 // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens}
-                struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x);
+                struct ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_x, x);
                 // split
                 struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0);
                 struct ggml_tensor * B  = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
                 struct ggml_tensor * C  = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
 
                 // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens}
-                dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt);
+                dt = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_dt, dt);
                 dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
 
                 // Custom operator to optimize the parallel associative scan
@@ -11995,7 +12073,7 @@ struct llm_build_context {
                 y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
 
                 // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens}
-                cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_out, y);
             }
 
             // residual
@@ -12014,7 +12092,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -12024,7 +12102,7 @@ struct llm_build_context {
 
     struct ggml_cgraph * build_command_r() {
 
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -12053,21 +12131,21 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 if (model.layers[il].bq) {
                     Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
                     cb(Qcur, "Qcur", il);
                 }
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 if (model.layers[il].bk) {
                     Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
                     cb(Kcur, "Kcur", il);
                 }
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 if (model.layers[il].bv) {
                     Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -12113,7 +12191,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -12130,7 +12208,7 @@ struct llm_build_context {
 
             // feed-forward network
             {
-                cur = llm_build_ffn(ctx0, ffn_inp,
+                cur = llm_build_ffn(ctx0, lctx, ffn_inp,
                         model.layers[il].ffn_up,   NULL, NULL,
                         model.layers[il].ffn_gate, NULL, NULL,
                         model.layers[il].ffn_down, NULL, NULL,
@@ -12157,7 +12235,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
 
         if (f_logit_scale) {
             cur = ggml_scale(ctx0, cur, f_logit_scale);
@@ -12178,7 +12256,7 @@ struct llm_build_context {
     //   * removed bias
     //   * removed MoE
     struct ggml_cgraph * build_olmo() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -12210,21 +12288,21 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 if (hparams.f_clamp_kqv > 0.0f) {
                     Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
                     cb(Qcur, "Qcur", il);
                 }
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 if (hparams.f_clamp_kqv > 0.0f) {
                     Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
                     cb(Kcur, "Kcur", il);
                 }
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 if (hparams.f_clamp_kqv > 0.0f) {
                     Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
@@ -12245,7 +12323,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, nullptr,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -12267,7 +12345,7 @@ struct llm_build_context {
                     LLM_NORM, cb, il);
             cb(cur, "ffn_norm", il);
 
-            cur = llm_build_ffn(ctx0, cur,
+            cur = llm_build_ffn(ctx0, lctx, cur,
                     model.layers[il].ffn_up,   NULL, NULL,
                     model.layers[il].ffn_gate, NULL, NULL,
                     model.layers[il].ffn_down, NULL, NULL,
@@ -12293,7 +12371,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -12302,7 +12380,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_openelm() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -12333,7 +12411,7 @@ struct llm_build_context {
 
             // self-attention
             {
-                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 cur = ggml_reshape_3d(ctx0, cur, n_embd_head_k, n_head_qkv, n_tokens);
@@ -12372,7 +12450,7 @@ struct llm_build_context {
                 Vcur = ggml_reshape_2d(ctx0, Vcur, n_embd_head * n_head_kv, n_tokens);
                 cb(Qcur, "Vcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -12394,7 +12472,7 @@ struct llm_build_context {
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
+                cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   NULL, NULL,
                         model.layers[il].ffn_gate, NULL, NULL,
                         model.layers[il].ffn_down, NULL, NULL,
@@ -12418,7 +12496,7 @@ struct llm_build_context {
                 LLM_NORM_RMS, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -12427,7 +12505,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_gptneox() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -12453,7 +12531,7 @@ struct llm_build_context {
 
             // self-attention
             {
-                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
@@ -12481,7 +12559,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -12506,7 +12584,7 @@ struct llm_build_context {
                         LLM_NORM, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
+                cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
                         NULL,                      NULL,                        NULL,
                         model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
@@ -12537,7 +12615,7 @@ struct llm_build_context {
                         LLM_NORM, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
+                cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
                         NULL,                      NULL,                        NULL,
                         model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
@@ -12560,7 +12638,7 @@ struct llm_build_context {
                 LLM_NORM, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -12569,7 +12647,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_arctic() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -12601,13 +12679,13 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
                 Qcur = ggml_rope_ext(
@@ -12624,7 +12702,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -12646,7 +12724,7 @@ struct llm_build_context {
                     LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_norm", il);
 
-            cur = llm_build_ffn(ctx0, cur,
+            cur = llm_build_ffn(ctx0, lctx, cur,
                     model.layers[il].ffn_up,   NULL, NULL,
                     model.layers[il].ffn_gate, NULL, NULL,
                     model.layers[il].ffn_down, NULL, NULL,
@@ -12663,7 +12741,7 @@ struct llm_build_context {
                     LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_norm_exps", il);
 
-            cur = llm_build_moe_ffn(ctx0, cur,
+            cur = llm_build_moe_ffn(ctx0, lctx, cur,
                     model.layers[il].ffn_gate_inp,
                     model.layers[il].ffn_up_exps,
                     model.layers[il].ffn_gate_exps,
@@ -12692,7 +12770,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -12701,7 +12779,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_deepseek2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -12846,7 +12924,7 @@ struct llm_build_context {
                 struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
                 cb(k_states, "k_states", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
             }
@@ -12862,13 +12940,13 @@ struct llm_build_context {
             struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
             cb(ffn_inp, "ffn_inp", il);
 
-            if ((uint32_t) il < hparams.n_layer_dense_lead) {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
+            if ((uint32_t) il < hparams.n_layer_dense_lead) {
+                cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   NULL, NULL,
                         model.layers[il].ffn_gate, NULL, NULL,
                         model.layers[il].ffn_down, NULL, NULL,
@@ -12877,13 +12955,8 @@ struct llm_build_context {
                 cb(cur, "ffn_out", il);
             } else {
                 // MoE branch
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
                 ggml_tensor * moe_out =
-                        llm_build_moe_ffn(ctx0, cur,
+                        llm_build_moe_ffn(ctx0, lctx, cur,
                             model.layers[il].ffn_gate_inp,
                             model.layers[il].ffn_up_exps,
                             model.layers[il].ffn_gate_exps,
@@ -12896,7 +12969,7 @@ struct llm_build_context {
 
                 // FFN shared expert
                 {
-                    ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, cur,
+                    ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, cur,
                             model.layers[il].ffn_up_shexp,   NULL, NULL,
                             model.layers[il].ffn_gate_shexp, NULL, NULL,
                             model.layers[il].ffn_down_shexp, NULL, NULL,
@@ -12934,7 +13007,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_bitnet() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -12961,7 +13034,7 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale);
                 cb(Qcur, "Qcur", il);
                 if (model.layers[il].bq) {
@@ -12970,7 +13043,7 @@ struct llm_build_context {
                 }
 
                 // B1.K
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale);
                 cb(Kcur, "Kcur", il);
                 if (model.layers[il].bk) {
@@ -12979,7 +13052,7 @@ struct llm_build_context {
                 }
 
                 // B1.V
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale);
                 cb(Vcur, "Vcur", il);
                 if (model.layers[il].bv) {
@@ -13001,7 +13074,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         NULL, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
 
@@ -13010,7 +13083,7 @@ struct llm_build_context {
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "attn_sub_norm", il);
 
-                cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
                 cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale);
                 if (model.layers[il].bo) {
                     cur = ggml_add(ctx0, cur, model.layers[il].bo);
@@ -13034,7 +13107,7 @@ struct llm_build_context {
                     LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_norm", il);
 
-            cur = llm_build_ffn(ctx0, cur,
+            cur = llm_build_ffn(ctx0, lctx, cur,
                     model.layers[il].ffn_up,   NULL, model.layers[il].ffn_up_scale,
                     model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_scale,
                     NULL,                      NULL, NULL,
@@ -13047,7 +13120,7 @@ struct llm_build_context {
                             LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_sub_norm", il);
 
-            cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur);
+            cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_down, cur);
             cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale);
             cb(cur, "ffn_down", il);
 
@@ -13066,7 +13139,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.tok_embd, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.tok_embd, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -13074,7 +13147,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_t5() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -13168,7 +13241,7 @@ struct llm_build_context {
                     cb(cur, "ffn_norm", il);
 
                     // T5 uses relu, flan-T5 uses gelu-gated
-                    cur = llm_build_ffn(ctx0, cur,
+                    cur = llm_build_ffn(ctx0, lctx, cur,
                             model.layers[il].ffn_up_enc,   NULL, NULL,
                             model.layers[il].ffn_gate_enc, NULL, NULL,
                             model.layers[il].ffn_down_enc, NULL, NULL,
@@ -13200,6 +13273,8 @@ struct llm_build_context {
                     LLM_NORM_RMS, cb, -1);
             cb(cur, "result_norm", -1);
         } else {
+            GGML_ASSERT(n_outputs_enc > 0 && "call llama_encode() first");
+
             struct ggml_tensor * embd_enc       = llm_build_inp_embd_enc();
             struct ggml_tensor * pos_bucket_dec = llm_build_pos_bucket(true);
 
@@ -13346,7 +13421,7 @@ struct llm_build_context {
                     cb(cur, "ffn_norm", il);
 
                     // T5 uses relu, flan-T5 uses gelu-gated
-                    cur = llm_build_ffn(ctx0, cur,
+                    cur = llm_build_ffn(ctx0, lctx, cur,
                             model.layers[il].ffn_up,   NULL, NULL,
                             model.layers[il].ffn_gate, NULL, NULL,
                             model.layers[il].ffn_down, NULL, NULL,
@@ -13389,7 +13464,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_jais() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -13412,7 +13487,7 @@ struct llm_build_context {
 
             // self-attention
             {
-                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
@@ -13428,7 +13503,7 @@ struct llm_build_context {
 
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/float(n_embd_head), cb, il);
             }
@@ -13452,7 +13527,7 @@ struct llm_build_context {
                         LLM_NORM, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
+                cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
                         model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
                         model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
@@ -13471,7 +13546,7 @@ struct llm_build_context {
                 LLM_NORM, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
 
         cb(cur, "result_output", -1);
 
@@ -13481,7 +13556,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_chatglm() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -13513,7 +13588,7 @@ struct llm_build_context {
                 struct ggml_tensor * Kcur = nullptr;
                 struct ggml_tensor * Vcur = nullptr;
 
-                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
@@ -13541,7 +13616,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur_rope", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
 
@@ -13566,7 +13641,7 @@ struct llm_build_context {
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
+                cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   NULL, NULL,
                         NULL,                      NULL, NULL,
                         model.layers[il].ffn_down, NULL, NULL,
@@ -13586,7 +13661,7 @@ struct llm_build_context {
                 LLM_NORM_RMS, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -13841,7 +13916,7 @@ static struct ggml_cgraph * llama_build_graph(
                 result = llm.build_jais();
             } break;
         default:
-            GGML_ASSERT(false);
+            GGML_ABORT("fatal error");
     }
 
     // add on pooling layer
@@ -13965,18 +14040,23 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         "causal attention is not supported by this model"
     );
 
-    if (lctx.inp_KQ_mask) {
+    if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) {
         // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
         if (cparams.causal_attn && !lctx.is_encoding) {
             const int64_t n_kv     = kv_self.n;
             const int64_t n_tokens = batch.n_tokens;
 
-            GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
 
-            float * data     = (float *) lctx.inp_KQ_mask->data;
+            float * data     = nullptr;
             float * data_swa = nullptr;
 
+            if (lctx.inp_KQ_mask) {
+                GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
+                data = (float *) lctx.inp_KQ_mask->data;
+            }
+
             if (lctx.inp_KQ_mask_swa) {
+                GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer));
                 data_swa = (float *) lctx.inp_KQ_mask_swa->data;
             }
 
@@ -13994,12 +14074,15 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
                             f = -INFINITY;
                         } else {
                             if (hparams.use_alibi) {
-                                f = -fabs(lctx.kv_self.cells[i].pos - pos);
+                                f = -std::abs(lctx.kv_self.cells[i].pos - pos);
                             } else {
                                 f = 0.0f;
                             }
                         }
-                        data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
+
+                        if (data) {
+                            data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
+                        }
 
                         // may need to cut off old tokens for sliding window
                         if (data_swa) {
@@ -14011,9 +14094,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
                     }
                 }
 
-                for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
-                    for (int j = 0; j < n_kv; ++j) {
-                        data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
+                if (data) {
+                    for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
+                        for (int j = 0; j < n_kv; ++j) {
+                            data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
+                        }
+                    }
+                }
+
+                if (data_swa) {
+                    for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
+                        for (int j = 0; j < n_kv; ++j) {
+                            data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
+                        }
                     }
                 }
             }
@@ -14035,7 +14128,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
                         for (int s = 0; s < batch.n_seq_id[i]; ++s) {
                             if (batch.seq_id[i][s] == seq_id) {
                                 if (hparams.use_alibi) {
-                                    f = -fabs(batch.pos[i] - batch.pos[j]);
+                                    f = -std::abs(batch.pos[i] - batch.pos[j]);
                                 } else {
                                     f = 0.0f;
                                 }
@@ -14622,8 +14715,8 @@ static int llama_decode_internal(
                     } break;
                 case LLAMA_POOLING_TYPE_UNSPECIFIED:
                     {
-                        GGML_ASSERT(false && "unknown pooling type");
-                    } break;
+                        GGML_ABORT("unknown pooling type");
+                    }
             }
         }
         n_outputs_prev += lctx.n_outputs;
@@ -14808,9 +14901,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
     // each move requires 6*n_layer tensors (see build_defrag)
     //   - source view, destination view, copy operation
     //   - x2 for keys and values
-    //const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer);
+    //const uint32_t max_moves = llama_model_max_nodes(model)/(6*n_layer);
     // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
-    const uint32_t max_moves = (LLAMA_MAX_NODES - 2*n_layer)/(6*n_layer);
+    const uint32_t max_moves = (llama_model_max_nodes(lctx.model) - 2*n_layer)/(6*n_layer);
 
     // determine which KV cells to move where
     //
@@ -15013,6 +15106,10 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
 
     // apply K-shift if needed
     if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
+        if (lctx.model.arch == LLM_ARCH_DEEPSEEK2) { // not supported due to MLA
+            GGML_ABORT("Deepseek2 does not support K-shift");
+        }
+
         {
             ggml_backend_sched_reset(lctx.sched);
 
@@ -15091,2541 +15188,35 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
 }
 
 //
-// tokenizer
+// quantization
 //
 
-static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) {
-    return vocab.type;
-}
+struct quantize_state_internal {
+    const llama_model                 & model;
+    const llama_model_quantize_params * params;
 
-static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL;
-}
+    int n_attention_wv    = 0;
+    int n_ffn_down        = 0;
+    int n_ffn_gate        = 0;
+    int n_ffn_up          = 0;
+    int i_attention_wv    = 0;
+    int i_ffn_down        = 0;
+    int i_ffn_gate        = 0;
+    int i_ffn_up          = 0;
 
-static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNKNOWN;
-}
+    int n_k_quantized     = 0;
+    int n_fallback        = 0;
 
-static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_CONTROL;
-}
+    bool has_imatrix      = false;
 
-static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_BYTE;
-}
+    // used to figure out if a model shares tok_embd with the output weight
+    bool has_output       = false;
 
-static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED;
-}
-
-static bool llama_is_unused_token(const llama_vocab& vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNUSED;
-}
-
-static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
-    GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
-    GGML_ASSERT(llama_is_byte_token(vocab, id));
-    const auto & token_data = vocab.id_to_token.at(id);
-    switch (llama_vocab_get_type(vocab)) {
-        case LLAMA_VOCAB_TYPE_SPM:
-        case LLAMA_VOCAB_TYPE_UGM: {
-            auto buf = token_data.text.substr(3, 2);
-            return strtol(buf.c_str(), NULL, 16);
-        }
-        case LLAMA_VOCAB_TYPE_BPE: {
-            GGML_ASSERT(false);
-            return unicode_utf8_to_byte(token_data.text); // TODO: why is this here after GGML_ASSERT?
-        }
-        case LLAMA_VOCAB_TYPE_WPM: {
-            GGML_ASSERT(false);
-        }
-        default:
-            GGML_ASSERT(false);
-    }
-}
-
-static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
-    GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
-    static const char * hex = "0123456789ABCDEF";
-    switch (llama_vocab_get_type(vocab)) {
-        case LLAMA_VOCAB_TYPE_SPM:
-        case LLAMA_VOCAB_TYPE_UGM: {
-            const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
-            auto token = vocab.token_to_id.find(buf);
-            if (token != vocab.token_to_id.end()) {
-                return (*token).second;
-            }
-            // Try to fall back to just the byte as a string
-            const char buf2[2] = { (char)ch, 0 };
-            return vocab.token_to_id.at(buf2);
-        }
-        case LLAMA_VOCAB_TYPE_WPM:
-        case LLAMA_VOCAB_TYPE_BPE: {
-            return vocab.token_to_id.at(unicode_byte_to_utf8(ch));
-        }
-        default:
-            GGML_ASSERT(false);
-    }
-}
-
-static void llama_escape_whitespace(std::string & text) {
-    replace_all(text, " ", "\xe2\x96\x81");
-}
-
-static void llama_unescape_whitespace(std::string & word) {
-    replace_all(word, "\xe2\x96\x81", " ");
-}
-
-struct llm_symbol {
-    using index = int;
-    index prev;
-    index next;
-    const char * text;
-    size_t n;
-};
-
-static_assert(std::is_trivially_copyable<llm_symbol>::value, "llm_symbol is not trivially copyable");
-
-// SPM tokenizer
-// original implementation:
-// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
-
-struct llm_bigram_spm {
-    struct comparator {
-        bool operator()(llm_bigram_spm & l, llm_bigram_spm & r) {
-            return (l.score < r.score) || (l.score == r.score && l.left > r.left);
-        }
-    };
-    using queue_storage = std::vector<llm_bigram_spm>;
-    using queue = std::priority_queue<llm_bigram_spm, queue_storage, comparator>;
-    llm_symbol::index left;
-    llm_symbol::index right;
-    float score;
-    size_t size;
-};
-
-struct llm_tokenizer_spm {
-    llm_tokenizer_spm(const llama_vocab & vocab) : vocab(vocab) {}
-
-    void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
-        // split string into utf8 chars
-        int index = 0;
-        size_t offs = 0;
-        while (offs < text.size()) {
-            llm_symbol sym;
-            size_t len = utf8_len(text[offs]);
-            sym.text = text.c_str() + offs;
-            sym.n = std::min(len, text.size() - offs);
-            offs += sym.n;
-            sym.prev = index - 1;
-            sym.next = offs == text.size() ? -1 : index + 1;
-            index++;
-            symbols.emplace_back(sym);
-        }
-
-        // seed the work queue with all possible 2-character tokens.
-        for (size_t i = 1; i < symbols.size(); ++i) {
-            try_add_bigram(i - 1, i);
-        }
-
-        // keep substituting the highest frequency pairs for as long as we can.
-        while (!work_queue.empty()) {
-            auto bigram = work_queue.top();
-            work_queue.pop();
-
-            auto & left_sym = symbols[bigram.left];
-            auto & right_sym = symbols[bigram.right];
-
-            // if one of the symbols already got merged, skip it.
-            if (left_sym.n == 0 || right_sym.n == 0 ||
-                left_sym.n + right_sym.n != bigram.size) {
-                continue;
-            }
-
-            // merge the right sym into the left one
-            left_sym.n += right_sym.n;
-            right_sym.n = 0;
-
-            //LLAMA_LOG_INFO("left = '%*s' size = %zu\n", (int) left_sym.n, left_sym.text, bigram.size);
-
-            // remove the right sym from the chain
-            left_sym.next = right_sym.next;
-            if (right_sym.next >= 0) {
-                symbols[right_sym.next].prev = bigram.left;
-            }
-
-            // find more substitutions
-            try_add_bigram(left_sym.prev, bigram.left);
-            try_add_bigram(bigram.left, left_sym.next);
-        }
-
-        for (int i = 0; i != -1; i = symbols[i].next) {
-            auto & symbol = symbols[i];
-            resegment(symbol, output);
-        }
-    }
-
-private:
-    void resegment(llm_symbol & symbol, std::vector<llama_vocab::id> & output) {
-        auto text = std::string(symbol.text, symbol.n);
-        auto token = vocab.token_to_id.find(text);
-
-        // Do we need to support is_unused?
-        if (token != vocab.token_to_id.end()) {
-            output.push_back((*token).second);
-            return;
-        }
-
-        const auto p = rev_merge.find(text);
-
-        if (p == rev_merge.end()) {
-            // output any symbols that did not form tokens as bytes.
-            output.reserve(output.size() + symbol.n);
-            for (int j = 0; j < (int)symbol.n; ++j) {
-                llama_vocab::id token_id = llama_byte_to_token(vocab, symbol.text[j]);
-                output.push_back(token_id);
-            }
-            return;
-        }
-
-        resegment(symbols[p->second.first],  output);
-        resegment(symbols[p->second.second], output);
-    }
-
-    void try_add_bigram(int left, int right) {
-        if (left == -1 || right == -1) {
-            return;
-        }
-
-        const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n);
-        auto token = vocab.token_to_id.find(text);
-
-        if (token == vocab.token_to_id.end()) {
-            return;
-        }
-
-        if (static_cast<size_t>((*token).second) >= vocab.id_to_token.size()) {
-            return;
-        }
-
-        const auto & tok_data = vocab.id_to_token[(*token).second];
-
-        llm_bigram_spm bigram;
-        bigram.left  = left;
-        bigram.right = right;
-        bigram.score = tok_data.score;
-        bigram.size  = text.size();
-
-        work_queue.push(bigram);
-
-        // Do we need to support is_unused?
-        rev_merge[text] = std::make_pair(left, right);
-    }
-
-    const llama_vocab & vocab;
-
-    std::vector<llm_symbol> symbols;
-    llm_bigram_spm::queue work_queue;
-
-    std::map<std::string, std::pair<int, int>> rev_merge;
-};
-
-// BPE tokenizer
-// adapted from https://github.com/cmp-nct/ggllm.cpp [MIT License]
-// tried to simplify unicode stuff, so most likely does not work 100% correctly!
-
-// TODO: there are a lot of common parts between spm and bpe tokenizers, should be refactored and reused
-
-struct llm_bigram_bpe {
-    struct comparator {
-        bool operator()(const llm_bigram_bpe & l, const llm_bigram_bpe & r) const {
-            return l.rank > r.rank || (l.rank == r.rank && l.left > r.left);
-        }
-    };
-
-    using queue_storage = std::vector<llm_bigram_bpe>;
-    using queue = std::priority_queue<llm_bigram_bpe, queue_storage, comparator>;
-    llm_symbol::index left;
-    llm_symbol::index right;
-    std::string text;
-    int rank;
-    size_t size;
-};
-
-struct llm_tokenizer_bpe {
-    llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {
-        GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE);
-        switch (vocab.type_pre) {
-            case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
-                regex_exprs = {
-                    // original regex from tokenizer.json
-                    //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
-
-                    // adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989
-                    "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
-                };
-                break;
-            case LLAMA_VOCAB_PRE_TYPE_DBRX:
-            case LLAMA_VOCAB_PRE_TYPE_SMAUG:
-                regex_exprs = {
-                    // same as llama3
-                    "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
-                };
-                break;
-            case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM:
-                regex_exprs = {
-                    "[\r\n]",
-                    "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
-                    "\\s?[!-/:-~!-/:-~‘-‟ -。]+",
-                    "\\s+$",
-                    "[一-龥ࠀ-一가-퟿]+",
-                    "\\p{N}+",
-                };
-                break;
-            case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
-                regex_exprs = {
-                    "[\r\n]",
-                    "\\s?\\p{L}+",
-                    "\\s?\\p{P}+",
-                    "[一-龥ࠀ-一가-퟿]+",
-                    "\\p{N}",
-                };
-                break;
-            case LLAMA_VOCAB_PRE_TYPE_FALCON:
-                regex_exprs = {
-                    "[\\p{P}\\$\\+<=>\\^~\\|`]+",
-                    "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
-                    "[0-9][0-9][0-9]",
-                };
-                break;
-            case LLAMA_VOCAB_PRE_TYPE_MPT:
-                // TODO: MPT pre-tokenization regexes are unknown
-                //       the following are close, but not exact. run the following:
-                //       ./bin/test-tokenizer-0 ../models/ggml-vocab-mpt.gguf
-                GGML_ASSERT("MPT pre-tokenization regexes are unknown - fixes needed");
-                regex_exprs = {
-                    "\\s?\\p{L}+",
-                    "\\s?\\p{P}+",
-                    "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
-                };
-                break;
-            case LLAMA_VOCAB_PRE_TYPE_STARCODER:
-            case LLAMA_VOCAB_PRE_TYPE_REFACT:
-            case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
-                regex_exprs = {
-                    "\\p{N}",
-                    "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
-                };
-                break;
-            case LLAMA_VOCAB_PRE_TYPE_GPT2:
-            case LLAMA_VOCAB_PRE_TYPE_OLMO:
-            case LLAMA_VOCAB_PRE_TYPE_JAIS:
-                regex_exprs = {
-                    "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
-                };
-                break;
-            case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
-            case LLAMA_VOCAB_PRE_TYPE_QWEN2:
-                regex_exprs = {
-                    // original regex from tokenizer.json
-                    // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
-                    "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
-                };
-                break;
-            case LLAMA_VOCAB_PRE_TYPE_PORO:
-                regex_exprs = {
-                    " ?[^(\\s|.,!?…。,、।۔،)]+",
-                };
-                break;
-            case LLAMA_VOCAB_PRE_TYPE_CHATGLM4:
-                regex_exprs = {
-                    "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
-                };
-                break;
-            case LLAMA_VOCAB_PRE_TYPE_VIKING:
-                regex_exprs = {
-                    "\\p{N}",
-                    " ?[^(\\s|.,!?…。,、।۔،)]+",
-                };
-                break;
-            default:
-                // default regex for BPE tokenization pre-processing
-                regex_exprs = {
-                    "[\\p{P}\\$\\+<=>\\^~\\|]+",
-                    "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
-                    "\\p{N}+",
-                    "[0-9][0-9][0-9]",
-                };
-                break;
-        }
-    }
-
-    void append(const llama_vocab::id token_id, std::vector<llama_vocab::id> & output) const {
-        output.push_back(token_id);
-    }
-
-    bool append_bos(std::vector<llama_vocab::id> & output) const {
-        if (vocab.tokenizer_add_bos) {
-            GGML_ASSERT(vocab.special_bos_id != -1);
-            output.push_back(vocab.special_bos_id);
-            return true;
-        }
-        return false;
-    }
-
-    bool append_eos(std::vector<llama_vocab::id> & output) const {
-        if (vocab.tokenizer_add_eos) {
-            GGML_ASSERT(vocab.special_eos_id != -1);
-            output.push_back(vocab.special_eos_id);
-            return true;
-        }
-        return false;
-    }
-
-    void check_double_bos_eos(const std::vector<llama_vocab::id> & output) const {
-        if (vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
-            LLAMA_LOG_WARN(
-                "%s: Added a BOS token to the prompt as specified by the model but the prompt "
-                "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
-                "Are you sure this is what you want?\n", __FUNCTION__);
-        }
-        if (vocab.tokenizer_add_eos && output.size() >= 2 && *(output.end()-2) == vocab.special_eos_id) {
-            LLAMA_LOG_WARN(
-                "%s: Added a EOS token to the prompt as specified by the model but the prompt "
-                "also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. "
-                "Are you sure this is what you want?\n", __FUNCTION__);
-        }
-    }
-
-    void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
-        int final_prev_index = -1;
-
-        const auto word_collection = unicode_regex_split(text, regex_exprs);
-
-        symbols_final.clear();
-
-        for (auto & word : word_collection) {
-            work_queue = llm_bigram_bpe::queue();
-            symbols.clear();
-
-            int index = 0;
-            size_t offset = 0;
-
-            if (vocab.tokenizer_ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
-                symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
-                offset = word.size();
-            }
-
-            while (offset < word.size()) {
-                llm_symbol sym;
-                size_t char_len = std::min(word.size() - offset, (size_t) ::utf8_len(word[offset]));
-                sym.text = word.c_str() + offset;
-                sym.n = char_len;
-                offset += sym.n;
-                sym.prev = index - 1;
-                sym.next = offset == word.size() ? -1 : index + 1;
-                index++;
-                symbols.emplace_back(sym);
-            }
-            for (size_t i = 1; i < symbols.size(); ++i) {
-                add_new_bigram(i - 1, i);
-            }
-
-            // build token(s)
-            while (!work_queue.empty()) {
-                auto bigram = work_queue.top();
-                work_queue.pop();
-
-                auto & left_symbol = symbols[bigram.left];
-                auto & right_symbol = symbols[bigram.right];
-
-                if (left_symbol.n == 0 || right_symbol.n == 0) {
-                    continue;
-                }
-                std::string left_token = std::string(left_symbol.text, left_symbol.n);
-                std::string right_token = std::string(right_symbol.text, right_symbol.n);
-                if (left_token + right_token != bigram.text) {
-                    continue;  // Skip this bigram if it's outdated
-                }
-
-                // merge the right sym into the left one
-                left_symbol.n += right_symbol.n;
-                right_symbol.n = 0;
-
-                // remove the right sym from the chain
-                left_symbol.next = right_symbol.next;
-                if (right_symbol.next >= 0) {
-                    symbols[right_symbol.next].prev = bigram.left;
-                }
-
-                add_new_bigram(left_symbol.prev, bigram.left);  // left side of current symbol
-                add_new_bigram(bigram.left, left_symbol.next);  // right side of current symbol
-            }
-
-            // add the finished tokens to the final list keeping correct order for next and prev
-            for (auto & sym : symbols) {
-                if (sym.n > 0) {
-                    sym.prev = final_prev_index;
-                    sym.next = -1;
-                    if (final_prev_index != -1) {
-                        symbols_final[final_prev_index].next = symbols_final.size();
-                    }
-                    symbols_final.emplace_back(sym);
-                    final_prev_index = symbols_final.size() - 1;
-                }
-            }
-        }
-
-        symbols = symbols_final;
-
-        if (!symbols.empty()) {
-            for (int i = 0; i != -1; i = symbols[i].next) {
-                auto & symbol = symbols[i];
-                if (symbol.n == 0) {
-                    continue;
-                }
-
-                const std::string str = std::string(symbol.text, symbol.n);
-                const auto token = vocab.token_to_id.find(str);
-
-                if (token == vocab.token_to_id.end()) {
-                    for (auto j = str.begin(); j != str.end(); ++j) {
-                        std::string byte_str(1, *j);
-                        auto token_multibyte = vocab.token_to_id.find(byte_str);
-                        if (token_multibyte != vocab.token_to_id.end()) {
-                            output.push_back(token_multibyte->second);
-                        }
-                    }
-                } else {
-                    output.push_back((*token).second);
-                }
-            }
-        }
-    }
-
-private:
-    void add_new_bigram(int left, int right) {
-        if (left == -1 || right == -1) {
-            return;
-        }
-
-        std::string left_token  = std::string(symbols[left].text,  symbols[left].n);
-        std::string right_token = std::string(symbols[right].text, symbols[right].n);
-
-        int rank_found = -1;
-
-        rank_found = vocab.find_bpe_rank(left_token, right_token);
-
-        if (rank_found < 0) {
-            return;
-        }
-
-        llm_bigram_bpe bigram;
-
-        bigram.left  = left;
-        bigram.right = right;
-        bigram.text  = left_token + right_token;
-        bigram.size  = left_token.size() + right_token.size();
-        bigram.rank  = rank_found;
-
-        work_queue.push(bigram);
-    }
-
-    const llama_vocab & vocab;
-
-    std::vector<std::string> regex_exprs;
-
-    std::vector<llm_symbol> symbols;
-    std::vector<llm_symbol> symbols_final;
-
-    llm_bigram_bpe::queue work_queue;
-};
-
-struct llm_tokenizer_wpm {
-    llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}
-
-    void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) const {
-        const auto & token_map = vocab.token_to_id;
-
-        // normalize and split by whitespace
-        std::vector<std::string> words = preprocess(text);
-
-        // bos token prepended already
-
-        // find the longest tokens that form the words
-        for (const std::string & word : words) {
-            // skip empty words
-            if (word.size() == 0) {
-                continue;
-            }
-
-            // prepend phantom space
-            const std::string word1 = "\xe2\x96\x81" + word;
-            const int n = word1.size();
-
-            const size_t current_tokens = output.size();
-
-            // we're at the start of a new word
-            // move through character position in word
-            for (int i = 0; i < n; ++i) {
-                // loop through possible match length
-                bool match = false;
-                for (int j = std::min(n, i + vocab.max_token_len + 1); j > i; j--) {
-                    auto it = token_map.find(word1.substr(i, j - i));
-                    if (it != token_map.end()) {
-                        output.push_back(it->second);
-                        match = true;
-                        i = j - 1;
-                        break;
-                    }
-                }
-
-                if (!match) { // discard all
-                    output.resize(current_tokens);
-                    break;  // and discard next tokens
-                }
-            }
-
-            // we didn't find any matches for this word
-            if (current_tokens == output.size()) {
-                output.push_back(vocab.special_unk_id);
-            }
-        }
-    }
-
-    // TODO: reduce string copies by using cpts_offs array
-    std::vector<std::string> preprocess(const std::string & text) const {
-        const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
-        std::vector<std::string> words(1, "");
-
-        for (const uint32_t cpt : cpts_nfd) {
-            const auto flags = unicode_cpt_flags(cpt);
-
-            if (flags.is_whitespace) {
-                if (words.back().size()) {  // finish previous word if any
-                    words.emplace_back();
-                }
-                continue;
-            }
-
-            assert (!flags.is_separator);
-            if (cpt == 0 || cpt == 0xFFFD || flags.is_control) {
-                continue;
-            }
-
-            const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt));
-            if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) {
-                if (words.back().size()) {  // finish previous word if any
-                    words.emplace_back();
-                }
-                words.back() = s;       // single char word
-                words.emplace_back();   // start a new word
-            } else {
-                words.back() += s;  // append char to word
-            }
-        }
-
-        if (!words.back().size()) {
-            words.pop_back();
-        }
-
-        return words;
-    }
-
-    static bool is_chinese_char(uint32_t cpt) {
-        return
-            (cpt >= 0x04E00 && cpt <= 0x09FFF) ||
-            (cpt >= 0x03400 && cpt <= 0x04DBF) ||
-            (cpt >= 0x20000 && cpt <= 0x2A6DF) ||
-            (cpt >= 0x2A700 && cpt <= 0x2B73F) ||
-            (cpt >= 0x2B740 && cpt <= 0x2B81F) ||
-            (cpt >= 0x2B920 && cpt <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920
-            (cpt >= 0x0F900 && cpt <= 0x0FAFF) ||
-            (cpt >= 0x2F800 && cpt <= 0x2FA1F);
-            //(cpt >= 0x3000  && cpt <= 0x303F)  ||
-            //(cpt >= 0xFF00  && cpt <= 0xFFEF);
-    }
-
-    const llama_vocab & vocab;
-};
-
-struct naive_trie {
-    naive_trie() : has_value(false), value(0) {
-    }
-    void insert(const char * key, size_t len, int32_t value = 0) {
-        if (len == 0) {
-            this->has_value = true;
-            this->value = value;
-            return;
-        }
-        char c = key[0];
-        auto res = children.find(c);
-        if (res != children.end()) {
-            res->second.insert(key + 1, len - 1, value);
-        } else {
-            auto res = children.insert(std::make_pair(c, naive_trie()));
-            res.first->second.insert(key + 1, len - 1, value);
-        }
-    }
-    std::pair<const char *, size_t> get_longest_prefix(const char * key, size_t len, size_t offset = 0) {
-        if (len == 0 || offset == len) {
-            return std::make_pair(key, offset);
-        }
-        char c = key[offset];
-        auto res = children.find(c);
-        if (res != children.end()) {
-            return res->second.get_longest_prefix(key, len, offset + 1);
-        } else {
-            return std::make_pair(key, offset);
-        }
-    }
-    struct naive_trie * traverse(const char c) {
-        auto res = children.find(c);
-        if (res != children.end()) {
-            return &res->second;
-        } else {
-            return NULL;
-        }
-    }
-    std::map<char, struct naive_trie> children;
-    bool has_value;
-    llama_token value;
-};
-
-struct llm_tokenizer_ugm {
-    llm_tokenizer_ugm(const llama_vocab & vocab) : vocab(vocab) {
-        if (vocab.precompiled_charsmap.size() > 0) {
-            size_t charsmap_offset = 0;
-
-            // First four bytes of precompiled_charsmap contains length of binary
-            // blob containing XOR-compressed compact double array (XCDA) entries
-            uint32_t xcda_blob_size = *(const uint32_t *) &vocab.precompiled_charsmap[0];
-            charsmap_offset += sizeof(xcda_blob_size);
-            if (xcda_blob_size + charsmap_offset >= vocab.precompiled_charsmap.size()) {
-                throw std::runtime_error("Index out of array bounds in precompiled charsmap!");
-            }
-
-            // Next xcda_blob_size bytes contain entries of XOR-compressed compact
-            // double array (XCDA). Each entry is bit-packed into a 32-bit integer.
-            xcda_array = (const uint32_t *) &vocab.precompiled_charsmap[charsmap_offset];
-            xcda_array_size = xcda_blob_size / sizeof(uint32_t);
-            charsmap_offset += xcda_blob_size;
-
-            // Remaining bytes of precompiled charsmap contain null-terminated
-            // replacement strings for prefixes matched by the XCDA.
-            prefix_replacements = &vocab.precompiled_charsmap[charsmap_offset];
-            prefix_replacements_size = vocab.precompiled_charsmap.size() - charsmap_offset;
-        }
-
-        for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) {
-            const auto &token_data = vocab.id_to_token[id];
-
-            if (llama_is_normal_token(vocab, id)) {
-                min_score = std::min<float>(min_score, token_data.score);
-                max_score = std::max<float>(max_score, token_data.score);
-            }
-
-            if (llama_is_normal_token(vocab, id) ||
-                llama_is_user_defined_token(vocab, id) ||
-                llama_is_unused_token(vocab, id)) {
-                token_matcher.insert(token_data.text.data(), token_data.text.size(), id);
-            }
-
-            if (llama_is_user_defined_token(vocab, id)) {
-                user_defined_token_matcher.insert(token_data.text.data(), token_data.text.size());
-            }
-        }
-
-        unknown_token_score = min_score - unknown_token_score_penalty;
-    }
-
-    /* This implementation is based on SentencePiece optimized Viterbi algorithm for
-     * unigram language models. The general idea is to:
-     * - move along the input sequence in steps of one UTF code point,
-     * - at each step find all possible tokenizations of the prefix by
-     *   traversing the tokens trie,
-     * - for each tokenization store the best one so far (by higher score)
-     * - use the position in sequence after given token as an index to store
-     *   results
-     * - if there was no valid tokenization of the current UTF code point
-     *   then use unknown token with additional score penalty
-     * After processing the whole sequence we backtrack from the end to get
-     * the best tokenization.
-    */
-    void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
-        // normalize the input first
-        std::string normalized;
-        normalize(text, &normalized);
-        size_t input_len = normalized.size();
-        if (input_len == 0) {
-            return;
-        }
-
-        // initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores
-        std::vector<struct best_tokenization> tokenization_results(input_len + 1, {vocab.special_unk_id, 0, -FLT_MAX});
-        // at the beginning tokenization score is zero
-        tokenization_results[0] = { vocab.special_unk_id, 0, 0 };
-
-        for (size_t input_offset = 0; input_offset < input_len;) {
-            size_t prefix_offset = input_offset;
-            // calculate how many code units are in the currently processed UTF code point
-            size_t n_utf8_code_units = std::min<size_t>(utf8_len(normalized[input_offset]), input_len - input_offset);
-
-            // traverse the token matcher trie to find a matching token
-            bool single_codepoint_token_found = false;
-            const struct best_tokenization & current_best = tokenization_results[input_offset];
-            struct naive_trie * node  = token_matcher.traverse(normalized[prefix_offset++]);
-
-            while (prefix_offset <= input_len && node != NULL) {
-                // check if we found valid token in prefix
-                if (node->has_value) {
-                    // check if it corresponds to the whole UTF code point
-                    if (prefix_offset - input_offset == n_utf8_code_units) {
-                        single_codepoint_token_found = true;
-                    }
-                    llama_token token_id = node->value;
-                    const auto & token_data = vocab.id_to_token[token_id];
-
-                    // we set the user-defined token scores to 0 to make them more likely to be selected
-                    // (normal token scores are log probabilities, so they are negative)
-                    // score type is double here to make tokenization results exactly
-                    // the same as in the HF tokenizer using SentencePiece
-                    const double token_score = llama_is_user_defined_token(vocab, token_id) ? 0.0 : token_data.score;
-                    const double challenger_score = current_best.score_sum + token_score;
-                    struct best_tokenization & current_champ = tokenization_results[prefix_offset];
-                    if (challenger_score > current_champ.score_sum) {
-                        struct best_tokenization challenger = { token_id, input_offset, (float) challenger_score };
-                        current_champ = challenger;
-                    }
-                }
-                node = node->traverse(normalized[prefix_offset++]);
-            }
-
-            // if we didn't find a valid token corresponding to the whole UTF code point
-            // then use unknown token as the tokenization of this UTF code point
-            if (!single_codepoint_token_found) {
-                const double challenger_score = current_best.score_sum + unknown_token_score;
-                prefix_offset = input_offset + n_utf8_code_units;
-                struct best_tokenization & current_champ = tokenization_results[prefix_offset];
-                if (challenger_score > current_champ.score_sum) {
-                    struct best_tokenization challenger = { vocab.special_unk_id, input_offset, (float) challenger_score };
-                    current_champ = challenger;
-                }
-            }
-
-            // move to the next UTF code point
-            input_offset += n_utf8_code_units;
-        }
-
-        // now backtrack from the end to gather token ids of the best tokenization
-        // merge sequences of consecutive unknown tokens into single unknown tokens
-        bool is_prev_unknown = false;
-        for (struct best_tokenization & tokenization = tokenization_results[input_len]; ; tokenization = tokenization_results[tokenization.input_offset]) {
-            bool is_unknown = tokenization.token_id == vocab.special_unk_id;
-            if (!(is_prev_unknown && is_unknown)) {
-                output.push_back(tokenization.token_id);
-            }
-            if (tokenization.input_offset == 0) {
-                break;
-            }
-            is_prev_unknown = is_unknown;
-        }
-
-        // reverse the output since we added tokens starting from the end of the input
-        std::reverse(output.begin(), output.end());
-    }
-
-private:
-    const llama_vocab & vocab;
-
-    // helper structure for returning normalization results
-    struct normalization_result {
-        const char * normalized;
-        size_t normalized_len;
-        size_t consumed_input;
-    };
-
-    void normalize(const std::string& input, std::string * normalized) {
-        normalized->clear();
-        normalized->reserve(input.size() * 3);
-
-        const std::string space = vocab.tokenizer_escape_whitespaces ? escaped_space : " ";
-
-        bool shall_prepend_space = !vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
-        bool shall_append_space = vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
-        bool shall_merge_spaces = vocab.tokenizer_remove_extra_whitespaces;
-
-        bool is_space_prepended = false;
-        bool processing_non_ws = false;
-
-        size_t input_len = input.size();
-
-        for (size_t input_offset = 0; input_offset < input_len; ) {
-            auto norm_res = normalize_prefix(input, input_offset);
-            for (size_t i = 0; i < norm_res.normalized_len; i++) {
-                char c = norm_res.normalized[i];
-                if (c != ' ') {
-                    if (!processing_non_ws) {
-                        processing_non_ws = true;
-                        if ((shall_prepend_space && !is_space_prepended) || shall_merge_spaces) {
-                            normalized->append(space);
-                            is_space_prepended = true;
-                        }
-                    }
-                    normalized->push_back(c);
-                } else {
-                    if (processing_non_ws) {
-                        processing_non_ws = false;
-                    }
-                    if (!shall_merge_spaces) {
-                        normalized->append(space);
-                    }
-                }
-            }
-
-            input_offset += norm_res.consumed_input;
-        }
-
-        if (shall_append_space) {
-            normalized->append(space);
-        }
-    }
-
-    /*
-     * This structure is a view wrapper for XOR-compressed double array (XCDA)
-     * See Shunsuke Kanda (2018). Space- and Time-Efficient String Dictionaries.
-     * Eeach bit-packed entry contains:
-     * - BASE array value in bits 10-30
-     * - LCHECK array value in bits 0-7
-     * - LEAF array value in bit 9
-     * Entries containing indexes of replacement sequences have set bit 31
-     */
-    struct xcda_array_view {
-    public:
-        xcda_array_view(const uint32_t * xcda_array, size_t xcda_array_size) : xcda_array(xcda_array), xcda_array_size(xcda_array_size) {
-        }
-        uint32_t get_base(size_t index) {
-            uint32_t packed_node = get_node(index);
-            return (packed_node >> 10) << ((packed_node & (1U << 9)) >> 6);
-        }
-        uint32_t get_lcheck(size_t index) {
-            uint32_t packed_node = get_node(index);
-            return packed_node & ((1U << 31) | 0xff);
-        }
-        bool get_leaf(size_t index) {
-            uint32_t packed_node = get_node(index);
-            return (packed_node >> 8) & 1;
-        }
-        uint32_t get_value(size_t index) {
-            uint32_t packed_node = get_node(index);
-            return packed_node & ((1U << 31) - 1);
-        }
-    private:
-        uint32_t get_node(size_t index) {
-            if (index > xcda_array_size) {
-                throw std::runtime_error("Index out of array bounds in XCDA array!");
-            }
-            return xcda_array[index];
-        }
-        const uint32_t * xcda_array;
-        size_t xcda_array_size;
-    };
-
-    struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) {
-        if (input_offset == input.size()) {
-            return { &input[input_offset], 0, 0 };
-        }
-
-        // if input prefix matches some user-defined token return this token as normalization result
-        auto user_defined_token_match = user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
-        if (user_defined_token_match.second > 0) {
-            return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second };
-        }
-
-        size_t longest_prefix_length = 0;
-        size_t longest_prefix_offset = 0;
-
-        if (xcda_array_size > 0) {
-            struct xcda_array_view xcda_view(xcda_array, xcda_array_size);
-
-            // Find the longest normalized sequence matching the input prefix by walking
-            // the XOR-compressed compact double array (XCDA) starting from the root node
-            // We find the index of the next node by calculating BASE[s] ^ c where s is
-            // the index of the previous node and c is a numerical character value
-            uint32_t node_index = 0;
-            // get BASE of the root node
-            node_index = xcda_view.get_base(node_index);
-            for (size_t prefix_offset = input_offset; prefix_offset < input.size(); prefix_offset++) {
-                unsigned char c = input[prefix_offset];
-                if (c == 0) {
-                    break;
-                }
-                node_index ^= c;
-                // if value of LCHECK is not c it means that this is not a child of
-                // the previous node, so we stop matching
-                if (xcda_view.get_lcheck(node_index) != c) {
-                    break;
-                }
-                bool is_leaf = xcda_view.get_leaf(node_index);
-                // get BASE of the current node
-                node_index ^= xcda_view.get_base(node_index);
-                // if LEAF of the current node is true, it means that its BASE points to the node
-                // containing index of replacement sequence for currently matched input prefix
-                if (is_leaf)
-                {
-                    longest_prefix_length = prefix_offset - input_offset + 1;
-                    // get index of replacement sequence for currently matched input prefix
-                    longest_prefix_offset = xcda_view.get_value(node_index);
-                }
-            }
-        }
-
-        if (longest_prefix_length > 0) {
-            // we have a match, so return the replacement sequence
-            if (longest_prefix_offset >= prefix_replacements_size) {
-                throw std::runtime_error("Index out of array bounds in precompiled charsmap!");
-            }
-            const char * prefix_replacement = &prefix_replacements[longest_prefix_offset];
-            return { prefix_replacement, strlen(prefix_replacement), longest_prefix_length };
-        } else {
-            // check if the input prefix contains a valid sequence of UTF-8 code units
-            try {
-                // if yes, return this sequence unmodified
-                size_t prefix_offset = input_offset;
-                unicode_cpt_from_utf8(input, prefix_offset);
-                return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset };
-            } catch (std::invalid_argument & /*ex*/) {
-                // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER
-                return { "\xEF\xBF\xBD", 3, 1 };
-            }
-        }
-    }
-
-    // escaped space symbol - U+2581 (Lower One Eighth Block)
-    const std::string escaped_space = "\xE2\x96\x81";
-
-    const char * prefix_replacements = NULL;
-    size_t prefix_replacements_size = 0;
-
-    const uint32_t * xcda_array = NULL;
-    size_t xcda_array_size = 0;
-
-    struct naive_trie user_defined_token_matcher;
-
-    // this structure stores the best tokenization so far at input_offset
-    struct best_tokenization {
-        llama_token token_id;
-        size_t input_offset;
-        float score_sum;
-    };
-
-    float min_score = FLT_MAX;
-    float max_score = -FLT_MAX;
-
-    float unknown_token_score_penalty = 10.0;
-    float unknown_token_score;
-
-    struct naive_trie token_matcher;
-};
-
-
-typedef enum FRAGMENT_BUFFER_VARIANT_TYPE {
-    FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN,
-    FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT
-} FRAGMENT_BUFFER_VARIANT_TYPE;
-
-struct fragment_buffer_variant {
-    fragment_buffer_variant(llama_vocab::id _token)
-    :
-        type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN),
-        token(_token),
-        raw_text(_dummy),
-        offset(0),
-        length(0) {}
-
-    fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length)
-    :
-        type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT),
-        token((llama_vocab::id) - 1),
-        raw_text(_raw_text),
-        offset(_offset),
-        length(_length){
-            GGML_ASSERT(_offset >= 0);
-            GGML_ASSERT(_length >= 1);
-            GGML_ASSERT(offset + length <= raw_text.length());
-        }
-
-    const FRAGMENT_BUFFER_VARIANT_TYPE type;
-    const llama_vocab::id token;
-    const std::string _dummy;
-    const std::string & raw_text;
-    const uint64_t offset;
-    const uint64_t length;
-};
-
-// #define PRETOKENIZERDEBUG
-
-static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
-    // for each special token
-    for (const llama_vocab::id special_id : vocab.cache_special_tokens) {
-        const auto & data = vocab.id_to_token[special_id];
-        const auto & special_token = data.text;
-
-        // for each text fragment
-        std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
-        while (it != buffer.end()) {
-            auto & fragment = (*it);
-
-            // if a fragment is text ( not yet processed )
-            if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                auto & raw_text = fragment.raw_text;
-
-                auto raw_text_base_offset = fragment.offset;
-                auto raw_text_base_length = fragment.length;
-
-                // loop over the text
-                while (true) {
-                    // find the first occurrence of a given special token in this fragment
-                    //  passing offset argument only limit the "search area" but match coordinates
-                    //  are still relative to the source full raw_text
-                    auto match = raw_text.find(special_token, raw_text_base_offset);
-
-                    // no occurrences found, stop processing this fragment for a given special token
-                    if (match == std::string::npos) break;
-
-                    // check if match is within bounds of offset <-> length
-                    if (match + special_token.length() > raw_text_base_offset + raw_text_base_length) break;
-
-#ifdef PRETOKENIZERDEBUG
-                    LLAMA_LOG_WARN("FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
-#endif
-                    auto source = std::distance(buffer.begin(), it);
-
-                    // if match is further than base offset
-                    //  then we have some text to the left of it
-                    if (match > raw_text_base_offset) {
-                        // left
-                        const int64_t left_reminder_offset = raw_text_base_offset + 0;
-                        int64_t left_reminder_length = match - raw_text_base_offset;
-
-                        if (data.attr & LLAMA_TOKEN_ATTR_LSTRIP) {
-                            while (left_reminder_length > 0 && isspace(raw_text[left_reminder_offset + left_reminder_length - 1])) {
-                                left_reminder_length--;
-                            }
-                        }
-
-                        if (left_reminder_length > 0) {
-                            buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length);
-                            it++;
-                        }
-
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str());
-#endif
-                    }
-
-                    // special token
-                    buffer.emplace_after(it, special_id);
-                    it++;
-
-                    // right
-                    if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) {
-                        int64_t right_reminder_offset = match + special_token.length();
-                        int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
-
-                        if (data.attr & LLAMA_TOKEN_ATTR_RSTRIP) {
-                            while (right_reminder_length > 0 && isspace(raw_text[right_reminder_offset])) {
-                                right_reminder_offset++;
-                                right_reminder_length--;
-                            }
-                        }
-
-                        if (right_reminder_length > 0) {
-                            buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length);
-                            it++;
-                        }
-
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str());
-#endif
-
-                        if (source == 0) {
-                            buffer.erase_after(buffer.before_begin());
-                        } else {
-                            buffer.erase_after(std::next(buffer.begin(), (source-1)));
-                        }
-
-                        // repeat for the right side
-                        raw_text_base_offset = right_reminder_offset;
-                        raw_text_base_length = right_reminder_length;
-
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("RR: (%ld %ld) '%s'\n", raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
-#endif
-                    } else {
-                        if (source == 0) {
-                            buffer.erase_after(buffer.before_begin());
-                        } else {
-                            buffer.erase_after(std::next(buffer.begin(), (source-1)));
-                        }
-                        break;
-                    }
-                }
-            }
-            it++;
-        }
-    }
-}
-
-static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special) {
-    std::vector<llama_vocab::id> output;
-    std::forward_list<fragment_buffer_variant> fragment_buffer;
-
-    if (!raw_text.empty()) {
-        fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
-        if (parse_special) tokenizer_st_partition(vocab, fragment_buffer);
-    }
-
-    switch (vocab.type) {
-        case LLAMA_VOCAB_TYPE_SPM:
-            {
-                // OG tokenizer behavior:
-                //
-                // tokenizer.encode('', add_special_tokens=True)  returns [1]
-                // tokenizer.encode('', add_special_tokens=False) returns []
-
-                bool is_prev_special = true;  // prefix with space if first token
-
-                if (add_special && vocab.tokenizer_add_bos) {
-                    GGML_ASSERT(vocab.special_bos_id != -1);
-                    output.push_back(vocab.special_bos_id);
-                    is_prev_special = true;
-                }
-
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
-
-                        // prefix with space if previous is special
-                        if (vocab.tokenizer_add_space_prefix && is_prev_special) {
-                            raw_text = " " + raw_text;
-                        }
-
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-                        llm_tokenizer_spm tokenizer(vocab);
-                        llama_escape_whitespace(raw_text);
-                        tokenizer.tokenize(raw_text, output);
-                        is_prev_special = false;
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        output.push_back(fragment.token);
-                        is_prev_special = true;
-                    }
-                }
-
-                if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
-                    LLAMA_LOG_WARN(
-                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
-                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
-                        "Are you sure this is what you want?\n", __FUNCTION__);
-                }
-
-                if (add_special && vocab.tokenizer_add_eos) {
-                    GGML_ASSERT(vocab.special_eos_id != -1);
-                    output.push_back(vocab.special_eos_id);
-                }
-            } break;
-        case LLAMA_VOCAB_TYPE_BPE:
-            {
-                llm_tokenizer_bpe tokenizer(vocab);
-
-                if (add_special) {
-                    tokenizer.append_bos(output);
-                }
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
-
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-                        tokenizer.tokenize(raw_text, output);
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        tokenizer.append(fragment.token, output);
-                    }
-                }
-
-                if (add_special) {
-                    tokenizer.append_eos(output);
-                    tokenizer.check_double_bos_eos(output);
-                }
-            } break;
-        case LLAMA_VOCAB_TYPE_WPM:
-            {
-                if (add_special) {
-                    GGML_ASSERT(vocab.special_cls_id != -1);
-                    output.push_back(vocab.special_cls_id);
-                }
-
-                llm_tokenizer_wpm tokenizer(vocab);
-
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
-
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-                        tokenizer.tokenize(raw_text, output);
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        output.push_back(fragment.token);
-                    }
-                }
-
-                if (add_special) {
-                    GGML_ASSERT(vocab.special_sep_id != -1);
-                    output.push_back(vocab.special_sep_id);
-                }
-            } break;
-        case LLAMA_VOCAB_TYPE_UGM:
-            {
-                llm_tokenizer_ugm tokenizer(vocab);
-
-                if (add_special && vocab.tokenizer_add_bos != 0) {
-                    GGML_ASSERT(vocab.special_bos_id != -1);
-                    output.push_back(vocab.special_bos_id);
-                }
-
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-                        tokenizer.tokenize(raw_text, output);
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        output.push_back(fragment.token);
-                    }
-                }
-
-                if (add_special && vocab.tokenizer_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
-                    LLAMA_LOG_WARN(
-                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
-                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
-                        "Are you sure this is what you want?\n", __FUNCTION__);
-                }
-
-                if (add_special && vocab.tokenizer_add_eos == 1) {
-                    GGML_ASSERT(vocab.special_eos_id != -1);
-                    output.push_back(vocab.special_eos_id);
-                }
-            } break;
-        case LLAMA_VOCAB_TYPE_NONE:
-            GGML_ASSERT(false);
-    }
-
-    return output;
-}
-
-//
-// grammar - internal
-//
-
-
-// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
-// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
-std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
-        const std::string & src,
-        llama_partial_utf8   partial_start) {
-    static const int      lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
-    const char          * pos      = src.c_str();
-    std::vector<uint32_t> code_points;
-    // common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0.
-    code_points.reserve(src.size() + 1);
-    uint32_t              value    = partial_start.value;
-    int                   n_remain = partial_start.n_remain;
-
-    // continue previous decode, if applicable
-    while (*pos != 0 && n_remain > 0) {
-        uint8_t next_byte = static_cast<uint8_t>(*pos);
-        if ((next_byte >> 6) != 2) {
-            // invalid sequence, abort
-            code_points.push_back(0);
-            return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 });
-        }
-        value = (value << 6) + (next_byte & 0x3F);
-        ++pos;
-        --n_remain;
-    }
-
-    if (partial_start.n_remain > 0 && n_remain == 0) {
-        code_points.push_back(value);
-    }
-
-    // decode any subsequent utf-8 sequences, which may end in an incomplete one
-    while (*pos != 0) {
-        uint8_t  first_byte = static_cast<uint8_t>(*pos);
-        uint8_t  highbits   = first_byte >> 4;
-                 n_remain   = lookup[highbits] - 1;
-
-        if (n_remain < 0) {
-            // invalid sequence, abort
-            code_points.clear();
-            code_points.push_back(0);
-            return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain });
-        }
-
-        uint8_t  mask       = (1 << (7 - n_remain)) - 1;
-                 value      = first_byte & mask;
-        ++pos;
-        while (*pos != 0 && n_remain > 0) {
-            value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
-            ++pos;
-            --n_remain;
-        }
-        if (n_remain == 0) {
-            code_points.push_back(value);
-        }
-    }
-    code_points.push_back(0);
-
-    return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
-}
-
-// returns true iff pos points to the end of one of the definitions of a rule
-static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) {
-    switch (pos->type) {
-        case LLAMA_GRETYPE_END: return true;  // NOLINT
-        case LLAMA_GRETYPE_ALT: return true;  // NOLINT
-        default:                return false;
-    }
-}
-
-// returns true iff chr satisfies the char range at pos (regular or inverse range)
-// asserts that pos is pointing to a char range element
-static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
-        const llama_grammar_element * pos,
-        const uint32_t                chr) {
-
-    bool found            = false;
-    bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
-
-    GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT
-
-    do {
-        if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
-            // inclusive range, e.g. [a-z]
-            found = found || (pos->value <= chr && chr <= pos[1].value);
-            pos += 2;
-        } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
-            // Any character matches "."
-            found = true;
-            pos += 1;
-        } else {
-            // exact char match, e.g. [a] or "a"
-            found = found || pos->value == chr;
-            pos += 1;
-        }
-    } while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
-
-    return std::make_pair(found == is_positive_char, pos);
-}
-
-// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
-// range at pos (regular or inverse range)
-// asserts that pos is pointing to a char range element
-static bool llama_grammar_match_partial_char(
-        const llama_grammar_element * pos,
-        const llama_partial_utf8      partial_utf8) {
-
-    bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
-    GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);
-
-    uint32_t partial_value = partial_utf8.value;
-    int      n_remain      = partial_utf8.n_remain;
-
-    // invalid sequence or 7-bit char split across 2 bytes (overlong)
-    if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
-        return false;
-    }
-
-    // range of possible code points this partial UTF-8 sequence could complete to
-    uint32_t low  = partial_value << (n_remain * 6);
-    uint32_t high = low | ((1 << (n_remain * 6)) - 1);
-
-    if (low == 0) {
-        if (n_remain == 2) {
-            low = 1 << 11;
-        } else if (n_remain == 3) {
-            low = 1 << 16;
-        }
-    }
-
-    do {
-        if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
-            // inclusive range, e.g. [a-z]
-            if (pos->value <= high && low <= pos[1].value) {
-                return is_positive_char;
-            }
-            pos += 2;
-        } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
-            // Any character matches "."
-            return true;
-        } else {
-            // exact char match, e.g. [a] or "a"
-            if (low <= pos->value && pos->value <= high) {
-                return is_positive_char;
-            }
-            pos += 1;
-        }
-    } while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
-
-    return !is_positive_char;
-}
-
-
-// transforms a grammar pushdown stack into N possible stacks, all ending
-// at a character range (terminal element)
-static void llama_grammar_advance_stack(
-        const std::vector<std::vector<llama_grammar_element>>   & rules,
-        const std::vector<const llama_grammar_element *>        & stack,
-        std::vector<std::vector<const llama_grammar_element *>> & new_stacks) {
-
-    if (stack.empty()) {
-        if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
-            new_stacks.emplace_back(stack);
-        }
-        return;
-    }
-
-    const llama_grammar_element * pos = stack.back();
-
-    switch (pos->type) {
-        case LLAMA_GRETYPE_RULE_REF: {
-            const size_t                  rule_id = static_cast<size_t>(pos->value);
-            const llama_grammar_element * subpos  = rules[rule_id].data();
-            do {
-                // init new stack without the top (pos)
-                std::vector<const llama_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
-                if (!llama_grammar_is_end_of_sequence(pos + 1)) {
-                    // if this rule ref is followed by another element, add that to stack
-                    new_stack.push_back(pos + 1);
-                }
-                if (!llama_grammar_is_end_of_sequence(subpos)) {
-                    // if alternate is nonempty, add to stack
-                    new_stack.push_back(subpos);
-                }
-                llama_grammar_advance_stack(rules, new_stack, new_stacks);
-                while (!llama_grammar_is_end_of_sequence(subpos)) {
-                    // scan to end of alternate def
-                    subpos++;
-                }
-                if (subpos->type == LLAMA_GRETYPE_ALT) {
-                    // there's another alternate def of this rule to process
-                    subpos++;
-                } else {
-                    break;
-                }
-            } while (true);
-            break;
-        }
-        case LLAMA_GRETYPE_CHAR:
-        case LLAMA_GRETYPE_CHAR_NOT:
-        case LLAMA_GRETYPE_CHAR_ANY:
-            if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
-                // only add the stack if it's not a duplicate of one we already have
-                new_stacks.emplace_back(stack);
-            }
-            break;
-        default:
-            // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
-            // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
-            // those
-            GGML_ASSERT(false);
-    }
-}
-
-// takes a set of possible pushdown stacks on a grammar, which are required to
-// be positioned at a character range (see `llama_grammar_advance_stack`), and
-// produces the N possible stacks if the given char is accepted at those
-// positions
-void llama_grammar_accept(
-        const std::vector<std::vector<llama_grammar_element>>         & rules,
-        const std::vector<std::vector<const llama_grammar_element *>> & stacks,
-        const uint32_t                                                  chr,
-        std::vector<std::vector<const llama_grammar_element *>>       & new_stacks) {
-
-    new_stacks.clear();
-
-    for (const auto & stack : stacks) {
-        if (stack.empty()) {
-            continue;
-        }
-
-        auto match = llama_grammar_match_char(stack.back(), chr);
-        if (match.first) {
-            const llama_grammar_element * pos = match.second;
-
-            // update top of stack to next element, if any
-            std::vector<const llama_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
-            if (!llama_grammar_is_end_of_sequence(pos)) {
-                new_stack.push_back(pos);
-            }
-            llama_grammar_advance_stack(rules, new_stack, new_stacks);
-        }
-    }
-}
-
-static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
-        const std::vector<std::vector<llama_grammar_element>>         & rules,
-        const std::vector<std::vector<const llama_grammar_element *>> & stacks,
-        const std::vector<llama_grammar_candidate>                    & candidates);
-
-static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
-        const std::vector<std::vector<llama_grammar_element>> & rules,
-        const std::vector<const llama_grammar_element *>      & stack,
-        const std::vector<llama_grammar_candidate>            & candidates) {
-
-    std::vector<llama_grammar_candidate> rejects;
-    rejects.reserve(candidates.size());
-
-    if (stack.empty()) {
-        for (const auto & tok : candidates) {
-            if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
-                rejects.push_back(tok);
-            }
-        }
-        return rejects;
-    }
-
-    const llama_grammar_element * stack_pos = stack.back();
-
-    std::vector<llama_grammar_candidate> next_candidates;
-    next_candidates.reserve(candidates.size());
-
-    for (const auto & tok : candidates) {
-        if (*tok.code_points == 0) {
-            // reached end of full codepoints in token, reject iff it ended in a partial sequence
-            // that cannot satisfy this position in grammar
-            if (tok.partial_utf8.n_remain != 0 &&
-                    !llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
-                rejects.push_back(tok);
-            }
-        } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
-            next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 });
-        } else {
-            rejects.push_back(tok);
-        }
-    }
-
-    const auto * stack_pos_after = llama_grammar_match_char(stack_pos, 0).second;
-
-    // update top of stack to next element, if any
-    std::vector<const llama_grammar_element *> stack_after(stack.begin(), stack.end() - 1);
-    if (!llama_grammar_is_end_of_sequence(stack_pos_after)) {
-        stack_after.push_back(stack_pos_after);
-    }
-    std::vector<std::vector<const llama_grammar_element *>> next_stacks;
-    llama_grammar_advance_stack(rules, stack_after, next_stacks);
-
-    auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
-    for (const auto & tok : next_rejects) {
-        rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
-    }
-
-    return rejects;
-}
-
-static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
-        const std::vector<std::vector<llama_grammar_element>>         & rules,
-        const std::vector<std::vector<const llama_grammar_element *>> & stacks,
-        const std::vector<llama_grammar_candidate>                    & candidates) {
-    GGML_ASSERT(!stacks.empty()); // REVIEW
-
-    if (candidates.empty()) {
-        return std::vector<llama_grammar_candidate>();
-    }
-
-    auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
-
-    for (size_t i = 1, size = stacks.size(); i < size; ++i) {
-        rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
-    }
-    return rejects;
-}
-
-static bool llama_grammar_detect_left_recursion(
-        const std::vector<std::vector<llama_grammar_element>> & rules,
-        size_t                                                  rule_index,
-        std::vector<bool>                                     * rules_visited,
-        std::vector<bool>                                     * rules_in_progress,
-        std::vector<bool>                                     * rules_may_be_empty) {
-    if ((*rules_in_progress)[rule_index]) {
-        return true;
-    }
-
-    (*rules_in_progress)[rule_index] = true;
-
-    const std::vector<llama_grammar_element> & rule = rules[rule_index];
-
-    // First check if the rule might produce the empty string. This could be done combined with the second
-    // step but it's more readable as two steps.
-    bool at_rule_start = true;
-    for (size_t i = 0; i < rule.size(); i++) {
-        if (llama_grammar_is_end_of_sequence(&rule[i])) {
-            if (at_rule_start) {
-                (*rules_may_be_empty)[rule_index] = true;
-                break;
-            }
-            at_rule_start = true;
-        } else {
-            at_rule_start = false;
-        }
-    }
-
-    // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
-    // be empty)
-    bool recurse_into_nonterminal = true;
-    for (size_t i = 0; i < rule.size(); i++) {
-        if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
-            if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
-                return true;
-            }
-            if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
-                recurse_into_nonterminal = false;
-            }
-        } else if (llama_grammar_is_end_of_sequence(&rule[i])) {
-            recurse_into_nonterminal = true;
-        } else {
-            recurse_into_nonterminal = false;
-        }
-    }
-
-    (*rules_in_progress)[rule_index] = false;
-    (*rules_visited)[rule_index] = true;
-    return false;
-}
-
-//
-// grammar - external
-//
-
-struct llama_grammar * llama_grammar_init(
-            const llama_grammar_element ** rules,
-                                 size_t    n_rules,
-                                 size_t    start_rule_index) {
-    const llama_grammar_element * pos;
-
-    // copy rule definitions into vectors
-    std::vector<std::vector<llama_grammar_element>> vec_rules(n_rules);
-    for (size_t i = 0; i < n_rules; i++) {
-        for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
-            vec_rules[i].push_back(*pos);
-        }
-        vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
-    }
-
-    // Check for left recursion
-    std::vector<bool> rules_visited(n_rules);
-    std::vector<bool> rules_in_progress(n_rules);
-    std::vector<bool> rules_may_be_empty(n_rules);
-    for (size_t i = 0; i < n_rules; i++) {
-        if (rules_visited[i]) {
-            continue;
-        }
-        if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
-            LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
-            return nullptr;
-        }
-    }
-
-    // loop over alternates of start rule to build initial stacks
-    std::vector<std::vector<const llama_grammar_element *>> stacks;
-    pos = vec_rules[start_rule_index].data();
-    do {
-        std::vector<const llama_grammar_element *> stack;
-        if (!llama_grammar_is_end_of_sequence(pos)) {
-            // if alternate is nonempty, add to stack
-            stack.push_back(pos);
-        }
-        llama_grammar_advance_stack(vec_rules, stack, stacks);
-        while (!llama_grammar_is_end_of_sequence(pos)) {
-            // scan to end of alternate def
-            pos++;
-        }
-        if (pos->type == LLAMA_GRETYPE_ALT) {
-            // there's another alternate def of this rule to process
-            pos++;
-        } else {
-            break;
-        }
-    } while (true);
-
-    // Important: vec_rules has to be moved here, not copied, because stacks contains
-    // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
-    // then the pointers would be invalidated when the local vec_rules goes out of scope.
-    return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
-}
-
-void llama_grammar_free(struct llama_grammar * grammar) {
-    delete grammar;
-}
-
-struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) {
-    llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 };
-
-    // redirect elements in stacks to point to new rules
-    for (size_t is = 0; is < result->stacks.size(); is++) {
-        for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
-            for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) {
-                for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) {
-                    if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) {
-                         result->stacks[is][ie]  =  &result->rules[ir0][ir1];
-                    }
-                }
-            }
-        }
-    }
-
-    return result;
-}
-
-//
-// sampling
-//
-
-void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) {
-    if (seed == LLAMA_DEFAULT_SEED) {
-        seed = time(NULL);
-    }
-    ctx->rng.seed(seed);
-}
-
-void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) {
-    GGML_ASSERT(candidates->size > 0);
-
-    const int64_t t_start_sample_us = ggml_time_us();
-
-    // Sort the logits in descending order
-    if (!candidates->sorted) {
-        std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
-            return a.logit > b.logit;
-        });
-        candidates->sorted = true;
-    }
-
-    float max_l = candidates->data[0].logit;
-    float cum_sum = 0.0f;
-    for (size_t i = 0; i < candidates->size; ++i) {
-        float p = expf(candidates->data[i].logit - max_l);
-        candidates->data[i].p = p;
-        cum_sum += p;
-    }
-    for (size_t i = 0; i < candidates->size; ++i) {
-        candidates->data[i].p /= cum_sum;
-    }
-
-    if (ctx) {
-        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
-}
-
-void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
-    // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
-    // if (k >= (int32_t)candidates->size) {
-    //     return;
-    // }
-
-    const int64_t t_start_sample_us = ggml_time_us();
-
-    if (k <= 0) {
-        k = candidates->size;
-    }
-
-    k = std::max(k, (int) min_keep);
-    k = std::min(k, (int) candidates->size);
-
-    // Sort scores in descending order
-    if (!candidates->sorted) {
-        auto comp = [](const llama_token_data & a, const llama_token_data & b) {
-            return a.logit > b.logit;
-        };
-        if (k <= 128) {
-            std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
-        } else {
-            constexpr int   nbuckets     = 128;
-            constexpr float bucket_low   = -10.0f;
-            constexpr float bucket_high  =  10.0f;
-            constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
-            constexpr float bucker_inter = -bucket_low * bucket_scale;
-
-            std::vector<int> bucket_idx(candidates->size);
-            std::vector<int> histo(nbuckets, 0);
-
-            for (int i = 0; i < (int)candidates->size; ++i) {
-                const float val = candidates->data[i].logit;
-                int ib = int(bucket_scale * val + bucker_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
-                ib = std::max(0, std::min(nbuckets-1, ib));
-                bucket_idx[i] = ib;
-                ++histo[ib];
-            }
-            int nhave = 0;
-            int ib = nbuckets - 1;
-            for ( ; ib >= 0; --ib) {
-                nhave += histo[ib];
-                if (nhave >= k) break;
-            }
-            std::vector<llama_token_data> tmp_tokens(nhave);
-            auto ptr = tmp_tokens.data();
-            std::vector<llama_token_data*> bucket_ptrs;
-            bucket_ptrs.reserve(nbuckets - ib);
-            for (int j = nbuckets - 1; j >= ib; --j) {
-                bucket_ptrs.push_back(ptr);
-                ptr += histo[j];
-            }
-            for (int i = 0; i < (int)candidates->size; ++i) {
-                int j = bucket_idx[i];
-                if (j >= ib) {
-                    *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i];
-                }
-            }
-
-            ptr = tmp_tokens.data();
-            int ndone = 0;
-            for (int j = nbuckets-1; j > ib; --j) {
-                std::sort(ptr, ptr + histo[j], comp);
-                ptr += histo[j];
-                ndone += histo[j];
-            }
-            std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
-
-            std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data));
-
-        }
-        candidates->sorted = true;
-    }
-    candidates->size = k;
-
-    if (ctx) {
-        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
-}
-
-void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
-    if (p >= 1.0f) {
-        return;
-    }
-
-    llama_sample_softmax(ctx, candidates);
-
-    const int64_t t_start_sample_us = ggml_time_us();
-
-    // Compute the cumulative probabilities
-    float cum_sum = 0.0f;
-    size_t last_idx = candidates->size;
-
-    for (size_t i = 0; i < candidates->size; ++i) {
-        cum_sum += candidates->data[i].p;
-
-        // Check if the running sum is at least p or if we have kept at least min_keep tokens
-        // we set the last index to i+1 to indicate that the current iterate should be included in the set
-        if (cum_sum >= p && i + 1 >= min_keep) {
-            last_idx = i + 1;
-            break;
-        }
-    }
-
-    // Resize the output vector to keep only the top-p tokens
-    candidates->size = last_idx;
-
-    if (ctx) {
-        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
-}
-
-void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
-    if (p <= 0.0f || !candidates->size) {
-        return;
-    }
-
-    const int64_t t_start_sample_us = ggml_time_us();
-
-    bool min_p_applied = false;
-
-    // if the candidates aren't sorted, try the unsorted implementation first
-    if (!candidates->sorted) {
-        std::vector<llama_token_data> filtered_tokens;
-
-        float max_logit = -FLT_MAX;
-        for (size_t i = 0; i < candidates->size; ++i) {
-            max_logit = std::max(max_logit, candidates->data[i].logit);
-        }
-        const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
-
-        for (size_t i = 0; i < candidates->size; ++i) {
-            if (candidates->data[i].logit >= min_logit) {
-                filtered_tokens.push_back(candidates->data[i]);
-            }
-        }
-
-        // if we have enough values the operation was a success
-        if (filtered_tokens.size() >= min_keep) {
-            memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
-            candidates->size = filtered_tokens.size();
-            min_p_applied = true;
-        }
-    }
-
-    // if the candidates are sorted or the unsorted implementation failed, use this implementation
-    if (!min_p_applied) {
-        // Sort the logits in descending order
-        if (!candidates->sorted) {
-            std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
-                return a.logit > b.logit;
-            });
-            candidates->sorted = true;
-        }
-
-        const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max
-        size_t i = 1; // first token always matches
-
-        for (; i < candidates->size; ++i) {
-            if (candidates->data[i].logit < min_logit && i >= min_keep) {
-                break; // prob too small
-            }
-        }
-
-        // Resize the output vector to keep only the matching tokens
-        candidates->size = i;
-    }
-
-    if (ctx) {
-        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
-}
-
-void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
-    if (z >= 1.0f || candidates->size <= 2) {
-        return;
-    }
-
-    llama_sample_softmax(nullptr, candidates);
-    const int64_t t_start_sample_us = ggml_time_us();
-
-    // Compute the first and second derivatives
-    std::vector<float> first_derivatives(candidates->size - 1);
-    std::vector<float> second_derivatives(candidates->size - 2);
-
-    for (size_t i = 0; i < first_derivatives.size(); ++i) {
-        first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p;
-    }
-    for (size_t i = 0; i < second_derivatives.size(); ++i) {
-        second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
-    }
-
-    // Calculate absolute value of second derivatives
-    for (size_t i = 0; i < second_derivatives.size(); ++i) {
-        second_derivatives[i] = std::abs(second_derivatives[i]);
-    }
-
-    // Normalize the second derivatives
-    {
-        const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
-
-        if (second_derivatives_sum > 1e-6f) {
-            for (float & value : second_derivatives) {
-                value /= second_derivatives_sum;
-            }
-        } else {
-            for (float & value : second_derivatives) {
-                value = 1.0f / second_derivatives.size();
-            }
-        }
-    }
-
-    float cum_sum = 0.0f;
-    size_t last_idx = candidates->size;
-    for (size_t i = 0; i < second_derivatives.size(); ++i) {
-        cum_sum += second_derivatives[i];
-
-        // Check if the running sum is greater than z or if we have kept at least min_keep tokens
-        if (cum_sum > z && i >= min_keep) {
-            last_idx = i;
-            break;
-        }
-    }
-
-    // Resize the output vector to keep only the tokens above the tail location
-    candidates->size = last_idx;
-
-    if (ctx) {
-        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
-}
-
-void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
-    // Reference implementation:
-    // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
-    if (p >= 1.0f) {
-        return;
-    }
-
-    // Compute the softmax of logits and calculate entropy
-    llama_sample_softmax(nullptr, candidates);
-
-    const int64_t t_start_sample_us = ggml_time_us();
-
-    float entropy = 0.0f;
-    for (size_t i = 0; i < candidates->size; ++i) {
-        entropy += -candidates->data[i].p * logf(candidates->data[i].p);
-    }
-
-    // Compute the absolute difference between negative log probability and entropy for each candidate
-    std::vector<float> shifted_scores;
-    for (size_t i = 0; i < candidates->size; ++i) {
-        float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy);
-        shifted_scores.push_back(shifted_score);
-    }
-
-    // Sort tokens based on the shifted_scores and their corresponding indices
-    std::vector<size_t> indices(candidates->size);
-    std::iota(indices.begin(), indices.end(), 0);
-
-    std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
-        return shifted_scores[a] < shifted_scores[b];
-    });
-
-    // Compute the cumulative probabilities
-    float cum_sum = 0.0f;
-    size_t last_idx = indices.size();
-
-    for (size_t i = 0; i < indices.size(); ++i) {
-        size_t idx = indices[i];
-        cum_sum += candidates->data[idx].p;
-
-        // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
-        if (cum_sum > p && i >= min_keep - 1) {
-            last_idx = i + 1;
-            break;
-        }
-    }
-
-    // Resize the output vector to keep only the locally typical tokens
-    std::vector<llama_token_data> new_candidates;
-    for (size_t i = 0; i < last_idx; ++i) {
-        size_t idx = indices[i];
-        new_candidates.push_back(candidates->data[idx]);
-    }
-
-    // Replace the data in candidates with the new_candidates data
-    std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
-    candidates->size = new_candidates.size();
-    candidates->sorted = false;
-
-    if (ctx) {
-        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
-}
-
-void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) {
-    const int64_t t_start_sample_us = ggml_time_us();
-
-    // no need to do anything if there is only one (or zero) candidates
-    if(candidates_p->size <= 1) {
-        return;
-    }
-
-    // Calculate maximum possible entropy
-    float max_entropy = -logf(1.0f / candidates_p->size);
-
-    llama_sample_softmax(nullptr, candidates_p);
-
-    // Calculate entropy of the softmax probabilities
-    float entropy = 0.0f;
-    for (size_t i = 0; i < candidates_p->size; ++i) {
-        float prob = candidates_p->data[i].p;
-        if (prob > 0.0f) { // Ensure no log(0)
-            entropy -= prob * logf(prob);
-        }
-    }
-
-    // Normalize the entropy (max_entropy cannot be 0 here because we checked candidates_p->size != 1 above)
-    float normalized_entropy = entropy / max_entropy;
-
-    // Map the normalized entropy to the desired temperature range using the power function
-    float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
-
-#ifdef DEBUG
-    LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
-    LLAMA_LOG_INFO("Entropy: %f\n", entropy);
-    LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
-    LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
-    LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
-    LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
-#endif
-
-    // Apply the dynamically calculated temperature scaling
-    for (size_t i = 0; i < candidates_p->size; ++i) {
-        candidates_p->data[i].logit /= dyn_temp;
-    }
-
-    // Re-compute softmax probabilities after scaling logits with dynamic temperature
-    double max_l_double = candidates_p->data[0].logit;
-    double cum_sum_double = 0.0;
-    for (size_t i = 0; i < candidates_p->size; ++i) {
-        double p = exp(candidates_p->data[i].logit - max_l_double);
-        candidates_p->data[i].p = p; // Store the scaled probability
-        cum_sum_double += p;
-    }
-    for (size_t i = 0; i < candidates_p->size; ++i) {
-        candidates_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities
-    }
-
-#ifdef DEBUG
-    // Print the updated top 25 probabilities after temperature scaling
-    LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
-    for (size_t i = 0; i < 25 && i < candidates_p->size; ++i) {
-        LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates_p->data[i].p * 100.0f);
-    }
-#endif
-
-    if (ctx) {
-        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
-}
-
-void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) {
-    const int64_t t_start_sample_us = ggml_time_us();
-
-    for (size_t i = 0; i < candidates_p->size; ++i) {
-        candidates_p->data[i].logit /= temp;
-    }
-
-    if (ctx) {
-        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
-}
-
-void llama_sample_repetition_penalties(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-               const llama_token * last_tokens,
-                          size_t   penalty_last_n,
-                           float   penalty_repeat,
-                           float   penalty_freq,
-                           float   penalty_present) {
-    if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
-        return;
-    }
-
-    const int64_t t_start_sample_us = ggml_time_us();
-
-    // Create a frequency map to count occurrences of each token in last_tokens
-    std::unordered_map<llama_token, int> token_count;
-    for (size_t i = 0; i < penalty_last_n; ++i) {
-        token_count[last_tokens[i]]++;
-    }
-
-    // Apply frequency and presence penalties to the candidates
-    for (size_t i = 0; i < candidates->size; ++i) {
-        const auto token_iter = token_count.find(candidates->data[i].id);
-        if (token_iter == token_count.end()) {
-            continue;
-        }
-
-        const int count = token_iter->second;
-
-        // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
-        // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
-        if (candidates->data[i].logit <= 0) {
-            candidates->data[i].logit *= penalty_repeat;
-        } else {
-            candidates->data[i].logit /= penalty_repeat;
-        }
-
-        candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present;
-    }
-
-    candidates->sorted = false;
-
-    if (ctx) {
-        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
-}
-
-void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
-    GGML_ASSERT(ctx);
-    int64_t t_start_sample_us = ggml_time_us();
-
-    bool allow_eog = false;
-    for (const auto & stack : grammar->stacks) {
-        if (stack.empty()) {
-            allow_eog = true;
-            break;
-        }
-    }
-
-    std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
-    candidates_decoded.reserve(candidates->size);
-
-    std::vector<llama_grammar_candidate> candidates_grammar;
-    candidates_grammar.reserve(candidates->size);
-
-    for (size_t i = 0; i < candidates->size; ++i) {
-        const llama_token id      = candidates->data[i].id;
-        const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(id);
-
-        if (llama_token_is_eog(&ctx->model, id)) {
-            if (!allow_eog) {
-                candidates->data[i].logit = -INFINITY;
-            }
-        } else if (piece.empty() || piece[0] == 0) {
-            candidates->data[i].logit = -INFINITY;
-        } else {
-            candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
-            candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
-        }
-    }
-
-    const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar);
-    for (const auto & reject : rejects) {
-        candidates->data[reject.index].logit = -INFINITY;
-    }
-
-    ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-}
-
-static void llama_log_softmax(float * array, size_t size) {
-    float max_l = *std::max_element(array, array + size);
-    float sum = 0.f;
-    for (size_t i = 0; i < size; ++i) {
-        float p = expf(array[i] - max_l);
-        sum += p;
-        array[i] = p;
-    }
-
-    for (size_t i = 0; i < size; ++i) {
-        array[i] = logf(array[i] / sum);
-    }
-}
-
-void llama_sample_apply_guidance(
-          struct llama_context * ctx,
-                         float * logits,
-                         float * logits_guidance,
-                         float   scale) {
-    GGML_ASSERT(ctx);
-
-    const auto t_start_sample_us = ggml_time_us();
-    const auto n_vocab = llama_n_vocab(llama_get_model(ctx));
-
-    llama_log_softmax(logits, n_vocab);
-    llama_log_softmax(logits_guidance, n_vocab);
-
-    for (int i = 0; i < n_vocab; ++i) {
-              auto & l = logits[i];
-        const auto & g = logits_guidance[i];
-
-        l = scale * (l - g) + g;
-    }
-
-    ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-}
-
-llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
-    GGML_ASSERT(ctx);
-
-    auto N = float(llama_n_vocab(llama_get_model(ctx)));
-    int64_t t_start_sample_us;
-    t_start_sample_us = ggml_time_us();
-
-    llama_sample_softmax(nullptr, candidates);
-
-    // Estimate s_hat using the most probable m tokens
-    float s_hat = 0.0;
-    float sum_ti_bi = 0.0;
-    float sum_ti_sq = 0.0;
-    for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
-        float t_i = logf(float(i + 2) / float(i + 1));
-        float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
-        sum_ti_bi += t_i * b_i;
-        sum_ti_sq += t_i * t_i;
-    }
-    s_hat = sum_ti_bi / sum_ti_sq;
-
-    // Compute k from the estimated s_hat and target surprise value
-    float epsilon_hat = s_hat - 1;
-    float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat);
-
-    // Sample the next word X using top-k sampling
-    llama_sample_top_k(nullptr, candidates, int(k), 1);
-    ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    llama_token X = llama_sample_token(ctx, candidates);
-    t_start_sample_us = ggml_time_us();
-
-    // Compute error as the difference between observed surprise and target surprise value
-    size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
-        return candidate.id == X;
-    }));
-    float observed_surprise = -log2f(candidates->data[X_idx].p);
-    float e = observed_surprise - tau;
-
-    // Update mu using the learning rate and error
-    *mu = *mu - eta * e;
-
-    ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    return X;
-}
-
-llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) {
-    int64_t t_start_sample_us;
-    t_start_sample_us = ggml_time_us();
-
-    llama_sample_softmax(ctx, candidates);
-
-    // Truncate the words with surprise values greater than mu
-    candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
-        return -log2f(candidate.p) > *mu;
-    }));
-
-    if (candidates->size == 0) {
-        candidates->size = 1;
-    }
-
-    if (ctx) {
-        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
-
-    // Normalize the probabilities of the remaining words
-    llama_sample_softmax(ctx, candidates);
-
-    // Sample the next word X from the remaining words
-    llama_token X = llama_sample_token(ctx, candidates);
-    t_start_sample_us = ggml_time_us();
-
-    // Compute error as the difference between observed surprise and target surprise value
-    size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
-        return candidate.id == X;
-    }));
-    float observed_surprise = -log2f(candidates->data[X_idx].p);
-    float e = observed_surprise - tau;
-
-    // Update mu using the learning rate and error
-    *mu = *mu - eta * e;
-
-    if (ctx) {
-        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
-    return X;
-}
-
-llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) {
-    const int64_t t_start_sample_us = ggml_time_us();
-
-    // Find max element
-    auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
-        return a.logit < b.logit;
-    });
-
-    llama_token result = max_iter->id;
-    if (ctx) {
-        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-        ctx->n_sample++;
-    }
-    return result;
-}
-
-llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) {
-    GGML_ASSERT(ctx);
-
-    const int64_t t_start_sample_us = ggml_time_us();
-    llama_sample_softmax(nullptr, candidates);
-
-    std::vector<float> probs;
-    probs.reserve(candidates->size);
-    for (size_t i = 0; i < candidates->size; ++i) {
-        probs.push_back(candidates->data[i].p);
-    }
-
-    std::discrete_distribution<> dist(probs.begin(), probs.end());
-    int idx = dist(rng);
-
-    llama_token result = candidates->data[idx].id;
-
-    ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    ctx->n_sample++;
-    return result;
-}
-
-llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
-    return llama_sample_token_with_rng(ctx, candidates, ctx->rng);
-}
-
-void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
-    const int64_t t_start_sample_us = ggml_time_us();
-
-    if (llama_token_is_eog(&ctx->model, token)) {
-        for (const auto & stack : grammar->stacks) {
-            if (stack.empty()) {
-                return;
-            }
-        }
-        GGML_ASSERT(false);
-    }
-
-    const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(token);
-
-    // Note terminating 0 in decoded string
-    const auto   decoded     = decode_utf8(piece, grammar->partial_utf8);
-    const auto & code_points = decoded.first;
-    std::vector<std::vector<const llama_grammar_element *>> tmp_new_stacks;
-    for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
-        llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks);
-        grammar->stacks = tmp_new_stacks;
-    }
-    grammar->partial_utf8 = decoded.second;
-    GGML_ASSERT(!grammar->stacks.empty());
-
-    ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-}
-
-//
-// quantization
-//
-
-struct quantize_state_internal {
-    const llama_model                 & model;
-    const llama_model_quantize_params * params;
-
-    int n_attention_wv    = 0;
-    int n_ffn_down        = 0;
-    int n_ffn_gate        = 0;
-    int n_ffn_up          = 0;
-    int i_attention_wv    = 0;
-    int i_ffn_down        = 0;
-    int i_ffn_gate        = 0;
-    int i_ffn_up          = 0;
-
-    int n_k_quantized     = 0;
-    int n_fallback        = 0;
-
-    bool has_imatrix      = false;
-
-    // used to figure out if a model shares tok_embd with the output weight
-    bool has_output       = false;
-
-    quantize_state_internal(const llama_model & model, const llama_model_quantize_params * params)
-        : model(model)
-        , params(params)
-        {}
-};
+    quantize_state_internal(const llama_model & model, const llama_model_quantize_params * params)
+        : model(model)
+        , params(params)
+        {}
+};
 
 static void llama_tensor_dequantize_internal(
     struct ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
@@ -17655,7 +15246,7 @@ static void llama_tensor_dequantize_internal(
         } else if (ggml_is_quantized(tensor->type)) {
             qtype.to_float(tensor->data, f32_output, nelements);
         } else {
-            GGML_ASSERT(false); // unreachable
+            GGML_ABORT("fatal error"); // unreachable
         }
         return;
     }
@@ -17760,6 +15351,10 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
             else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
                 new_type = GGML_TYPE_IQ3_S;
             }
+            else if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8 ||
+                     new_type == GGML_TYPE_Q4_0_8_8) {
+                new_type = GGML_TYPE_Q4_0;
+            }
         }
     } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S ||
                ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M    || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
@@ -17943,10 +15538,10 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
     //    if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K;
     //}
     bool convert_incompatible_tensor = false;
-    if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K ||
-        new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K || new_type == GGML_TYPE_IQ4_XS ||
-        new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S ||
-        new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S || new_type == GGML_TYPE_IQ3_S ||
+    if (new_type == GGML_TYPE_Q2_K    || new_type == GGML_TYPE_Q3_K    || new_type == GGML_TYPE_Q4_K   ||
+        new_type == GGML_TYPE_Q5_K    || new_type == GGML_TYPE_Q6_K    || new_type == GGML_TYPE_IQ4_XS ||
+        new_type == GGML_TYPE_IQ2_XS  || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S  ||
+        new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S   || new_type == GGML_TYPE_IQ3_S  ||
         new_type == GGML_TYPE_IQ1_M) {
         int nx = tensor->ne[0];
         int ny = tensor->ne[1];
@@ -18072,6 +15667,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         case LLAMA_FTYPE_MOSTLY_IQ4_XS:  default_type = GGML_TYPE_IQ4_XS;  break;
         case LLAMA_FTYPE_MOSTLY_IQ3_S:   default_type = GGML_TYPE_IQ3_S;   break;
         case LLAMA_FTYPE_MOSTLY_IQ3_M:   default_type = GGML_TYPE_IQ3_S;   break;
+        case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: default_type = GGML_TYPE_Q4_0_4_4; break;
+        case LLAMA_FTYPE_MOSTLY_Q4_0_4_8: default_type = GGML_TYPE_Q4_0_4_8; break;
+        case LLAMA_FTYPE_MOSTLY_Q4_0_8_8: default_type = GGML_TYPE_Q4_0_8_8; break;
 
         default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
     }
@@ -18129,8 +15727,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
 
     // copy the KV pairs from the input file
     gguf_set_kv     (ctx_out, ml.meta);
-    gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION);
-    gguf_set_val_u32(ctx_out, "general.file_type", ftype);
+    gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
+    gguf_set_val_u32(ctx_out, "general.file_type", ftype); // TODO: use LLM_KV
+
     // Remove split metadata
     gguf_remove_key(ctx_out, ml.llm_kv(LLM_KV_SPLIT_NO).c_str());
     gguf_remove_key(ctx_out, ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str());
@@ -18382,6 +15981,14 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
                 f32_data = (float *) f32_conv_buf.data();
             }
 
+            int chunk_size_multiplier = 1;
+            if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8 || new_type == GGML_TYPE_Q4_0_8_8) {
+                if ((new_type == GGML_TYPE_Q4_0_8_8) && (tensor->ne[1] % 8 != 0)) new_type = GGML_TYPE_Q4_0;
+                else if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q4_0;
+                if (new_type == GGML_TYPE_Q4_0_8_8) chunk_size_multiplier = 8;
+                else if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8) chunk_size_multiplier = 4;
+            }
+
             LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type));
             fflush(stdout);
 
@@ -18394,7 +16001,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
             const int64_t nrows = tensor->ne[1];
 
             static const int64_t min_chunk_size = 32 * 512;
-            const int64_t chunk_size = n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row);
+            const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row)) *
+                                       chunk_size_multiplier;
 
             const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1];
             const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size;
@@ -18436,282 +16044,214 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
     }
 }
 
-static int llama_apply_lora_from_file_internal(
-    const struct llama_model & model, const char * path_lora, float scale, const char * path_base_model, int n_threads
-) {
-    LLAMA_LOG_INFO("%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora);
-
-    const int64_t t_start_lora_us = ggml_time_us();
-
-    llama_file fin(path_lora, "rb");
-
-    // verify magic and version
-    {
-        uint32_t magic = fin.read_u32();
-        if (magic != LLAMA_FILE_MAGIC_GGLA) {
-            LLAMA_LOG_ERROR("%s: bad file magic\n", __func__);
-            return 1;
-        }
-
-        uint32_t format_version = fin.read_u32();
-        if (format_version != 1) {
-            LLAMA_LOG_ERROR("%s: unsupported file version\n", __func__ );
-            return 1;
-        }
-    }
-
-    int32_t lora_r = fin.read_u32();
-    int32_t lora_alpha = fin.read_u32();
-    float scaling = scale * (float)lora_alpha / (float)lora_r;
-
-    LLAMA_LOG_INFO("%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling);
+static void llama_lora_adapter_init_internal(struct llama_model * model, const char * path_lora, struct llama_lora_adapter & adapter) {
+    LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
 
-    // load base model
-    std::unique_ptr<llama_model_loader> ml;
-    if (path_base_model) {
-        LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model);
-        ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*check_tensors*/ false, /*kv_overrides*/ nullptr));
-        ml->init_mappings(/*prefetch*/ false); // no prefetching
-    }
-
-    struct tensor_meta {
-        std::string name;
-        ggml_type type;
-        int32_t ne[2];
-        size_t offset;
+    ggml_context * ctx = nullptr;
+    struct gguf_init_params meta_gguf_params = {
+        /* .no_alloc = */ true,
+        /* .ctx      = */ &ctx,
     };
-    std::map<std::string, tensor_meta> tensor_meta_map;
-
-    // load all tensor meta
-    while (true) {
-        if (fin.tell() == fin.size) {
-            // eof
-            break;
-        }
-
-        int32_t n_dims;
-        int32_t name_len;
-        int32_t ftype;
-
-        fin.read_raw(&n_dims, sizeof(n_dims));
-        fin.read_raw(&name_len, sizeof(name_len));
-        fin.read_raw(&ftype, sizeof(ftype));
-
-        if (n_dims != 1 && n_dims != 2) {
-            LLAMA_LOG_ERROR("%s: unsupported tensor dimension %d\n", __func__, n_dims);
-            return 1;
-        }
+    struct gguf_context * ctx_gguf = gguf_init_from_file(path_lora, meta_gguf_params);
+    if (!ctx_gguf) {
+        throw std::runtime_error("failed to load lora adapter file from " + std::string(path_lora));
+    }
 
-        int32_t ne[2] = { 1, 1 };
-        for (int i = 0; i < n_dims; ++i) {
-            fin.read_raw(&ne[i], sizeof(ne[i]));
-        }
+    // check metadata
+    {
+        auto get_kv_str = [&](const std::string & key) -> std::string {
+            int id = gguf_find_key(ctx_gguf, key.c_str());
+            return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf, id));
+        };
+        auto get_kv_f32 = [&](const std::string & key) -> float {
+            int id = gguf_find_key(ctx_gguf, key.c_str());
+            return id < 0 ? 0.0f : gguf_get_val_f32(ctx_gguf, id);
+        };
+        LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
 
-        std::string name;
-        {
-            GGML_ASSERT(name_len < GGML_MAX_NAME);
-            char buf[GGML_MAX_NAME];
-            fin.read_raw(buf, name_len);
-            name = std::string(buf, name_len);
+        auto general_type = get_kv_str(llm_kv(LLM_KV_GENERAL_TYPE));
+        if (general_type != "adapter") {
+            gguf_free(ctx_gguf);
+            throw std::runtime_error("expect general.type to be 'adapter', but got: " + general_type);
         }
 
-        // check for lora suffix
-        std::string lora_suffix;
-        if (name.length() > 6) {
-            lora_suffix = name.substr(name.length() - 6);
-        }
-        if (lora_suffix != ".loraA" && lora_suffix != ".loraB") {
-            LLAMA_LOG_ERROR("%s: error: '%s' is not a lora tensor\n", __func__, name.c_str());
-            return 1;
+        auto general_arch_str = get_kv_str(llm_kv(LLM_KV_GENERAL_ARCHITECTURE));
+        auto general_arch = llm_arch_from_string(general_arch_str);
+        if (general_arch != model->arch) {
+            gguf_free(ctx_gguf);
+            throw std::runtime_error("model arch and LoRA arch mismatch");
         }
 
-        // tensor type
-        ggml_type wtype;
-        switch (ftype) {
-            case 0: wtype = GGML_TYPE_F32;  break;
-            case 1: wtype = GGML_TYPE_F16;  break;
-            default:
-                    {
-                        LLAMA_LOG_ERROR("%s: invalid tensor data type '%d'\n",
-                                __func__, ftype);
-                        return 1;
-                    }
+        auto adapter_type = get_kv_str(llm_kv(LLM_KV_ADAPTER_TYPE));
+        if (adapter_type != "lora") {
+            gguf_free(ctx_gguf);
+            throw std::runtime_error("expect adapter.type to be 'lora', but got: " + adapter_type);
         }
 
-        // data offset
-        size_t offset = fin.tell();
-        offset = (offset + 31) & -32;
-
-        // skip tensor data
-        fin.seek(offset + ggml_row_size(wtype, ne[0]) * ne[1], SEEK_SET);
-
-        tensor_meta_map.emplace(name, tensor_meta{ name, wtype, { ne[0], ne[1] }, offset });
-    }
-
-    bool warned = false;
-    int n_tensors = 0;
-
-    // apply
-    ggml_backend_t backend_cpu = ggml_backend_cpu_init();
-    if (backend_cpu == nullptr) {
-        LLAMA_LOG_ERROR("%s: error: failed to initialize cpu backend\n", __func__);
-        return 1;
+        adapter.alpha = get_kv_f32(llm_kv(LLM_KV_ADAPTER_LORA_ALPHA));
     }
-    ggml_backend_cpu_set_n_threads(backend_cpu, n_threads);
-
-    std::vector<no_init<uint8_t>> read_buf;
-    for (const auto & it : model.tensors_by_name) {
-        const std::string & base_name = it.first;
-        ggml_tensor * model_t = it.second;
-
-        if (tensor_meta_map.find(base_name + ".loraA") == tensor_meta_map.end() ||
-            tensor_meta_map.find(base_name + ".loraB") == tensor_meta_map.end()) {
-            continue;
-        }
 
-        tensor_meta & metaA = tensor_meta_map.at(base_name + ".loraA");
-        tensor_meta & metaB = tensor_meta_map.at(base_name + ".loraB");
+    int n_tensors = gguf_get_n_tensors(ctx_gguf);
 
-        ggml_init_params lora_init_params = {
-            /* .mem_size   */ ggml_tensor_overhead()*128 + ggml_graph_overhead(),
-            /* .mem_buffer */ nullptr,
-            /* .no_alloc   */ true,
+    // contexts for each buffer type
+    std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
+    auto get_ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
+        auto it = ctx_map.find(buft);
+        if (it == ctx_map.end()) {
+            // add a new context
+            struct ggml_init_params params = {
+                /*.mem_size   =*/ n_tensors*ggml_tensor_overhead(),
+                /*.mem_buffer =*/ NULL,
+                /*.no_alloc   =*/ true,
+            };
+            ggml_context * buft_ctx = ggml_init(params);
+            ctx_map[buft] = buft_ctx;
+            return buft_ctx;
         };
-        ggml_context * lora_ctx = ggml_init(lora_init_params);
-        if (lora_ctx == nullptr) {
-            LLAMA_LOG_ERROR("%s: error: failed to initialize lora context\n", __func__);
-            ggml_backend_free(backend_cpu);
-            return 1;
-        }
-
-        // create tensors
-        ggml_tensor * loraA = ggml_new_tensor_2d(lora_ctx, metaA.type, metaA.ne[0], metaA.ne[1]);
-        ggml_tensor * loraB = ggml_new_tensor_2d(lora_ctx, metaB.type, metaB.ne[0], metaB.ne[1]);
-        ggml_set_name(loraA, metaA.name.c_str());
-        ggml_set_name(loraB, metaB.name.c_str());
+        return it->second;
+    };
 
-        ggml_tensor * base_t;
-        if (ml) {
-            if (!ml->get_tensor_meta(base_name.c_str())) {
-                LLAMA_LOG_ERROR("%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str());
-                return 1;
+    // bundle lora_a and lora_b into pairs
+    std::map<std::string, llama_lora_weight> ab_map;
+    auto str_endswith = [](const std::string & str, const std::string & suffix) {
+        return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
+    };
+    for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
+        std::string name(cur->name);
+        if (str_endswith(name, ".lora_a")) {
+            replace_all(name, ".lora_a", "");
+            if (ab_map.find(name) == ab_map.end()) {
+                ab_map[name] = llama_lora_weight(cur, nullptr);
+            } else {
+                ab_map[name].a = cur;
+            }
+        } else if (str_endswith(name, ".lora_b")) {
+            replace_all(name, ".lora_b", "");
+            if (ab_map.find(name) == ab_map.end()) {
+                ab_map[name] = llama_lora_weight(nullptr, cur);
+            } else {
+                ab_map[name].b = cur;
             }
-            base_t = ggml_dup_tensor(lora_ctx, ml->get_tensor_meta(base_name.c_str()));
         } else {
-            base_t = ggml_dup_tensor(lora_ctx, model_t);
-        }
-        ggml_set_name(base_t, base_name.c_str());
-
-        // allocate in backend buffer
-        ggml_backend_buffer_t lora_buf = ggml_backend_alloc_ctx_tensors_from_buft(lora_ctx, ggml_backend_cpu_buffer_type());
-        if (lora_buf == nullptr) {
-            LLAMA_LOG_ERROR("%s: error: failed to allocate lora tensors\n", __func__);
-            return 1;
+            gguf_free(ctx_gguf);
+            ggml_free(ctx);
+            throw std::runtime_error("LoRA tensor '" + name + "' has unexpected suffix");
         }
+    }
 
-        // load tensor data
-        auto load_tensor = [&read_buf, &fin](const tensor_meta & tensor_meta, ggml_tensor * tensor) {
-            read_buf.resize(ggml_nbytes(tensor));
-            fin.seek(tensor_meta.offset, SEEK_SET);
-            fin.read_raw(read_buf.data(), ggml_nbytes(tensor));
-            ggml_backend_tensor_set(tensor, read_buf.data(), 0, read_buf.size());
-        };
-        load_tensor(metaA, loraA);
-        load_tensor(metaB, loraB);
+    // add tensors
+    for (auto & it : ab_map) {
+        const std::string & name = it.first;
+        llama_lora_weight & w = it.second;
 
-        // load base model tensor data
-        if (ml) {
-            ml->load_data_for(base_t);
-        } else {
-            ggml_backend_tensor_copy(model_t, base_t);
+        if (!w.a || !w.b) {
+            gguf_free(ctx_gguf);
+            ggml_free(ctx);
+            throw std::runtime_error("LoRA tensor pair for '" + name + "' is missing one component");
         }
 
-        if (ggml_is_quantized(base_t->type) && !warned) {
-            LLAMA_LOG_WARN("%s: warning: using a lora adapter with a quantized model may result in poor quality, "
-                            "use a f16 or f32 base model with --lora-base\n", __func__);
-            warned = true;
+        // device buft and device ctx
+        auto * model_tensor = llama_get_model_tensor(model, name.c_str());
+        if (!model_tensor) {
+            gguf_free(ctx_gguf);
+            ggml_free(ctx);
+            throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model");
         }
-
-        if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) {
-            LLAMA_LOG_ERROR("%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");"
-                            " are you sure that this adapter is for this model?\n", __func__, base_t->ne[0], loraA->ne[1]);
-            ggml_free(lora_ctx);
-            ggml_backend_buffer_free(lora_buf);
-            ggml_backend_free(backend_cpu);
-            return 1;
+        struct ggml_context * dev_ctx = get_ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer));
+        // validate tensor shape
+        if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) {
+            gguf_free(ctx_gguf);
+            ggml_free(ctx);
+            throw std::runtime_error("tensor '" + name + "' has incorrect shape");
         }
+        if (w.a->ne[1] != w.b->ne[0]) {
+            gguf_free(ctx_gguf);
+            ggml_free(ctx);
+            throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)");
+        }
+        // save tensor to adapter
+        struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
+        struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
+        ggml_set_name(tensor_a, w.a->name);
+        ggml_set_name(tensor_b, w.b->name);
+        adapter.ab_map[name] = llama_lora_weight(tensor_a, tensor_b);
+    }
 
-        auto build_lora_graph = [&]() {
-            // w = w + BA*s
-            ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraA, loraB);
-            ggml_set_name(BA, "BA");
-
-            if (scaling != 1.0f) {
-                BA = ggml_scale(lora_ctx, BA, scaling);
-                ggml_set_name(BA, "BA_scaled");
-            }
-
-            ggml_tensor * r;
-            r = ggml_add_inplace(lora_ctx, base_t, BA);
-            ggml_set_name(r, "r_add");
-
-            if (base_t->type != model_t->type) {
-                // convert the result to the model type
-                r = ggml_cast(lora_ctx, r, model_t->type);
-                ggml_set_name(r, "r_cast");
+    // allocate tensors / buffers and zero
+    {
+        adapter.ctxs.reserve(ctx_map.size());
+        adapter.bufs.reserve(ctx_map.size());
+        for (auto it : ctx_map) {
+            ggml_backend_buffer_type_t buft = it.first;
+            ggml_context * ctx_dev = it.second;
+            ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx_dev, buft);
+            if (!buf) {
+                gguf_free(ctx_gguf);
+                ggml_free(ctx);
+                throw std::runtime_error("failed to allocate buffer for lora adapter\n");
             }
+            LLAMA_LOG_INFO("%s: %10s LoRA buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
+            adapter.ctxs.push_back(ctx_dev);
+            adapter.bufs.push_back(buf);
+        }
+    }
 
-            return r;
+    // set tensor data
+    {
+        llama_file gguf_file(path_lora, "rb");
+        std::vector<uint8_t> read_buf;
+        auto set_tensor = [&](struct ggml_tensor * orig, struct ggml_tensor * dev) {
+            size_t offs = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, gguf_find_tensor(ctx_gguf, orig->name));
+            size_t size = ggml_nbytes(orig);
+            read_buf.resize(size);
+            gguf_file.seek(offs, SEEK_SET);
+            gguf_file.read_raw(read_buf.data(), size);
+            ggml_backend_tensor_set(dev, read_buf.data(), 0, size);
         };
-
-        ggml_cgraph * gf = ggml_new_graph(lora_ctx);
-        ggml_tensor * r = build_lora_graph();
-        ggml_build_forward_expand(gf, r);
-
-        ggml_backend_buffer_t graph_buf = ggml_backend_alloc_ctx_tensors_from_buft(lora_ctx, ggml_backend_cpu_buffer_type());
-        if (graph_buf == nullptr) {
-            LLAMA_LOG_ERROR("%s: error: failed to allocate graph tensors\n", __func__);
-            ggml_free(lora_ctx);
-            ggml_backend_buffer_free(lora_buf);
-            ggml_backend_free(backend_cpu);
-            return 1;
+        for (auto & it : adapter.ab_map) {
+            auto orig = ab_map[it.first];
+            auto dev  = it.second;
+            set_tensor(orig.a, dev.a);
+            set_tensor(orig.b, dev.b);
         }
+    }
 
-        ggml_backend_graph_compute(backend_cpu, gf);
-
-        ggml_backend_tensor_set(model_t, r->data, 0, ggml_nbytes(r));
-
-#if 0
-        // TODO: use scheduler with fallback to CPU for less copies between CPU and GPU
-        //ggml_backend_sched_t sched = ggml_backend_sched_new(backends.data(), backends.size(), GGML_DEFAULT_GRAPH_SIZE);
-
-        // sched compute
-        ggml_build_forward_expand(gf, build_graph());
-        ggml_backend_sched_init_measure(sched, gf);
-
-        // create the graph again, since the previous one was destroyed by the measure
-        ggml_graph_clear(gf);
-        ggml_build_forward_expand(gf, build_graph());
-        ggml_backend_sched_graph_compute(sched, gf);
-        ggml_backend_sched_free(sched);
-#endif
+    LLAMA_LOG_INFO("%s: loaded %ld tensors from lora file\n", __func__, adapter.ab_map.size()*2);
 
-        ggml_backend_buffer_free(lora_buf);
-        ggml_backend_buffer_free(graph_buf);
-        ggml_free(lora_ctx);
+    // free ctx for reading gguf
+    gguf_free(ctx_gguf);
+    ggml_free(ctx);
+}
 
-        n_tensors++;
-        if (n_tensors % 4 == 0) {
-            LLAMA_LOG_INFO(".");
-        }
+int32_t llama_lora_adapter_set(
+            struct llama_context * ctx,
+            struct llama_lora_adapter * adapter,
+            float scale) {
+    if (ctx->cparams.flash_attn) {
+        LLAMA_LOG_ERROR("%s: flash_attn is not compatible with LoRA\n", __func__);
+        return -1;
     }
+    ctx->lora_adapters[adapter] = scale;
+    return 0;
+}
 
-    ggml_backend_free(backend_cpu);
+int32_t llama_lora_adapter_remove(
+            struct llama_context * ctx,
+            struct llama_lora_adapter * adapter) {
+    auto pos = ctx->lora_adapters.find(adapter);
+    if (pos != ctx->lora_adapters.end()) {
+        ctx->lora_adapters.erase(pos);
+        return 0;
+    }
+    return -1;
+}
 
-    const int64_t t_lora_us = ggml_time_us() - t_start_lora_us;
-    LLAMA_LOG_INFO(" done (%.2f ms)\n", t_lora_us / 1000.0);
+void llama_lora_adapter_clear(struct llama_context * ctx) {
+    ctx->lora_adapters.clear();
+}
 
-    return 0;
+void llama_lora_adapter_free(struct llama_lora_adapter * adapter) {
+    delete adapter;
 }
 
 //
@@ -18805,6 +16345,8 @@ size_t llama_max_devices(void) {
     return GGML_SYCL_MAX_DEVICES;
 #elif defined(GGML_USE_VULKAN)
     return GGML_VK_MAX_DEVICES;
+#elif defined(GGML_USE_CANN)
+    return GGML_CANN_MAX_DEVICES;
 #else
     return 1;
 #endif
@@ -19034,8 +16576,8 @@ struct llama_context * llama_new_context_with_model(
     ctx->abort_callback      = params.abort_callback;
     ctx->abort_callback_data = params.abort_callback_data;
 
-    ctx->rng                 = std::mt19937(params.seed);
-    ctx->logits_all          = params.logits_all;
+    ctx->sampling.rng = std::mt19937(params.seed);
+    ctx->logits_all   = params.logits_all;
 
     uint32_t kv_size = cparams.n_ctx;
     ggml_type type_k = params.type_k;
@@ -19127,9 +16669,7 @@ struct llama_context * llama_new_context_with_model(
             for (int i = 0; i < ggml_backend_sycl_get_device_count(); ++i) {
                 ggml_backend_t backend = ggml_backend_sycl_init(i);
                 if (backend == nullptr) {
-                    int id_list[GGML_SYCL_MAX_DEVICES];
-                    ggml_sycl_get_gpu_list(id_list, GGML_SYCL_MAX_DEVICES);
-                    LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d (index %d) backend\n", __func__, id_list[i], i);
+                    LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d for No.%d backend\n", __func__, i, i);
                     llama_free(ctx);
                     return nullptr;
                 }
@@ -19146,6 +16686,30 @@ struct llama_context * llama_new_context_with_model(
             }
             ctx->backends.push_back(backend);
         }
+#elif defined(GGML_USE_CANN)
+    // with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
+    // TODO: ggml_backend_cann is not support split tensor now, just leave code here.
+    if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
+        ggml_backend_t backend = ggml_backend_cann_init(model->main_gpu);
+        if (backend == nullptr) {
+            LLAMA_LOG_ERROR("%s: failed to initialize CANN%d backend\n", __func__, model->main_gpu);
+            llama_free(ctx);
+            return nullptr;
+        }
+        ctx->backends.push_back(backend);
+    } else {
+        // LLAMA_SPLIT_MODE_LAYER requires a backend for each GPU
+        // TODO: currently, CANN can't use multi-gpus, just leave code here for further cann version.
+        for (int32_t device = 0; device < ggml_backend_cann_get_device_count(); ++device) {
+            ggml_backend_t backend = ggml_backend_cann_init(device);
+            if (backend == nullptr) {
+                LLAMA_LOG_ERROR("%s: failed to initialize CANN%d backend\n", __func__, device);
+                llama_free(ctx);
+                return nullptr;
+            }
+            ctx->backends.push_back(backend);
+        }
+    }
 #endif
 
 #ifdef GGML_USE_BLAS
@@ -19229,8 +16793,10 @@ struct llama_context * llama_new_context_with_model(
                 }
             }
 
+            const size_t max_nodes = llama_model_max_nodes(*model);
+
             // buffer used to store the computation graph and the tensor meta data
-            ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead_custom(LLAMA_MAX_NODES, false));
+            ctx->buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
 
             // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
             bool pipeline_parallel =
@@ -19243,7 +16809,7 @@ struct llama_context * llama_new_context_with_model(
             // currently this is only implemented in the CUDA backend
             pipeline_parallel = false;
 #endif
-            ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), LLAMA_MAX_NODES, pipeline_parallel);
+            ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), max_nodes, pipeline_parallel);
 
             if (pipeline_parallel) {
                 LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(ctx->sched));
@@ -19287,10 +16853,14 @@ void llama_free(struct llama_context * ctx) {
     delete ctx;
 }
 
-const llama_model * llama_get_model(const struct llama_context * ctx) {
+const struct llama_model * llama_get_model(const struct llama_context * ctx) {
     return &ctx->model;
 }
 
+const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx) {
+    return &ctx->model.vocab;
+}
+
 uint32_t llama_n_ctx(const struct llama_context * ctx) {
     return ctx->cparams.n_ctx;
 }
@@ -19330,7 +16900,6 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_BAICHUAN:
         case LLM_ARCH_STARCODER:
         case LLM_ARCH_PLAMO:
-        case LLM_ARCH_CODESHELL:
         case LLM_ARCH_ORION:
         case LLM_ARCH_INTERNLM2:
         case LLM_ARCH_MINICPM:
@@ -19360,12 +16929,12 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_STARCODER2:
         case LLM_ARCH_OPENELM:
         case LLM_ARCH_GPTNEOX:
+        case LLM_ARCH_CODESHELL:
             return LLAMA_ROPE_TYPE_NEOX;
 
         // all model arches should be listed explicitly here
         case LLM_ARCH_UNKNOWN:
-            GGML_ASSERT(false && "unknown architecture");
-            break;
+            GGML_ABORT("unknown architecture");
     }
 
     return LLAMA_ROPE_TYPE_NONE;
@@ -19492,12 +17061,14 @@ uint32_t llama_model_quantize(
     }
 }
 
-int32_t llama_model_apply_lora_from_file(const struct llama_model * model, const char * path_lora, float scale, const char * path_base_model, int32_t n_threads) {
+struct llama_lora_adapter * llama_lora_adapter_init(struct llama_model * model, const char * path_lora) {
     try {
-        return llama_apply_lora_from_file_internal(*model, path_lora, scale, path_base_model, n_threads);
+        struct llama_lora_adapter * adapter = new llama_lora_adapter(model);
+        llama_lora_adapter_init_internal(model, path_lora, *adapter);
+        return adapter;
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());
-        return 1;
+        return nullptr;
     }
 }
 
@@ -19745,18 +17316,18 @@ void llama_kv_cache_update(struct llama_context * ctx) {
 }
 
 // deprecated
-size_t llama_get_state_size(const struct llama_context * ctx) {
+size_t llama_get_state_size(struct llama_context * ctx) {
     return llama_state_get_size(ctx);
 }
 
 // deprecated
 size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
-    return llama_state_get_data(ctx, dst);
+    return llama_state_get_data(ctx, dst, -1);
 }
 
 // deprecated
 size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
-    return llama_state_set_data(ctx, src);
+    return llama_state_set_data(ctx, src, -1);
 }
 
 // deprecated
@@ -19769,302 +17340,284 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
     return llama_state_save_file(ctx, path_session, tokens, n_token_count);
 }
 
-// Returns the *maximum* size of the state
-size_t llama_state_get_size(const struct llama_context * ctx) {
-    const auto & cparams = ctx->cparams;
-    const auto & hparams = ctx->model.hparams;
-
-    // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
-    // for reference, std::mt19937(1337) serializes to 6701 bytes.
-    const size_t s_rng_size        = sizeof(size_t);
-    const size_t s_rng             = LLAMA_MAX_RNG_STATE;
-    const size_t s_n_outputs       = sizeof(size_t);
-    // assume worst case for outputs although only currently set ones are serialized
-    const size_t s_output_pos      = ctx->cparams.n_batch * sizeof(int32_t);
-    const size_t s_logits_size     = sizeof(size_t);
-    const size_t s_logits          = ctx->logits_size ? cparams.n_batch * hparams.n_vocab * sizeof(float) : 0;
-    const size_t s_embedding_size  = sizeof(size_t);
-    const size_t s_embedding       = ctx->embd_size   ? cparams.n_batch * hparams.n_embd  * sizeof(float) : 0;
-    const size_t s_kv_buf_size     = sizeof(size_t);
-    const size_t s_kv_head         = sizeof(uint32_t);
-    const size_t s_kv_size         = sizeof(uint32_t);
-    const size_t s_kv_used         = sizeof(uint32_t);
-    const size_t s_v_trans         = sizeof(uint32_t);
-    const size_t s_kv              = ctx->kv_self.total_size();
-    const size_t s_kv_cell         = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id);
-    const size_t s_kv_cells        = ctx->kv_self.size * s_kv_cell;
-
-    const size_t s_total = (
-        + s_rng_size
-        + s_rng
-        + s_n_outputs
-        + s_output_pos
-        + s_logits_size
-        + s_logits
-        + s_embedding_size
-        + s_embedding
-        + s_kv_buf_size
-        + s_kv_head
-        + s_kv_size
-        + s_kv_used
-        + s_v_trans
-        + s_kv
-        + s_kv_cells
-    );
-
-    // on session change it is very likely that the state size has changed - so we need to update this function
-    static_assert(LLAMA_SESSION_VERSION == 6, "So you just bumped the session version - good. But did you remember to update llama_state_get_size?");
-
-    return s_total;
-}
-
-// llama_context_data
-struct llama_data_context {
+// TODO: replace all non-fatal assertions with returned errors or exceptions
+struct llama_data_write {
     virtual void write(const void * src, size_t size) = 0;
     virtual size_t get_size_written() = 0;
-    virtual ~llama_data_context() = default;
-};
-
-struct llama_data_buffer_context : llama_data_context {
-    uint8_t * ptr;
-    size_t size_written = 0;
+    virtual ~llama_data_write() = default;
 
-    llama_data_buffer_context(uint8_t * p) : ptr(p) {}
+    void write_string(const std::string & str) {
+        uint32_t str_size = str.size();
 
-    void write(const void * src, size_t size) override {
-        memcpy(ptr, src, size);
-        ptr += size;
-        size_written += size;
+        write(&str_size,  sizeof(str_size));
+        write(str.data(), str_size);
     }
 
-    size_t get_size_written() override {
-        return size_written;
+    void write_model_info(const struct llama_context * ctx) {
+        std::string arch_str = LLM_ARCH_NAMES.at(ctx->model.arch);
+        write_string(arch_str);
+        // TODO: add more model-specific info which should prevent loading the session file if not identical
     }
-};
 
-struct llama_data_file_context : llama_data_context {
-    llama_file * file;
-    size_t size_written = 0;
+    void write_rng(const std::mt19937 & rng) {
+        std::ostringstream rng_ss;
+        rng_ss << rng;
 
-    llama_data_file_context(llama_file * f) : file(f) {}
+        const std::string & rng_str = rng_ss.str();
 
-    void write(const void * src, size_t size) override {
-        file->write_raw(src, size);
-        size_written += size;
+        write_string(rng_str);
     }
 
-    size_t get_size_written() override {
-        return size_written;
-    }
-};
+    void write_output_ids(const struct llama_context * ctx) {
+        const uint32_t n_outputs = ctx->n_outputs;
 
-/** copy state data into either a buffer or file depending on the passed in context
- *
- * file context:
- * llama_file file("/path", "wb");
- * llama_data_file_context data_ctx(&file);
- * llama_state_get_data(ctx, &data_ctx);
- *
- * buffer context:
- * std::vector<uint8_t> buf(max_size, 0);
- * llama_data_buffer_context data_ctx(&buf.data());
- * llama_state_get_data(ctx, &data_ctx);
- *
-*/
-static void llama_state_get_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) {
-    llama_synchronize(ctx);
+        std::vector<int32_t> output_pos;
 
-    // copy rng
-    {
-        std::ostringstream rng_ss;
-        rng_ss << ctx->rng;
+        const size_t    n_batch = ctx->cparams.n_batch;
+        const auto & output_ids = ctx->output_ids;
+
+        GGML_ASSERT(n_outputs <= ctx->output_size);
 
-        const std::string & rng_str  = rng_ss.str();
-        const size_t        rng_size = rng_str.size();
+        output_pos.resize(n_outputs);
 
-        GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE);
+        // build a more compact representation of the output ids
+        for (size_t i = 0; i < n_batch; ++i) {
+            // map an output id to a position in the batch
+            int32_t pos = output_ids[i];
+            if (pos >= 0) {
+                GGML_ASSERT((uint32_t) pos < n_outputs);
+                output_pos[pos] = i;
+            }
+        }
 
-        data_ctx->write(&rng_size,      sizeof(rng_size));
-        data_ctx->write(rng_str.data(), rng_size);
+        write(&n_outputs, sizeof(n_outputs));
+
+        if (n_outputs) {
+            write(output_pos.data(), n_outputs * sizeof(int32_t));
+        }
     }
 
-    // copy outputs
-    {
-        // Can't use ctx->n_outputs because it's not for the
-        // entire last batch when n_ubatch is smaller than n_batch
-        size_t n_outputs = 0;
+    void write_logits(const struct llama_context * ctx) {
+        const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_vocab);
 
-        // copy output ids
-        {
-            std::vector<int32_t> output_pos;
+        write(&logits_size, sizeof(logits_size));
+
+        if (logits_size) {
+            write(ctx->logits, logits_size * sizeof(float));
+        }
+    }
 
-            const size_t    n_batch = ctx->cparams.n_batch;
-            const auto & output_ids = ctx->output_ids;
+    void write_embeddings(const struct llama_context * ctx) {
+        const uint64_t embeddings_size = std::min((uint64_t) ctx->embd_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_embd);
 
-            output_pos.resize(ctx->output_size);
+        write(&embeddings_size, sizeof(embeddings_size));
 
-            // build a more compact representation of the output ids
-            for (size_t i = 0; i < n_batch; ++i) {
-                // map an output id to a position in the batch
-                int32_t pos = output_ids[i];
-                if (pos >= 0) {
-                    if ((size_t) pos >= n_outputs) {
-                        n_outputs = pos + 1;
+        if (embeddings_size) {
+            write(ctx->embd, embeddings_size * sizeof(float));
+        }
+    }
+
+    void write_kv_cache_meta(const llama_kv_cache & kv_self, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) {
+
+        for (const auto & range : cell_ranges) {
+            for (uint32_t i = range.first; i < range.second; ++i) {
+                const auto & cell = kv_self.cells[i];
+                const llama_pos pos      = cell.pos;
+                const uint32_t  n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
+
+                write(&pos,      sizeof(pos));
+                write(&n_seq_id, sizeof(n_seq_id));
+
+                if (n_seq_id) {
+                    for (auto seq_id : cell.seq_id) {
+                        write(&seq_id, sizeof(seq_id));
                     }
-                    GGML_ASSERT((size_t) pos < ctx->output_size);
-                    output_pos[pos] = i;
                 }
             }
+        }
+    }
 
-            data_ctx->write(&n_outputs, sizeof(n_outputs));
+    void write_kv_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) {
+        const struct llama_kv_cache & kv_self = ctx->kv_self;
+        const struct llama_hparams & hparams = ctx->model.hparams;
 
-            if (n_outputs) {
-                data_ctx->write(output_pos.data(), n_outputs * sizeof(int32_t));
-            }
-        }
+        const uint32_t v_trans = kv_self.v_trans ? 1 : 0;
+        const uint32_t n_layer = hparams.n_layer;
 
-        // copy logits
-        {
-            const size_t logits_size = std::min(ctx->logits_size, n_outputs * ctx->model.hparams.n_vocab);
+        write(&v_trans, sizeof(v_trans));
+        write(&n_layer, sizeof(n_layer));
 
-            data_ctx->write(&logits_size, sizeof(logits_size));
+        std::vector<uint8_t> tmp_buf;
 
-            if (logits_size) {
-                data_ctx->write(ctx->logits, logits_size * sizeof(float));
-            }
-        }
+        // Iterate and write all the keys first, each row is a cell
+        // Get whole range at a time
+        for (uint32_t il = 0; il < n_layer; ++il) {
+            const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
 
-        // copy embeddings
-        {
-            const size_t embeddings_size = std::min(ctx->embd_size, n_outputs * ctx->model.hparams.n_embd);
+            // Write key type
+            const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
+            write(&k_type_i, sizeof(k_type_i));
 
-            data_ctx->write(&embeddings_size, sizeof(embeddings_size));
+            // Write row size of key
+            const uint64_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
+            write(&k_size_row, sizeof(k_size_row));
 
-            if (embeddings_size) {
-                data_ctx->write(ctx->embd, embeddings_size * sizeof(float));
+            // Read each range of cells of k_size length each into tmp_buf and write out
+            for (const auto & range : cell_ranges) {
+                const size_t range_size = range.second - range.first;
+                tmp_buf.resize(range_size * k_size_row);
+                ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), range.first * k_size_row, range_size * k_size_row);
+                write(tmp_buf.data(), tmp_buf.size());
             }
         }
-    }
 
-    // copy kv cache
-    {
-        const auto & kv_self = ctx->kv_self;
-        const auto & hparams = ctx->model.hparams;
-
-        const uint32_t n_layer      = hparams.n_layer;
-
-        // NOTE: kv_size and kv_buf_size are mostly used for sanity checks
-        const uint32_t kv_head     = llama_kv_cache_cell_max(kv_self);
-        const uint32_t kv_size     = kv_self.size;
-        const size_t   kv_buf_size = kv_self.total_size() / (kv_size ? kv_size : 1) * kv_head;
-        const uint32_t kv_used     = kv_self.used;
-        const uint32_t v_trans     = kv_self.v_trans ? 1 : 0;
-
-        data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
-        data_ctx->write(&kv_head,     sizeof(kv_head));
-        data_ctx->write(&kv_size,     sizeof(kv_size));
-        data_ctx->write(&kv_used,     sizeof(kv_used));
-        data_ctx->write(&v_trans,     sizeof(v_trans));
-
-        if (kv_buf_size) {
-            const size_t pre_kv_buf_size = data_ctx->get_size_written();
-
-            std::vector<uint8_t> tmp_buf;
-            for (int il = 0; il < (int) n_layer; ++il) {
-                const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+        if (!kv_self.v_trans) {
+            for (uint32_t il = 0; il < n_layer; ++il) {
                 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
 
-                const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
+                // Write value type
+                const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
+                write(&v_type_i, sizeof(v_type_i));
 
-                tmp_buf.resize(k_size);
-                ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size());
-                data_ctx->write(tmp_buf.data(), tmp_buf.size());
+                // Write row size of value
+                const uint64_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
+                write(&v_size_row, sizeof(v_size_row));
 
-                if (kv_self.recurrent || !kv_self.v_trans) {
-                    // v is contiguous for recurrent models
-                    // TODO: use other tensors for state models than k and v
-                    const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
-
-                    tmp_buf.resize(v_size);
-                    ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), 0, tmp_buf.size());
-                    data_ctx->write(tmp_buf.data(), tmp_buf.size());
-                    continue;
+                // Read each range of cells of v_size length each into tmp_buf and write out
+                for (const auto & range : cell_ranges) {
+                    const size_t range_size = range.second - range.first;
+                    tmp_buf.resize(range_size * v_size_row);
+                    ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), range.first * v_size_row, range_size * v_size_row);
+                    write(tmp_buf.data(), tmp_buf.size());
                 }
+            }
+        } else {
+            // When v is transposed, we also need the element size and get the element ranges from each row
+            const uint32_t kv_size = kv_self.size;
+            for (uint32_t il = 0; il < n_layer; ++il) {
+                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
+                // Write value type
+                const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
+                write(&v_type_i, sizeof(v_type_i));
 
-                // v is not contiguous, copy row by row
-                const size_t v_row_size   = ggml_row_size(kv_self.v_l[il]->type, kv_head);
-                const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
+                // Write element size
+                const uint32_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
+                write(&v_size_el, sizeof(v_size_el));
 
-                tmp_buf.resize(v_row_size);
-                for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
-                    ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), ir*v_row_stride, tmp_buf.size());
-                    data_ctx->write(tmp_buf.data(), tmp_buf.size());
+                // Write GQA embedding size
+                write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
+
+                // For each row, we get the element values of each cell
+                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                    // Read each range of cells of v_size_el length each into tmp_buf and write out
+                    for (const auto & range : cell_ranges) {
+                        const size_t range_size = range.second - range.first;
+                        const size_t src_offset = (range.first + j * kv_size) * v_size_el;
+                        tmp_buf.resize(range_size * v_size_el);
+                        ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size());
+                        write(tmp_buf.data(), tmp_buf.size());
+                    }
                 }
             }
-            GGML_ASSERT(kv_buf_size == data_ctx->get_size_written() - pre_kv_buf_size);
         }
+    }
 
-        for (uint32_t i = 0; i < kv_head; ++i) {
-            const auto & cell = kv_self.cells[i];
-
-            const llama_pos pos         = cell.pos;
-            const size_t    seq_id_size = cell.seq_id.size();
-
-            data_ctx->write(&pos,         sizeof(pos));
-            data_ctx->write(&seq_id_size, sizeof(seq_id_size));
+    void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) {
+        const struct llama_kv_cache & kv_self = ctx->kv_self;
+        std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
+        uint32_t cell_count = 0;
 
-            for (auto seq_id : cell.seq_id) {
-                data_ctx->write(&seq_id, sizeof(seq_id));
+        // Count the number of cells with the specified seq_id
+        // Find all the ranges of cells with this seq id (or all, when -1)
+        uint32_t cell_range_begin = kv_self.size;
+        for (uint32_t i = 0; i < kv_self.size; ++i) {
+            const auto & cell = kv_self.cells[i];
+            if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
+                ++cell_count;
+                if (cell_range_begin == kv_self.size) {
+                    cell_range_begin = i;
+                }
+            } else {
+                if (cell_range_begin != kv_self.size) {
+                    cell_ranges.emplace_back(cell_range_begin, i);
+                    cell_range_begin = kv_self.size;
+                }
             }
         }
-    }
-}
+        if (cell_range_begin != kv_self.size) {
+            cell_ranges.emplace_back(cell_range_begin, kv_self.size);
+        }
 
-size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst) {
-    llama_data_buffer_context data_ctx(dst);
-    llama_state_get_data_internal(ctx, &data_ctx);
+        // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
+        uint32_t cell_count_check = 0;
+        for (const auto & range : cell_ranges) {
+            cell_count_check += range.second - range.first;
+        }
+        GGML_ASSERT(cell_count == cell_count_check);
 
-    return data_ctx.get_size_written();
-}
+        write(&cell_count, sizeof(cell_count));
 
-// Sets the state reading from the specified source address
-size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
-    llama_synchronize(ctx);
+        write_kv_cache_meta(kv_self, cell_ranges, seq_id);
+        write_kv_cache_data(ctx, cell_ranges);
+    }
+};
 
-    const uint8_t * inp = src;
+struct llama_data_read {
+    virtual const uint8_t * read(size_t size) = 0;
+    virtual void read_to(void * dst, size_t size) = 0;
+    virtual size_t get_size_read() = 0;
+    virtual ~llama_data_read() = default;
 
-    // set rng
-    {
-        size_t rng_size;
-        memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size);
+    void read_string(std::string & str) {
+        uint32_t str_size;
+        read_to(&str_size, sizeof(str_size));
+
+        str.assign((const char *) read(str_size), str_size);
+    }
 
-        GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE);
+    // validate model information
+    void read_model_info(const struct llama_context * ctx) {
+        std::string cur_arch_str = LLM_ARCH_NAMES.at(ctx->model.arch);
+        std::string arch_str;
+        read_string(arch_str);
+        if (cur_arch_str != arch_str) {
+            throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
+        }
+        // TODO: add more info which needs to be identical but which is not verified otherwise
+    }
 
-        std::string rng_str((const char *)inp, rng_size); inp += rng_size;
+    void read_rng(std::mt19937 & rng) {
+        std::string rng_str;
+        read_string(rng_str);
 
         std::istringstream rng_ss(rng_str);
-        rng_ss >> ctx->rng;
+        rng_ss >> rng;
 
-        GGML_ASSERT(!rng_ss.fail());
+        if (rng_ss.fail()) {
+            throw std::runtime_error("failed to load RNG state");
+        }
     }
 
-    // set output ids
-    {
-        size_t n_outputs;
+    void read_output_ids(struct llama_context * ctx) {
         std::vector<int32_t> output_pos;
 
-        memcpy(&n_outputs, inp, sizeof(n_outputs)); inp += sizeof(n_outputs);
+        uint32_t n_outputs;
+        read_to(&n_outputs, sizeof(n_outputs));
 
-        GGML_ASSERT(n_outputs <= llama_output_reserve(*ctx, n_outputs));
+        if (n_outputs > llama_output_reserve(*ctx, n_outputs)) {
+            throw std::runtime_error("could not reserve outputs");
+        }
 
         if (n_outputs) {
             output_pos.resize(n_outputs);
-            memcpy(output_pos.data(), inp, n_outputs * sizeof(int32_t));
-            inp += n_outputs * sizeof(int32_t);
+            read_to(output_pos.data(), n_outputs * sizeof(int32_t));
 
             for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
                 int32_t id = output_pos[i];
-                GGML_ASSERT((uint32_t) id < ctx->cparams.n_batch);
+                if ((uint32_t) id >= ctx->cparams.n_batch) {
+                    throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, ctx->cparams.n_batch));
+                }
                 ctx->output_ids[id] = i;
             }
 
@@ -20072,611 +17625,552 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
         }
     }
 
-    // set logits
-    {
-        size_t logits_size;
-
-        memcpy(&logits_size, inp, sizeof(logits_size)); inp += sizeof(logits_size);
+    void read_logits(struct llama_context * ctx) {
+        uint64_t logits_size;
+        read_to(&logits_size, sizeof(logits_size));
 
-        GGML_ASSERT(ctx->logits_size >= logits_size);
+        if (ctx->logits_size < logits_size) {
+            throw std::runtime_error("logits buffer too small");
+        }
 
         if (logits_size) {
-            memcpy(ctx->logits, inp, logits_size * sizeof(float));
-            inp += logits_size * sizeof(float);
+            read_to(ctx->logits, logits_size * sizeof(float));
         }
     }
 
-    // set embeddings
-    {
-        size_t embeddings_size;
-
-        memcpy(&embeddings_size, inp, sizeof(embeddings_size)); inp += sizeof(embeddings_size);
+    void read_embeddings(struct llama_context * ctx) {
+        uint64_t embeddings_size;
+        read_to(&embeddings_size, sizeof(embeddings_size));
 
-        GGML_ASSERT(ctx->embd_size >= embeddings_size);
+        if (ctx->embd_size < embeddings_size) {
+            throw std::runtime_error("embeddings buffer too small");
+        }
 
         if (embeddings_size) {
-            memcpy(ctx->embd, inp, embeddings_size * sizeof(float));
-            inp += embeddings_size * sizeof(float);
+            read_to(ctx->embd, embeddings_size * sizeof(float));
         }
     }
 
-    // set kv cache
-    {
-        const auto & kv_self = ctx->kv_self;
-        const auto & hparams = ctx->model.hparams;
+    bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id dest_seq_id = -1) {
+        struct llama_kv_cache & kv_self = ctx->kv_self;
 
-        const uint32_t n_layer      = hparams.n_layer;
+        if (dest_seq_id != -1) {
+            // single sequence
 
-        size_t   kv_buf_size;
-        uint32_t kv_head;
-        uint32_t kv_size;
-        uint32_t kv_used;
-        uint32_t v_trans;
+            llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
 
-        memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size);
-        memcpy(&kv_head,     inp, sizeof(kv_head));     inp += sizeof(kv_head);
-        memcpy(&kv_size,     inp, sizeof(kv_size));     inp += sizeof(kv_size);
-        memcpy(&kv_used,     inp, sizeof(kv_used));     inp += sizeof(kv_used);
-        memcpy(&v_trans,     inp, sizeof(v_trans));     inp += sizeof(v_trans);
+            llama_batch batch = llama_batch_init(cell_count, 0, 1);
+            batch.n_tokens = cell_count;
+            for (uint32_t i = 0; i < cell_count; ++i) {
+                llama_pos pos;
+                uint32_t n_seq_id;
 
-        GGML_ASSERT(kv_self.v_trans == (bool) v_trans); // incompatible V transposition
+                read_to(&pos, sizeof(pos));
+                read_to(&n_seq_id, sizeof(n_seq_id));
 
-        if (kv_self.size != kv_size) {
-            // the KV cache needs to be big enough to load all the KV cells from the saved state
-            GGML_ASSERT(kv_self.size >= kv_head);
+                if (n_seq_id != 0) {
+                    LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
+                    return false;
+                }
 
-            LLAMA_LOG_INFO("%s: state contains %d KV cells, was saved with kv_size=%d, but is loaded with kv_size=%d (fine, but different)\n",
-                __func__, kv_head, kv_size, kv_self.size);
-        }
+                batch.pos[i] = pos;
+                batch.n_seq_id[i] = 1;
+                batch.seq_id[i][0] = dest_seq_id;
+            }
+            if (!llama_kv_cache_find_slot(kv_self, batch)) {
+                llama_batch_free(batch);
+                LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
+                return false;
+            }
 
-        llama_kv_cache_clear(ctx);
+            // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
+            // Assume that this is one contiguous block of cells
+            GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
+            GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]);
+            GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
+            GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
+            GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
 
-        if (kv_buf_size) {
-            const size_t pre_kv_buf_size = inp - src;
+            // Cleanup
+            llama_batch_free(batch);
+        } else {
+            // whole KV cache restore
 
-            GGML_ASSERT(kv_self.total_size() >= kv_buf_size);
+            if (cell_count > kv_self.size) {
+                LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
+                return false;
+            }
 
-            for (int il = 0; il < (int) n_layer; ++il) {
-                const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
-                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+            llama_kv_cache_clear(kv_self);
 
-                const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
+            for (uint32_t i = 0; i < cell_count; ++i) {
+                llama_kv_cell & cell = kv_self.cells[i];
 
-                ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size);
-                inp += k_size;
+                llama_pos pos;
+                uint32_t  n_seq_id;
 
-                if (kv_self.recurrent || !kv_self.v_trans) {
-                    // v is contiguous for recurrent models
-                    // TODO: use other tensors for state models than k and v
-                    const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
+                read_to(&pos,      sizeof(pos));
+                read_to(&n_seq_id, sizeof(n_seq_id));
 
-                    ggml_backend_tensor_set(kv_self.v_l[il], inp, 0, v_size);
-                    inp += v_size;
-                    continue;
-                }
+                cell.pos = pos;
 
-                // v is not contiguous, copy row by row
-                const size_t v_row_size   = ggml_row_size(kv_self.v_l[il]->type, kv_head);
-                const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_self.size);
+                for (uint32_t j = 0; j < n_seq_id; ++j) {
+                    llama_seq_id seq_id;
+                    read_to(&seq_id, sizeof(seq_id));
 
-                for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
-                    ggml_backend_tensor_set(kv_self.v_l[il], inp, ir*v_row_stride, v_row_size);
-                    inp += v_row_size;
+                    if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
+                        LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
+                        return false;
+                    }
+
+                    cell.seq_id.insert(seq_id);
                 }
             }
-            GGML_ASSERT(kv_buf_size == inp - src - pre_kv_buf_size);
+
+            kv_self.head = 0;
+            kv_self.used = cell_count;
         }
 
-        ctx->kv_self.head = kv_head;
-        ctx->kv_self.used = kv_used;
+        return true;
+    }
+
+    bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) {
+        const struct llama_hparams & hparams = ctx->model.hparams;
+        struct llama_kv_cache & kv_self = ctx->kv_self;
+        uint32_t v_trans;
+        uint32_t n_layer;
+        read_to(&v_trans, sizeof(v_trans));
+        read_to(&n_layer, sizeof(n_layer));
 
-        for (uint32_t i = 0; i < kv_head; ++i) {
-            llama_pos pos;
-            size_t    seq_id_size;
+        if (n_layer != hparams.n_layer) {
+            LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
+            return false;
+        }
+        if (cell_count > kv_self.size) {
+            LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size);
+            return false;
+        }
+        if (kv_self.v_trans != (bool) v_trans) {
+            LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
+            return false;
+        }
 
-            memcpy(&pos,         inp, sizeof(pos));         inp += sizeof(pos);
-            memcpy(&seq_id_size, inp, sizeof(seq_id_size)); inp += sizeof(seq_id_size);
+        // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
+        for (uint32_t il = 0; il < n_layer; ++il) {
+            const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
 
-            ctx->kv_self.cells[i].pos = pos;
+            // Read type of key
+            int32_t k_type_i_ref;
+            read_to(&k_type_i_ref, sizeof(k_type_i_ref));
+            const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
+            if (k_type_i != k_type_i_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
+                return false;
+            }
 
-            llama_seq_id seq_id;
+            // Read row size of key
+            uint64_t k_size_row_ref;
+            read_to(&k_size_row_ref, sizeof(k_size_row_ref));
+            const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
+            if (k_size_row != k_size_row_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
+                return false;
+            }
 
-            for (size_t j = 0; j < seq_id_size; ++j) {
-                memcpy(&seq_id, inp, sizeof(seq_id)); inp += sizeof(seq_id);
-                ctx->kv_self.cells[i].seq_id.insert(seq_id);
+            if (cell_count) {
+                // Read and set the keys for the whole cell range
+                ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row);
             }
         }
-    }
 
-    const size_t nread    = inp - src;
-    const size_t max_size = llama_state_get_size(ctx);
+        if (!kv_self.v_trans) {
+            for (uint32_t il = 0; il < n_layer; ++il) {
+                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
 
-    GGML_ASSERT(nread <= max_size);
+                // Read type of value
+                int32_t v_type_i_ref;
+                read_to(&v_type_i_ref, sizeof(v_type_i_ref));
+                const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
+                if (v_type_i != v_type_i_ref) {
+                    LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
+                    return false;
+                }
 
-    return nread;
-}
+                // Read row size of value
+                uint64_t v_size_row_ref;
+                read_to(&v_size_row_ref, sizeof(v_size_row_ref));
+                const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
+                if (v_size_row != v_size_row_ref) {
+                    LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
+                    return false;
+                }
 
-static bool llama_state_load_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
-    llama_file file(path_session, "rb");
+                if (cell_count) {
+                    // Read and set the values for the whole cell range
+                    ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row);
+                }
+            }
+        } else {
+            // For each layer, read the values for each cell (transposed)
+            for (uint32_t il = 0; il < n_layer; ++il) {
+                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
 
-    // sanity checks
-    {
-        const uint32_t magic   = file.read_u32();
-        const uint32_t version = file.read_u32();
+                // Read type of value
+                int32_t v_type_i_ref;
+                read_to(&v_type_i_ref, sizeof(v_type_i_ref));
+                const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
+                if (v_type_i != v_type_i_ref) {
+                    LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
+                    return false;
+                }
 
-        if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) {
-            LLAMA_LOG_ERROR("%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
-            return false;
-        }
+                // Read element size of value
+                uint32_t v_size_el_ref;
+                read_to(&v_size_el_ref, sizeof(v_size_el_ref));
+                const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
+                if (v_size_el != v_size_el_ref) {
+                    LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
+                    return false;
+                }
 
-        llama_hparams session_hparams;
-        file.read_raw(&session_hparams, sizeof(llama_hparams));
+                // Read GQA embedding size
+                uint32_t n_embd_v_gqa_ref;
+                read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
+                if (n_embd_v_gqa != n_embd_v_gqa_ref) {
+                    LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
+                    return false;
+                }
 
-        if (session_hparams != ctx->model.hparams) {
-            LLAMA_LOG_INFO("%s : model hparams didn't match from session file!\n", __func__);
-            return false;
+                if (cell_count) {
+                    // For each row in the transposed matrix, read the values for the whole cell range
+                    for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                        const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el;
+                        ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
+                    }
+                }
+            }
         }
+        return true;
     }
 
-    // load the prompt
-    {
-        const uint32_t n_token_count = file.read_u32();
+    void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) {
+        uint32_t cell_count;
+        read_to(&cell_count, sizeof(cell_count));
 
-        if (n_token_count > n_token_capacity) {
-            LLAMA_LOG_ERROR("%s : token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
-            return false;
-        }
+        bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count);
 
-        file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
-        *n_token_count_out = n_token_count;
+        if (!res) {
+            if (seq_id == -1) {
+                llama_kv_cache_clear(ctx);
+            } else {
+                llama_kv_cache_seq_rm(ctx, seq_id, -1, -1);
+            }
+            throw std::runtime_error("failed to restore kv cache");
+        }
     }
+};
 
-    // restore the context state
-    {
-        const size_t n_state_size_cur = file.size - file.tell();
-        const size_t n_state_size_max = llama_state_get_size(ctx);
+struct llama_data_write_dummy : llama_data_write {
+    size_t size_written = 0;
 
-        if (n_state_size_cur > n_state_size_max) {
-            LLAMA_LOG_ERROR("%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur);
-            return false;
-        }
+    llama_data_write_dummy() {}
 
-        std::vector<uint8_t> state_data(n_state_size_max);
-        file.read_raw(state_data.data(), n_state_size_cur);
+    // TODO: avoid unnecessary calls to ggml_backend_tensor_get in a dummy context
 
-        llama_state_set_data(ctx, state_data.data());
+    void write(const void * /* src */, size_t size) override {
+        size_written += size;
     }
 
-    return true;
-}
-
-bool llama_state_load_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
-    try {
-        return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("error loading session file: %s\n", err.what());
-        return false;
+    size_t get_size_written() override {
+        return size_written;
     }
-}
+};
 
-static bool llama_state_save_file_internal(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
-    llama_file file(path_session, "wb");
+struct llama_data_write_buffer : llama_data_write {
+    uint8_t * ptr;
+    size_t buf_size = 0;
+    size_t size_written = 0;
 
-    file.write_u32(LLAMA_SESSION_MAGIC);
-    file.write_u32(LLAMA_SESSION_VERSION);
+    llama_data_write_buffer(uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
 
-    file.write_raw(&ctx->model.hparams, sizeof(llama_hparams));
+    void write(const void * src, size_t size) override {
+        if (size > buf_size) {
+            throw std::runtime_error("unexpectedly reached end of buffer");
+        }
+        memcpy(ptr, src, size);
+        ptr += size;
+        size_written += size;
+        buf_size -= size;
+    }
 
-    // save the prompt
-    file.write_u32((uint32_t) n_token_count);
-    file.write_raw(tokens, sizeof(llama_token) * n_token_count);
+    size_t get_size_written() override {
+        return size_written;
+    }
+};
 
-    // save the context state using stream saving
-    llama_data_file_context data_ctx(&file);
-    llama_state_get_data_internal(ctx, &data_ctx);
+struct llama_data_read_buffer : llama_data_read {
+    const uint8_t * ptr;
+    size_t buf_size = 0;
+    size_t size_read = 0;
 
-    return true;
-}
+    llama_data_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
 
-bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
-    try {
-        return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("error saving session file: %s\n", err.what());
-        return false;
+    const uint8_t * read(size_t size) override {
+        const uint8_t * base_ptr = ptr;
+        if (size > buf_size) {
+            throw std::runtime_error("unexpectedly reached end of buffer");
+        }
+        ptr += size;
+        size_read += size;
+        buf_size -= size;
+        return base_ptr;
     }
-}
 
-size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) {
-    // save the size of size_t as a uint32_t for safety check
-    const size_t size_t_size_size = sizeof(uint32_t);
+    void read_to(void * dst, size_t size) override {
+        memcpy(dst, read(size), size);
+    }
 
-    // other values
-    const size_t s_cell_count_size = sizeof(uint32_t);
-    const size_t s_layer_count_size = sizeof(uint32_t);
-    const size_t n_embd_v_gqa_size = sizeof(uint32_t);
+    size_t get_size_read() override {
+        return size_read;
+    }
+};
 
-    size_t s_cell_count = 0;
-    size_t s_cell_data_size = 0;
-    const auto & kv_self = ctx->kv_self;
-    const auto & hparams = ctx->model.hparams;
+struct llama_data_write_file : llama_data_write {
+    llama_file * file;
+    size_t size_written = 0;
 
-    const uint32_t n_layer = hparams.n_layer;
+    llama_data_write_file(llama_file * f) : file(f) {}
 
-    for (uint32_t i = 0; i < kv_self.size; ++i) {
-        const auto & cell = kv_self.cells[i];
-        if (cell.seq_id.count(seq_id) > 0) {
-            ++s_cell_count;
-            s_cell_data_size += sizeof(llama_pos);
-        }
+    void write(const void * src, size_t size) override {
+        file->write_raw(src, size);
+        size_written += size;
     }
 
-    for (int il = 0; il < (int)n_layer; ++il) {
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
-        const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+    size_t get_size_written() override {
+        return size_written;
+    }
+};
 
-        // types of keys and values
-        s_cell_data_size += sizeof(int32_t) * 2;
-        // k_size_row and v_size_el values of layer
-        s_cell_data_size += sizeof(size_t) * 2;
+struct llama_data_read_file : llama_data_read {
+    llama_file * file;
+    size_t size_read = 0;
+    std::vector<uint8_t> temp_buffer;
 
-        // keys
-        const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
-        s_cell_data_size += k_size_row * s_cell_count;
+    llama_data_read_file(llama_file * f) : file(f) {}
 
-        // values (transposed)
-        const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
-        s_cell_data_size += v_size_el * s_cell_count * n_embd_v_gqa;
+    void read_to(void * dst, size_t size) override {
+        file->read_raw(dst, size);
+        size_read += size;
     }
 
-    const size_t s_total = (
-        size_t_size_size +
-        s_cell_count_size +
-        s_layer_count_size +
-        n_embd_v_gqa_size +
-        s_cell_data_size
-        );
+    const uint8_t * read(size_t size) override {
+        temp_buffer.resize(size);
+        read_to(temp_buffer.data(), size);
+        return temp_buffer.data();
+    }
 
-    return s_total;
-}
+    size_t get_size_read() override {
+        return size_read;
+    }
+};
 
-static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_context & data_ctx, llama_seq_id seq_id) {
+/** copy state data into either a buffer or file depending on the passed in context
+ *
+ * file context:
+ * llama_file file("/path", "wb");
+ * llama_data_write_file data_ctx(&file);
+ * llama_state_get_data_internal(ctx, data_ctx);
+ *
+ * buffer context:
+ * std::vector<uint8_t> buf(max_size, 0);
+ * llama_data_write_buffer data_ctx(buf.data(), max_size);
+ * llama_state_get_data_internal(ctx, data_ctx);
+ *
+*/
+static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx) {
     llama_synchronize(ctx);
 
-    const auto & kv_self = ctx->kv_self;
-    GGML_ASSERT(!kv_self.recurrent); // not implemented
-
-    // Save the size of size_t as a uint32_t for safety check
-    const uint32_t size_t_size = sizeof(size_t);
-    data_ctx.write(&size_t_size, sizeof(size_t_size));
-
-    std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
-    uint32_t cell_count = 0;
-
-    // Count the number of cells with the specified seq_id
-    // Find all the ranges of cells with this seq id
-    {
-        uint32_t cell_range_begin = kv_self.size;
-        for (uint32_t i = 0; i < kv_self.size; ++i) {
-            const auto & cell = kv_self.cells[i];
-            if (cell.has_seq_id(seq_id)) {
-                ++cell_count;
-                if (cell_range_begin == kv_self.size) {
-                    cell_range_begin = i;
-                }
-            }
-            else {
-                if (cell_range_begin != kv_self.size) {
-                    cell_ranges.emplace_back(cell_range_begin, i);
-                    cell_range_begin = kv_self.size;
-                }
-            }
-        }
-        if (cell_range_begin != kv_self.size) {
-            cell_ranges.emplace_back(cell_range_begin, kv_self.size);
-        }
+    data_ctx.write_model_info(ctx);
 
-        // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
-        uint32_t cell_count_check = 0;
-        for (const auto & range : cell_ranges) {
-            cell_count_check += range.second - range.first;
-        }
-        GGML_ASSERT(cell_count == cell_count_check);
-    }
+    data_ctx.write_rng(ctx->sampling.rng);
 
-    // Write the cell count
-    data_ctx.write(&cell_count, sizeof(cell_count));
+    // copy outputs
+    data_ctx.write_output_ids(ctx);
+    data_ctx.write_logits(ctx);
+    data_ctx.write_embeddings(ctx);
 
-    const auto & hparams = ctx->model.hparams;
-    const uint32_t n_layer = hparams.n_layer;
+    data_ctx.write_kv_cache(ctx);
 
-    // Write the layer count
-    data_ctx.write(&n_layer, sizeof(n_layer));
+    return data_ctx.get_size_written();
+}
 
-    // Write n_embd_v_gqa (reference value)
-    {
-        const uint32_t n_embd_v_gqa_ref = hparams.n_embd_v_gqa() + hparams.n_embd_k_s();
-        data_ctx.write(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
+size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst, size_t size) {
+    llama_data_write_buffer data_ctx(dst, size);
+    try {
+        return llama_state_get_data_internal(ctx, data_ctx);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
+        return 0;
     }
+}
 
-    // Iterate the ranges and write all the pos (this is the token position in the prompt)
-    for (const auto & range : cell_ranges) {
-        for (uint32_t i = range.first; i < range.second; ++i) {
-            const auto & cell = kv_self.cells[i];
-            data_ctx.write(&cell.pos, sizeof(cell.pos));
-        }
+// Returns the *actual* size of the state.
+// Intended to be used when saving to state to a buffer.
+size_t llama_state_get_size(struct llama_context * ctx) {
+    llama_data_write_dummy data_ctx;
+    try {
+        return llama_state_get_data_internal(ctx, data_ctx);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
+        return 0;
     }
+}
 
-    // Iterate and write all the keys first, each row is a cell
-    // Get whole range at a time
-    std::vector<uint8_t> tmp_buf;
-    for (int il = 0; il < (int)n_layer; ++il) {
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
-
-        // Write key type
-        const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
-        data_ctx.write(&k_type_i, sizeof(k_type_i));
-
-        // Write row size of key
-        const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
-        data_ctx.write(&k_size_row, sizeof(k_size_row));
-
-        // Read each range of cells of k_size length each into tmp_buf and write out
-        for (const auto & range : cell_ranges) {
-            const size_t range_size = range.second - range.first;
-            tmp_buf.resize(range_size * k_size_row);
-            ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), range.first * k_size_row, range_size * k_size_row);
-            data_ctx.write(tmp_buf.data(), tmp_buf.size());
-        }
-    }
+static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx) {
+    llama_synchronize(ctx);
 
-    // TODO: simplify, reduce copy-paste
-    if (!kv_self.v_trans) {
-        for (int il = 0; il < (int)n_layer; ++il) {
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+    data_ctx.read_model_info(ctx);
 
-            // Write value type
-            const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
-            data_ctx.write(&v_type_i, sizeof(v_type_i));
+    // set rng
+    data_ctx.read_rng(ctx->sampling.rng);
 
-            // Write row size of value
-            const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
-            data_ctx.write(&v_size_row, sizeof(v_size_row));
+    // set outputs
+    data_ctx.read_output_ids(ctx);
+    data_ctx.read_logits(ctx);
+    data_ctx.read_embeddings(ctx);
 
-            // Read each range of cells of v_size length each into tmp_buf and write out
-            for (const auto & range : cell_ranges) {
-                const size_t range_size = range.second - range.first;
-                tmp_buf.resize(range_size * v_size_row);
-                ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), range.first * v_size_row, range_size * v_size_row);
-                data_ctx.write(tmp_buf.data(), tmp_buf.size());
-            }
-        }
-    } else {
-        // For the values, they are transposed, so we also need the element size and get the element ranges from each row
-        const uint32_t kv_size = kv_self.size;
-        for (int il = 0; il < (int)n_layer; ++il) {
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
-
-            // Write value type
-            const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
-            data_ctx.write(&v_type_i, sizeof(v_type_i));
-
-            // Write element size
-            const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
-            data_ctx.write(&v_size_el, sizeof(v_size_el));
-
-            // For each row, we get the element values of each cell
-            for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                // Read each range of cells of v_size_el length each into tmp_buf and write out
-                for (const auto & range : cell_ranges) {
-                    const size_t range_size = range.second - range.first;
-                    const size_t src_offset = (range.first + j * kv_size) * v_size_el;
-                    tmp_buf.resize(range_size * v_size_el);
-                    ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size());
-                    data_ctx.write(tmp_buf.data(), tmp_buf.size());
-                }
-            }
-        }
-    }
+    data_ctx.read_kv_cache(ctx);
 
-    return data_ctx.get_size_written();
+    return data_ctx.get_size_read();
 }
 
-size_t llama_state_seq_get_data(struct llama_context* ctx, uint8_t* dst, llama_seq_id seq_id) {
-    llama_data_buffer_context data_ctx(dst);
-    return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
+// Sets the state reading from the specified source address
+size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src, size_t size) {
+    llama_data_read_buffer data_ctx(src, size);
+    try {
+        return llama_state_set_data_internal(ctx, data_ctx);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
+        return 0;
+    }
 }
 
-size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) {
-    llama_synchronize(ctx);
-
-    auto & kv_self = ctx->kv_self;
-    GGML_ASSERT(!kv_self.recurrent); // not implemented
-
-    // Wipe the slot
-    llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
+static bool llama_state_load_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
+    llama_file file(path_session, "rb");
 
-    const uint8_t * inp = src;
+    // sanity checks
+    {
+        const uint32_t magic   = file.read_u32();
+        const uint32_t version = file.read_u32();
 
-    // Read size of size_t
-    uint32_t size_t_size;
-    memcpy(&size_t_size, inp, sizeof(size_t_size));
-    inp += sizeof(size_t_size);
-    if (size_t_size != sizeof(size_t)) {
-        LLAMA_LOG_ERROR("%s: size_t size mismatch\n", __func__);
-        return 0;
+        if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) {
+            LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
+            return false;
+        }
     }
 
-    // Read the cell count
-    uint32_t cell_count;
-    memcpy(&cell_count, inp, sizeof(cell_count));
-    inp += sizeof(cell_count);
+    // load the prompt
+    {
+        const uint32_t n_token_count = file.read_u32();
 
-    // Read the layer count
-    uint32_t n_layer_ref;
-    memcpy(&n_layer_ref, inp, sizeof(n_layer_ref));
-    inp += sizeof(n_layer_ref);
+        if (n_token_count > n_token_capacity) {
+            LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
+            return false;
+        }
 
-    // Read n_embd_v_gqa
-    uint32_t n_embd_v_gqa_ref;
-    memcpy(&n_embd_v_gqa_ref, inp, sizeof(n_embd_v_gqa_ref));
-    inp += sizeof(n_embd_v_gqa_ref);
+        file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
+        *n_token_count_out = n_token_count;
+    }
 
-    // Sanity check model compatibility
-    const auto & hparams = ctx->model.hparams;
-    const uint32_t n_layer = hparams.n_layer;
+    // restore the context state
+    {
+        const size_t n_state_size_cur = file.size - file.tell();
 
-    if (n_layer != n_layer_ref) {
-        LLAMA_LOG_ERROR("%s: mismatched n_layer (%d != %d)\n", __func__, n_layer, n_layer_ref);
-        return 0;
-    }
+        llama_data_read_file data_ctx(&file);
+        const size_t n_read = llama_state_set_data_internal(ctx, data_ctx);
 
-    if (hparams.n_embd_v_gqa() != n_embd_v_gqa_ref) {
-        LLAMA_LOG_ERROR("%s: mismatched n_embd_v_gqa (%d != %d)\n", __func__, hparams.n_embd_v_gqa(), n_embd_v_gqa_ref);
-        return 0;
+        if (n_read != n_state_size_cur) {
+            LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read);
+            return false;
+        }
     }
+    return true;
+}
 
-    // Allocate the new cells for the slot
-    if (cell_count) {
-        llama_batch batch = llama_batch_init(cell_count, 0, 1);
-        batch.n_tokens = cell_count;
-        for (uint32_t i = 0; i < cell_count; ++i) {
-            llama_pos pos;
-            memcpy(&pos, inp, sizeof(pos));
-            inp += sizeof(pos);
+bool llama_state_load_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
+    try {
+        return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what());
+        return false;
+    }
+}
 
-            batch.pos[i] = pos;
-            batch.n_seq_id[i] = 1;
-            batch.seq_id[i][0] = dest_seq_id;
-        }
-        if (!llama_kv_cache_find_slot(kv_self, batch)) {
-            llama_batch_free(batch);
-            LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
-            return 0;
-        }
+static bool llama_state_save_file_internal(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
+    llama_file file(path_session, "wb");
 
-        // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
-        // Assume that this is one contiguous block of cells
-        GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
-        GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]);
-        GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
-        GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
-        GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
+    file.write_u32(LLAMA_SESSION_MAGIC);
+    file.write_u32(LLAMA_SESSION_VERSION);
 
-        // Cleanup
-        llama_batch_free(batch);
-    }
+    // save the prompt
+    file.write_u32((uint32_t) n_token_count);
+    file.write_raw(tokens, sizeof(llama_token) * n_token_count);
 
-    const uint32_t kv_size = kv_self.size;
-    const uint32_t kv_head = kv_self.head;
-
-    // For each layer, read the keys for each cell, one row is one cell, read as one contiguous blo
-    for (int il = 0; il < (int)n_layer; ++il) {
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
-
-        // Read type of key
-        int32_t k_type_i_ref;
-        memcpy(&k_type_i_ref, inp, sizeof(k_type_i_ref));
-        inp += sizeof(k_type_i_ref);
-        const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
-        if (k_type_i != k_type_i_ref) {
-            llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
-            LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
-            return 0;
-        }
+    // save the context state using stream saving
+    llama_data_write_file data_ctx(&file);
+    llama_state_get_data_internal(ctx, data_ctx);
 
-        // Read row size of key
-        size_t k_size_row_ref;
-        memcpy(&k_size_row_ref, inp, sizeof(k_size_row_ref));
-        inp += sizeof(k_size_row_ref);
-        const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
-        if (k_size_row != k_size_row_ref) {
-            llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
-            LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, k_size_row_ref, il);
-            return 0;
-        }
+    return true;
+}
 
-        if (cell_count) {
-            // Read and set the keys for the whole cell range
-            ggml_backend_tensor_set(kv_self.k_l[il], inp, kv_head * k_size_row, cell_count * k_size_row);
-            inp += cell_count * k_size_row;
-        }
+bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
+    try {
+        return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what());
+        return false;
     }
+}
 
-    // TODO: simplify, reduce copy-paste
-    if (!kv_self.v_trans) {
-        for (int il = 0; il < (int)n_layer; ++il) {
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) {
+    llama_synchronize(ctx);
 
-            // Read type of value
-            int32_t v_type_i_ref;
-            memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
-            inp += sizeof(v_type_i_ref);
-            const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
-            if (v_type_i != v_type_i_ref) {
-                llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
-                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
-                return 0;
-            }
+    data_ctx.write_kv_cache(ctx, seq_id);
 
-            // Read row size of value
-            size_t v_size_row_ref;
-            memcpy(&v_size_row_ref, inp, sizeof(v_size_row_ref));
-            inp += sizeof(v_size_row_ref);
-            const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
-            if (v_size_row != v_size_row_ref) {
-                llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
-                LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, v_size_row_ref, il);
-                return 0;
-            }
+    return data_ctx.get_size_written();
+}
 
-            if (cell_count) {
-                // Read and set the values for the whole cell range
-                ggml_backend_tensor_set(kv_self.v_l[il], inp, kv_head * v_size_row, cell_count * v_size_row);
-                inp += cell_count * v_size_row;
-            }
-        }
-    } else {
-        // For each layer, read the values for each cell (transposed)
-        for (int il = 0; il < (int)n_layer; ++il) {
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
-
-            // Read type of value
-            int32_t v_type_i_ref;
-            memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
-            inp += sizeof(v_type_i_ref);
-            const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
-            if (v_type_i != v_type_i_ref) {
-                llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
-                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
-                return 0;
-            }
-
-            // Read element size of value
-            size_t v_size_el_ref;
-            memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref));
-            inp += sizeof(v_size_el_ref);
-            const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
-            if (v_size_el != v_size_el_ref) {
-                llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
-                LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il);
-                return 0;
-            }
+size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id) {
+    llama_data_write_dummy data_ctx;
+    return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
+}
 
-            if (cell_count) {
-                // For each row in the transposed matrix, read the values for the whole cell range
-                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                    const size_t dst_offset = (kv_head + j * kv_size) * v_size_el;
-                    ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el);
-                    inp += cell_count * v_size_el;
-                }
-            }
-        }
+size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
+    llama_data_write_buffer data_ctx(dst, size);
+    try {
+        return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error saving sequence state: %s\n", __func__, err.what());
+        return 0;
     }
+}
+
+static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) {
+    llama_synchronize(ctx);
 
-    const size_t nread = inp - src;
+    data_ctx.read_kv_cache(ctx, dest_seq_id);
 
-    return nread;
+    return data_ctx.get_size_read();
+}
+
+size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id) {
+    llama_data_read_buffer data_ctx(src, size);
+    try {
+        return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error loading sequence state: %s\n", __func__, err.what());
+        return 0;
+    }
 }
 
 static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
@@ -20686,11 +18180,11 @@ static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, con
     file.write_u32(LLAMA_STATE_SEQ_VERSION);
 
     // save the prompt
-    file.write_u32((uint32_t)n_token_count);
+    file.write_u32((uint32_t) n_token_count);
     file.write_raw(tokens, sizeof(llama_token) * n_token_count);
 
     // save the context state using stream saving
-    llama_data_file_context data_ctx(&file);
+    llama_data_write_file data_ctx(&file);
     llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
 
     const size_t res = file.tell();
@@ -20728,9 +18222,8 @@ static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, con
     // restore the context state
     {
         const size_t state_size = file.size - file.tell();
-        std::vector<uint8_t> state_data(state_size);
-        file.read_raw(state_data.data(), state_size);
-        const size_t nread = llama_state_seq_set_data(ctx, state_data.data(), dest_seq_id);
+        llama_data_read_file data_ctx(&file);
+        const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
         if (!nread) {
             LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
             return 0;
@@ -20746,7 +18239,7 @@ size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepa
     try {
         return llama_state_seq_save_file_internal(ctx, filepath, seq_id, tokens, n_token_count);
     } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("error saving sequence state file: %s\n", err.what());
+        LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what());
         return 0;
     }
 }
@@ -20755,7 +18248,7 @@ size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepa
     try {
         return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out);
     } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("error loading sequence state file: %s\n", err.what());
+        LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what());
         return 0;
     }
 }
@@ -20927,7 +18420,7 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
 #ifndef NDEBUG
-        GGML_ASSERT(false);
+        GGML_ABORT("fatal error");
 #endif
         return nullptr;
     }
@@ -20972,7 +18465,7 @@ float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
 #ifndef NDEBUG
-        GGML_ASSERT(false);
+        GGML_ABORT("fatal error");
 #endif
         return nullptr;
     }
@@ -20989,79 +18482,81 @@ float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id
     return it->second.data();
 }
 
+//
+// vocab
+//
+
 const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
-    GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return model->vocab.id_to_token[token].text.c_str();
+    return llama_token_get_text_impl(model->vocab, token);
 }
 
 float llama_token_get_score(const struct llama_model * model, llama_token token) {
-    GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return model->vocab.id_to_token[token].score;
+    return llama_token_get_score_impl(model->vocab, token);
 }
 
-llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token) {
-    GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return model->vocab.id_to_token[token].attr;
+enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token) {
+    return llama_token_get_attr_impl(model->vocab, token);
 }
 
 bool llama_token_is_eog(const struct llama_model * model, llama_token token) {
-    return token != -1 && (
-        token == llama_token_eos(model) ||
-        token == llama_token_eot(model)
-    );
+    return llama_token_is_eog_impl(model->vocab, token);
 }
 
 bool llama_token_is_control(const struct llama_model * model, llama_token token) {
-    return llama_is_control_token(model->vocab, token);
+    return llama_token_is_control_impl(model->vocab, token);
 }
 
 llama_token llama_token_bos(const struct llama_model * model) {
-    return model->vocab.special_bos_id;
+    return llama_token_bos_impl(model->vocab);
 }
 
 llama_token llama_token_eos(const struct llama_model * model) {
-    return model->vocab.special_eos_id;
+    return llama_token_eos_impl(model->vocab);
 }
 
 llama_token llama_token_cls(const struct llama_model * model) {
-    return model->vocab.special_cls_id;
+    return llama_token_cls_impl(model->vocab);
 }
 
 llama_token llama_token_sep(const struct llama_model * model) {
-    return model->vocab.special_sep_id;
+    return llama_token_sep_impl(model->vocab);
+}
+
+llama_token llama_token_nl (const struct llama_model * model) {
+    return llama_token_nl_impl(model->vocab);
 }
 
-llama_token llama_token_nl(const struct llama_model * model) {
-    return model->vocab.linefeed_id;
+llama_token llama_token_pad(const struct llama_model * model) {
+    return llama_token_pad_impl(model->vocab);
 }
 
 int32_t llama_add_bos_token(const struct llama_model * model) {
-    return model->vocab.tokenizer_add_bos;
+    return llama_add_bos_token_impl(model->vocab);
 }
 
 int32_t llama_add_eos_token(const struct llama_model * model) {
-    return model->vocab.tokenizer_add_eos;
+    return llama_add_eos_token_impl(model->vocab);
 }
 
 llama_token llama_token_prefix(const struct llama_model * model) {
-    return model->vocab.special_prefix_id;
+    return llama_token_prefix_impl(model->vocab);
 }
 
 llama_token llama_token_middle(const struct llama_model * model) {
-    return model->vocab.special_middle_id;
+    return llama_token_middle_impl(model->vocab);
 }
 
 llama_token llama_token_suffix(const struct llama_model * model) {
-    return model->vocab.special_suffix_id;
+    return llama_token_suffix_impl(model->vocab);
 }
 
 llama_token llama_token_eot(const struct llama_model * model) {
-    return model->vocab.special_eot_id;
+    return llama_token_eot_impl(model->vocab);
 }
 
-llama_token llama_token_pad(const struct llama_model * model) {
-    return model->vocab.special_pad_id;
-}
+//
+// tokenization
+//
 
 int32_t llama_tokenize(
     const struct llama_model * model,
@@ -21071,229 +18566,33 @@ int32_t llama_tokenize(
                      int32_t   n_tokens_max,
                         bool   add_special,
                         bool   parse_special) {
-    auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_special, parse_special);
-    if (n_tokens_max < (int) res.size()) {
-        // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
-        return -((int) res.size());
-    }
-
-    for (size_t i = 0; i < res.size(); i++) {
-        tokens[i] = res[i];
-    }
-
-    return res.size();
-}
-
-static std::string llama_decode_text(const std::string & text) {
-    std::string decoded_text;
-
-    const auto cpts = unicode_cpts_from_utf8(text);
-    for (const auto cpt : cpts) {
-        const auto utf8 = unicode_cpt_to_utf8(cpt);
-        try {
-            decoded_text += unicode_utf8_to_byte(utf8);
-        } catch (const std::out_of_range & /*e*/) {
-            decoded_text += "[UNK_BYTE_0x";
-            for (const auto c : utf8) {
-                decoded_text += format("%02x", (uint8_t) c);
-            }
-            decoded_text += text + "]";
-        }
-    }
-
-    return decoded_text;
+    return llama_tokenize_impl(model->vocab, text, text_len, tokens, n_tokens_max, add_special, parse_special);
 }
 
-// does not write null-terminator to buf
-int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) {
-    // ref: https://github.com/ggerganov/llama.cpp/pull/7587#discussion_r1620983843
-    static const int attr_special = LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_CONTROL;
-    const llama_token_attr attr = llama_token_get_attr(model, token);
-    if (!special && (attr & attr_special)) {
-        return 0;
-    }
-
-    // copy piece chars to output text buffer
-    // skip up to 'lstrip' leading spaces before copying
-    auto _try_copy = [=] (const char * token, size_t size) -> int32_t {
-        for (int32_t i = 0; i < lstrip && size && *token == ' '; ++i) {
-            token++;
-            size--;
-        }
-        if (length < (int32_t)size) {
-            return (int32_t) -size;
-        }
-        memcpy(buf, token, size);
-        return (int32_t) size;
-    };
-
-    // if we have a cache - use it
-    {
-        const auto & cache = model->vocab.cache_token_to_piece;
-
-        if (!cache.empty()) {
-            const auto & result = cache.at(token);
-            return _try_copy(result.data(), result.size());
-        }
-    }
-
-    if (0 <= token && token < llama_n_vocab(model)) {
-        const std::string & token_text = model->vocab.id_to_token[token].text;
-        switch (llama_vocab_get_type(model->vocab)) {
-            case LLAMA_VOCAB_TYPE_WPM:
-            case LLAMA_VOCAB_TYPE_SPM:
-            case LLAMA_VOCAB_TYPE_UGM: {
-                // NOTE: we accept all unsupported token types,
-                // suppressing them like CONTROL tokens.
-                if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
-                    return _try_copy(token_text.data(), token_text.size());
-                } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
-                    std::string result = token_text;
-                    llama_unescape_whitespace(result);
-                    return _try_copy(result.data(), result.size());
-                } else if (attr & LLAMA_TOKEN_ATTR_BYTE) {
-                    char byte = (char) llama_token_to_byte(model->vocab, token);
-                    return _try_copy((char*) &byte, 1);
-                }
-                break;
-            }
-            case LLAMA_VOCAB_TYPE_BPE: {
-                // NOTE: we accept all unsupported token types,
-                // suppressing them like CONTROL tokens.
-                if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
-                    return _try_copy(token_text.data(), token_text.size());
-                } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
-                    std::string result = llama_decode_text(token_text);
-                    return _try_copy(result.data(), result.size());
-                }
-                break;
-            }
-            default:
-                GGML_ASSERT(false);
-        }
-    }
-    return 0;
+int32_t llama_token_to_piece(
+    const struct llama_model * model,
+                 llama_token   token,
+                        char * buf,
+                     int32_t   length,
+                     int32_t   lstrip,
+                        bool   special) {
+    return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special);
 }
 
 int32_t llama_detokenize(
-        const struct llama_model * model,
-               const llama_token * tokens,
-                         int32_t   n_tokens,
-                            char * text,
-                         int32_t   text_len_max,
-                            bool   remove_special,
-                            bool   unparse_special) {
-    int32_t avail = text_len_max;
-    int32_t total = 0;
-
-    // remove the leading space
-    bool remove_space = model->vocab.tokenizer_add_space_prefix;
-
-    if (remove_special && model->vocab.tokenizer_add_bos) {
-        if (n_tokens > 0 && tokens[0] == model->vocab.special_bos_id) {
-            remove_space = false;
-            n_tokens--;
-            tokens++;
-        }
-    }
-
-    if (remove_special && model->vocab.tokenizer_add_eos) {
-        if (n_tokens > 0 && tokens[n_tokens-1] == model->vocab.special_eos_id) {
-            n_tokens--;
-        }
-    }
-
-    for (int32_t i = 0; i < n_tokens; ++i) {
-        GGML_ASSERT(avail >= 0);
-        int32_t n_chars = llama_token_to_piece(model, tokens[i], text, avail, remove_space, unparse_special);
-        remove_space = false;
-        if (n_chars < 0) {
-            avail = 0;
-            total -= n_chars;
-        } else if (n_chars > 0) {
-            avail -= n_chars;
-            text  += n_chars;
-            total += n_chars;
-        }
-    }
-
-    if (total > text_len_max) {
-        return -total;
-    }
-
-    if (model->vocab.tokenizer_clean_spaces) {
-        text -= total;  // restart text
-
-        // first pass: characters ?!.,  //TODO: where do these characters come from?
-        const int32_t total1 = total;
-        total = total ? 1 : 0;
-        for (int32_t i = 1; i < total1; ++i) {
-            const char x = text[i];
-            if (text[i - 1] == ' ') {
-                if (x == '?' || x == '!' || x == '.' || x == ',') {  // " ?", " !", " .", " ,"
-                    total--;  // remove space
-                }
-            }
-            text[total++] = x;
-        }
-
-        // second pass: strip single apostrophe between spaces
-        const int32_t total2 = total;
-        total = total ? 1 : 0;
-        for (int32_t i = 1; i < total2; ++i) {
-            const char x = text[i];
-            if (x == '\'' && i + 1 < total2 && text[i - 1] == ' ' && text[i + 1] == ' ') {  // " ' "
-                total--;           // remove prev space
-                text[++i] = '\0';  // remove next space
-            }
-            text[total++] = x;
-        }
-
-        // third pass: apostrophe contractions  //NOTE: this makes sense?
-        const int32_t total3 = total;
-        total = total ? 1 : 0;
-        for (int32_t i = 1; i < total3; ++i) {
-            const char x = text[i];
-            if (text[i - 1] == ' ') {
-                if (x == '\'' && i + 1 < total3) {
-                    const char x1 = text[i + 1];
-                    if (x1 == 't' || x1 == 'd') {  // " 't", " 'd"
-                        //total--;  // remove space
-                    } else if (x1 == 's' || x1 == 'm') {  // " 's", " 'm"
-                        total--;  // remove space
-                    } else if (i + 2 < total3) {
-                        const char x2 = text[i + 2];
-                        if ((x1 == 'l' && x2 == 'l')) {  // " 'll"
-                            //total--;  // remove space
-                        } else if ((x1 == 'r' && x2 == 'e') || (x1 == 'v' && x2 == 'e')) {  // " 're", " 've"
-                            total--;  // remove space
-                        } else {
-                            //total--;  // remove space
-                        }
-                    } else {
-                        //total--;  // remove space
-                    }
-                }
-            }
-            text[total++] = x;
-        }
-    }
-
-    return total <= text_len_max ? total : -total;
+    const struct llama_model * model,
+           const llama_token * tokens,
+                     int32_t   n_tokens,
+                        char * text,
+                     int32_t   text_len_max,
+                        bool   remove_special,
+                        bool   unparse_special) {
+    return llama_detokenize_impl(model->vocab, tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
 }
 
-// trim whitespace from the beginning and end of a string
-static std::string trim(const std::string & str) {
-    size_t start = 0;
-    size_t end = str.size();
-    while (start < end && isspace(str[start])) {
-        start += 1;
-    }
-    while (end > start && isspace(str[end - 1])) {
-        end -= 1;
-    }
-    return str.substr(start, end - start);
-}
+//
+// chat templates
+//
 
 // Simple version of "llama_apply_chat_template" that only works with strings
 // This function uses heuristic checks to determine commonly used template. It is not a jinja parser.
@@ -21500,7 +18799,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<|assistant|>";
         }
-    } else if (tmpl == "chaglm4" || tmpl_contains("[gMASK]<sop>")) {
+    } else if (tmpl == "chatglm4" || tmpl_contains("[gMASK]<sop>")) {
         ss << "[gMASK]" << "<sop>";
         for (auto message : chat) {
             std::string role(message->role);
@@ -21509,12 +18808,12 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<|assistant|>";
         }
-    } else if (tmpl == "minicpm" || tmpl_contains(u8"<用户>")) {
+    } else if (tmpl == "minicpm" || tmpl_contains(LU8("<用户>"))) {
         // MiniCPM-3B-OpenHermes-2.5-v2-GGUF
         for (auto message : chat) {
             std::string role(message->role);
             if (role == "user") {
-                ss << u8"<用户>";
+                ss << LU8("<用户>");
                 ss << trim(message->content);
                 ss << "<AI>";
             } else {
@@ -21530,7 +18829,7 @@ static int32_t llama_chat_apply_template_internal(
             } else if (role == "user") {
                 ss << "User: " << message->content << "\n\n";
             } else if (role == "assistant") {
-                ss << "Assistant: " << message->content << u8"<|end▁of▁sentence|>";
+                ss << "Assistant: " << message->content << LU8("<|end▁of▁sentence|>");
             }
         }
         if (add_ass) {
@@ -21544,7 +18843,7 @@ static int32_t llama_chat_apply_template_internal(
     return dest.size();
 }
 
-LLAMA_API int32_t llama_chat_apply_template(
+int32_t llama_chat_apply_template(
                 const struct llama_model * model,
                               const char * tmpl,
          const struct llama_chat_message * chat,
@@ -21585,7 +18884,126 @@ LLAMA_API int32_t llama_chat_apply_template(
     return res;
 }
 
-LLAMA_API int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
+//
+// grammar
+//
+
+struct llama_grammar * llama_grammar_init(
+        const llama_grammar_element ** rules,
+        size_t    n_rules,
+        size_t    start_rule_index) {
+    return llama_grammar_init_impl(rules, n_rules, start_rule_index);
+}
+
+void llama_grammar_free(struct llama_grammar * grammar) {
+    llama_grammar_free_impl(grammar);
+}
+
+struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) {
+    return llama_grammar_copy_impl(grammar);
+}
+
+void llama_grammar_sample(
+      const struct llama_grammar * grammar,
+      const struct llama_context * ctx,
+          llama_token_data_array * candidates) {
+    llama_grammar_sample_impl(grammar, &ctx->model.vocab, &ctx->sampling, candidates);
+}
+
+void llama_sample_grammar(
+            struct llama_context * ctx,
+          llama_token_data_array * candidates,
+      const struct llama_grammar * grammar) {
+    llama_grammar_sample(grammar, ctx, candidates);
+}
+
+void llama_grammar_accept_token(
+            struct llama_grammar * grammar,
+            struct llama_context * ctx,
+                     llama_token   token) {
+    llama_grammar_accept_token_impl(grammar, &ctx->model.vocab, &ctx->sampling, token);
+}
+
+//
+// sampling
+//
+
+void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) {
+    llama_set_rng_seed_impl(&ctx->sampling, seed);
+}
+
+void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) {
+    llama_sample_softmax_impl(ctx ? &ctx->sampling : nullptr, candidates);
+}
+
+void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
+    llama_sample_top_k_impl(ctx ? &ctx->sampling : nullptr, candidates, k, min_keep);
+}
+
+void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
+    llama_sample_top_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
+}
+
+void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
+    llama_sample_min_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
+}
+
+void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
+    llama_sample_tail_free_impl(ctx ? &ctx->sampling : nullptr, candidates, z, min_keep);
+}
+
+void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
+    llama_sample_typical_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
+}
+
+void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) {
+    llama_sample_entropy_impl(ctx ? &ctx->sampling : nullptr, candidates_p, min_temp, max_temp, exponent_val);
+}
+
+void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) {
+    llama_sample_temp_impl(ctx ? &ctx->sampling : nullptr, candidates_p, temp);
+}
+
+void llama_sample_repetition_penalties(
+            struct llama_context * ctx,
+          llama_token_data_array * candidates,
+               const llama_token * last_tokens,
+                          size_t   penalty_last_n,
+                           float   penalty_repeat,
+                           float   penalty_freq,
+                           float   penalty_present) {
+    llama_sample_repetition_penalties_impl(ctx ? &ctx->sampling : nullptr, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
+}
+
+void llama_sample_apply_guidance(
+          struct llama_context * ctx,
+                         float * logits,
+                         float * logits_guidance,
+                         float   scale) {
+    llama_sample_apply_guidance_impl(&ctx->sampling, logits, logits_guidance, scale);
+}
+
+llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
+    return llama_sample_token_mirostat_impl(&ctx->sampling, candidates, tau, eta, m, mu);
+}
+
+llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) {
+    return llama_sample_token_mirostat_v2_impl(ctx ? &ctx->sampling : nullptr, candidates, tau, eta, mu);
+}
+
+llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) {
+    return llama_sample_token_greedy_impl(ctx ? &ctx->sampling : nullptr, candidates);
+}
+
+llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) {
+    return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, rng);
+}
+
+llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
+    return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, ctx->sampling.rng);
+}
+
+int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
     static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf";
     if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) {
         return strlen(split_path);
@@ -21614,11 +19032,11 @@ struct llama_timings llama_get_timings(struct llama_context * ctx) {
         /*.t_start_ms  =*/ 1e-3 * ctx->t_start_us,
         /*.t_end_ms    =*/ 1.00 * ggml_time_ms(),
         /*.t_load_ms   =*/ 1e-3 * ctx->t_load_us,
-        /*.t_sample_ms =*/ 1e-3 * ctx->t_sample_us,
+        /*.t_sample_ms =*/ 1e-3 * ctx->sampling.t_sample_us,
         /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us,
         /*.t_eval_ms   =*/ 1e-3 * ctx->t_eval_us,
 
-        /*.n_sample =*/ std::max(1, ctx->n_sample),
+        /*.n_sample =*/ std::max(1, ctx->sampling.n_sample),
         /*.n_p_eval =*/ std::max(0, ctx->n_p_eval),
         /*.n_eval   =*/ std::max(1, ctx->n_eval),
     };
@@ -21641,10 +19059,11 @@ void llama_print_timings(struct llama_context * ctx) {
 }
 
 void llama_reset_timings(struct llama_context * ctx) {
-    ctx->t_start_us = ggml_time_us();
-    ctx->t_sample_us = ctx->n_sample = 0;
+    ctx->t_start_us  = ggml_time_us();
     ctx->t_eval_us   = ctx->n_eval   = 0;
     ctx->t_p_eval_us = ctx->n_p_eval = 0;
+
+    ctx->sampling.reset_timings();
 }
 
 const char * llama_print_system_info(void) {
@@ -21670,11 +19089,7 @@ const char * llama_print_system_info(void) {
     s += "SSSE3 = "       + std::to_string(ggml_cpu_has_ssse3())       + " | ";
     s += "VSX = "         + std::to_string(ggml_cpu_has_vsx())         + " | ";
     s += "MATMUL_INT8 = " + std::to_string(ggml_cpu_has_matmul_int8()) + " | ";
-#ifdef GGML_USE_LLAMAFILE
-    s += "LLAMAFILE = 1 | ";
-#else
-    s += "LLAMAFILE = 0 | ";
-#endif
+    s += "LLAMAFILE = "   + std::to_string(ggml_cpu_has_llamafile())   + " | ";
 
     return s.c_str();
 }
@@ -21691,20 +19106,20 @@ void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) {
     fprintf(stream, "mst_p_eval: %.2f  # ms / token during prompt processing\n",
             1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval);
     fprintf(stream, "mst_sample: %.2f  # ms / token during sampling\n",
-            1.0e-3 * ctx->t_sample_us / ctx->n_sample);
+            1.0e-3 * ctx->sampling.t_sample_us / ctx->sampling.n_sample);
     fprintf(stream, "n_eval: %d  # number of tokens generated (excluding the first one)\n", ctx->n_eval);
     fprintf(stream, "n_p_eval: %d  # number of tokens processed in batches at the beginning\n", ctx->n_p_eval);
-    fprintf(stream, "n_sample: %d  # number of sampled tokens\n", ctx->n_sample);
+    fprintf(stream, "n_sample: %d  # number of sampled tokens\n", ctx->sampling.n_sample);
     fprintf(stream, "t_eval_us: %" PRId64 "  # total microseconds spent generating tokens\n", ctx->t_eval_us);
     fprintf(stream, "t_load_us: %" PRId64 "  # total microseconds spent loading the model\n", ctx->t_load_us);
     fprintf(stream, "t_p_eval_us: %" PRId64 "  # total microseconds spent prompt processing\n", ctx->t_p_eval_us);
-    fprintf(stream, "t_sample_us: %" PRId64 "  # total microseconds spent sampling\n", ctx->t_sample_us);
+    fprintf(stream, "t_sample_us: %" PRId64 "  # total microseconds spent sampling\n", ctx->sampling.t_sample_us);
     fprintf(stream, "ts_eval: %.2f  # tokens / second during generation\n",
             1.0e6 * ctx->n_eval / ctx->t_eval_us);
     fprintf(stream, "ts_p_eval: %.2f  # tokens / second during prompt processing\n",
             1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us);
     fprintf(stream, "ts_sample: %.2f  # tokens / second during sampling\n",
-            1.0e6 * ctx->n_sample / ctx->t_sample_us);
+            1.0e6 * ctx->sampling.n_sample / ctx->sampling.t_sample_us);
 }
 
 // For internal test use
@@ -21721,6 +19136,8 @@ void llama_log_set(ggml_log_callback log_callback, void * user_data) {
     ggml_backend_metal_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
 #elif defined(GGML_USE_CUDA)
     ggml_backend_cuda_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
+#elif defined(GGML_USE_CANN)
+    ggml_backend_cann_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
 #endif
 }
 
@@ -21741,14 +19158,14 @@ static void llama_log_internal_v(ggml_log_level level, const char * format, va_l
     va_end(args_copy);
 }
 
-static void llama_log_internal(ggml_log_level level, const char * format, ...) {
+void llama_log_internal(ggml_log_level level, const char * format, ...) {
     va_list args;
     va_start(args, format);
     llama_log_internal_v(level, format, args);
     va_end(args);
 }
 
-static void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
+void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
     (void) level;
     (void) user_data;
     fputs(text, stderr);
index bb4b05ba636711618429afaa2c65c0cd273649d5..66c266298e86f0204f6d466368abd8d86465a99c 100644 (file)
 
 #define LLAMA_DEFAULT_SEED 0xFFFFFFFF
 
-#define LLAMA_MAX_RNG_STATE (64*1024)
-
 #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
 #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
 #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
 
 #define LLAMA_SESSION_MAGIC   LLAMA_FILE_MAGIC_GGSN
-#define LLAMA_SESSION_VERSION 6
+#define LLAMA_SESSION_VERSION 8
 
 #define LLAMA_STATE_SEQ_MAGIC   LLAMA_FILE_MAGIC_GGSQ
-#define LLAMA_STATE_SEQ_VERSION 1
+#define LLAMA_STATE_SEQ_VERSION 2
 
 #ifdef __cplusplus
 extern "C" {
@@ -92,6 +90,9 @@ extern "C" {
         LLAMA_VOCAB_PRE_TYPE_CHATGLM4       = 17,
         LLAMA_VOCAB_PRE_TYPE_VIKING         = 18,
         LLAMA_VOCAB_PRE_TYPE_JAIS           = 19,
+        LLAMA_VOCAB_PRE_TYPE_TEKKEN         = 20,
+        LLAMA_VOCAB_PRE_TYPE_SMOLLM         = 21,
+        LLAMA_VOCAB_PRE_TYPE_CODESHELL      = 22,
     };
 
     // note: these values should be synchronized with ggml_rope
@@ -133,7 +134,7 @@ extern "C" {
         LLAMA_FTYPE_MOSTLY_F16           = 1,  // except 1d tensors
         LLAMA_FTYPE_MOSTLY_Q4_0          = 2,  // except 1d tensors
         LLAMA_FTYPE_MOSTLY_Q4_1          = 3,  // except 1d tensors
-        LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4,  // tok_embeddings.weight and output.weight are F16
+        // LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4,  // tok_embeddings.weight and output.weight are F16
         // LLAMA_FTYPE_MOSTLY_Q4_2       = 5,  // support has been removed
         // LLAMA_FTYPE_MOSTLY_Q4_3       = 6,  // support has been removed
         LLAMA_FTYPE_MOSTLY_Q8_0          = 7,  // except 1d tensors
@@ -162,6 +163,9 @@ extern "C" {
         LLAMA_FTYPE_MOSTLY_IQ4_XS        = 30, // except 1d tensors
         LLAMA_FTYPE_MOSTLY_IQ1_M         = 31, // except 1d tensors
         LLAMA_FTYPE_MOSTLY_BF16          = 32, // except 1d tensors
+        LLAMA_FTYPE_MOSTLY_Q4_0_4_4      = 33, // except 1d tensors
+        LLAMA_FTYPE_MOSTLY_Q4_0_4_8      = 34, // except 1d tensors
+        LLAMA_FTYPE_MOSTLY_Q4_0_8_8      = 35, // except 1d tensors
 
         LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
     };
@@ -341,7 +345,7 @@ extern "C" {
         int32_t nthread;                     // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
         enum llama_ftype ftype;              // quantize to this llama_ftype
         enum ggml_type output_tensor_type;   // output tensor type
-        enum ggml_type token_embedding_type; // itoken embeddings tensor type
+        enum ggml_type token_embedding_type; // token embeddings tensor type
         bool allow_requantize;               // allow quantizing non-f32/f16 tensors
         bool quantize_output_tensor;         // quantize output.weight
         bool only_copy;                      // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
@@ -408,6 +412,9 @@ extern "C" {
         const char * content;
     } llama_chat_message;
 
+    // lora adapter
+    struct llama_lora_adapter;
+
     // Helpers for getting default parameters
     LLAMA_API struct llama_model_params llama_model_default_params(void);
     LLAMA_API struct llama_context_params llama_context_default_params(void);
@@ -507,18 +514,32 @@ extern "C" {
             const char * fname_out,
             const llama_model_quantize_params * params);
 
-    // Apply a LoRA adapter to a loaded model
-    // path_base_model is the path to a higher quality model to use as a base for
-    // the layers modified by the adapter. Can be NULL to use the current loaded model.
-    // The model needs to be reloaded before applying a new adapter, otherwise the adapter
-    // will be applied on top of the previous one
-    // Returns 0 on success
-    LLAMA_API int32_t llama_model_apply_lora_from_file(
-            const struct llama_model * model,
-                          const char * path_lora,
-                               float   scale,
-                          const char * path_base_model,
-                             int32_t   n_threads);
+    // Load a LoRA adapter from file
+    // The loaded adapter will be associated to the given model, and will be free when the model is deleted
+    LLAMA_API struct llama_lora_adapter * llama_lora_adapter_init(
+            struct llama_model * model,
+            const char * path_lora);
+
+    // Add a loaded LoRA adapter to given context
+    // This will not modify model's weight
+    LLAMA_API int32_t llama_lora_adapter_set(
+            struct llama_context * ctx,
+            struct llama_lora_adapter * adapter,
+            float scale);
+
+    // Remove a specific LoRA adapter from given context
+    // Return -1 if the adapter is not present in the context
+    LLAMA_API int32_t llama_lora_adapter_remove(
+            struct llama_context * ctx,
+            struct llama_lora_adapter * adapter);
+
+    // Remove all LoRA adapters from given context
+    LLAMA_API void llama_lora_adapter_clear(
+            struct llama_context * ctx);
+
+    // Manually free a LoRA adapter
+    // Note: loaded adapters will be free when the associated model is deleted
+    LLAMA_API void llama_lora_adapter_free(struct llama_lora_adapter * adapter);
 
     // Apply a loaded control vector to a llama_context, or if data is NULL, clear
     // the currently loaded vector.
@@ -668,10 +689,11 @@ extern "C" {
     // State / sessions
     //
 
-    // Returns the maximum size in bytes of the state (rng, logits, embedding
-    // and kv_cache) - will often be smaller after compacting tokens
-    LLAMA_API size_t llama_state_get_size(const struct llama_context * ctx);
-    LLAMA_API DEPRECATED(size_t llama_get_state_size(const struct llama_context * ctx),
+    // Returns the *actual* size in bytes of the state
+    // (rng, logits, embedding and kv_cache)
+    // Only use when saving the state, not when restoring it, otherwise the size may be too small.
+    LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
+    LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
         "use llama_state_get_size instead");
 
     // Copies the state to the specified destination address.
@@ -679,7 +701,8 @@ extern "C" {
     // Returns the number of bytes copied
     LLAMA_API size_t llama_state_get_data(
             struct llama_context * ctx,
-                         uint8_t * dst);
+                         uint8_t * dst,
+                          size_t   size);
     LLAMA_API DEPRECATED(size_t llama_copy_state_data(
             struct llama_context * ctx,
                          uint8_t * dst),
@@ -689,7 +712,8 @@ extern "C" {
     // Returns the number of bytes read
     LLAMA_API size_t llama_state_set_data(
             struct llama_context * ctx,
-                   const uint8_t * src);
+                   const uint8_t * src,
+                          size_t   size);
     LLAMA_API DEPRECATED(size_t llama_set_state_data(
             struct llama_context * ctx,
                    const uint8_t * src),
@@ -731,6 +755,7 @@ extern "C" {
     LLAMA_API size_t llama_state_seq_get_data(
             struct llama_context * ctx,
                          uint8_t * dst,
+                          size_t   size,
                     llama_seq_id   seq_id);
 
     // Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
@@ -740,6 +765,7 @@ extern "C" {
     LLAMA_API size_t llama_state_seq_set_data(
             struct llama_context * ctx,
                    const uint8_t * src,
+                          size_t   size,
                     llama_seq_id   dest_seq_id);
 
     LLAMA_API size_t llama_state_seq_save_file(
@@ -887,10 +913,10 @@ extern "C" {
     LLAMA_API llama_token llama_token_pad(const struct llama_model * model); // padding
 
     // Returns -1 if unknown, 1 for true or 0 for false.
-    LLAMA_API int32_t         llama_add_bos_token(const struct llama_model * model);
+    LLAMA_API int32_t llama_add_bos_token(const struct llama_model * model);
 
     // Returns -1 if unknown, 1 for true or 0 for false.
-    LLAMA_API int32_t         llama_add_eos_token(const struct llama_model * model);
+    LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model);
 
     // Codellama infill tokens
     LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
@@ -946,6 +972,10 @@ extern "C" {
                             bool   remove_special,
                             bool   unparse_special);
 
+    //
+    // Chat templates
+    //
+
     /// Apply chat template. Inspired by hf apply_chat_template() on python.
     /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
     /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
@@ -984,6 +1014,23 @@ extern "C" {
 
     LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
 
+    /// @details Apply constraints from grammar
+    LLAMA_API void llama_grammar_sample(
+            const struct llama_grammar * grammar,
+            const struct llama_context * ctx,
+                llama_token_data_array * candidates);
+    LLAMA_API DEPRECATED(void llama_sample_grammar(
+            struct llama_context * ctx,
+          llama_token_data_array * candidates,
+      const struct llama_grammar * grammar),
+        "use llama_grammar_sample instead");
+
+    /// @details Accepts the sampled token into the grammar
+    LLAMA_API void llama_grammar_accept_token(
+            struct llama_grammar * grammar,
+            struct llama_context * ctx,
+                     llama_token   token);
+
     //
     // Sampling functions
     //
@@ -1065,12 +1112,6 @@ extern "C" {
           llama_token_data_array * candidates,
                            float   temp);
 
-    /// @details Apply constraints from grammar
-    LLAMA_API void llama_sample_grammar(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-      const struct llama_grammar * grammar);
-
     /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
     /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
     /// @param tau  The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
@@ -1108,12 +1149,6 @@ extern "C" {
             struct llama_context * ctx,
           llama_token_data_array * candidates);
 
-    /// @details Accepts the sampled token into the grammar
-    LLAMA_API void llama_grammar_accept_token(
-            struct llama_context * ctx,
-            struct llama_grammar * grammar,
-                     llama_token   token);
-
     //
     // Model split
     //
@@ -1156,38 +1191,45 @@ extern "C" {
 
 struct ggml_tensor;
 
+const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
+    struct llama_context * ctx
+);
+
 struct llama_partial_utf8 {
     uint32_t value;    // bit value so far (unshifted)
     int      n_remain; // num bytes remaining; -1 indicates invalid sequence
 };
 
-struct llama_grammar {
-    const std::vector<std::vector<llama_grammar_element>>   rules;
-    std::vector<std::vector<const llama_grammar_element *>> stacks;
-
-    // buffer for partially generated UTF-8 sequence from accepted tokens
-    llama_partial_utf8                                      partial_utf8;
-};
-
 struct llama_grammar_candidate {
     size_t               index;
     const uint32_t     * code_points;
     llama_partial_utf8   partial_utf8;
 };
 
-const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
-    struct llama_context * ctx
-);
+using llama_grammar_rule  = std::vector<      llama_grammar_element>;
+using llama_grammar_stack = std::vector<const llama_grammar_element *>;
+
+using llama_grammar_rules      = std::vector<llama_grammar_rule>;
+using llama_grammar_stacks     = std::vector<llama_grammar_stack>;
+using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
+
+const llama_grammar_rules  & llama_grammar_get_rules (const struct llama_grammar * grammar);
+      llama_grammar_stacks & llama_grammar_get_stacks(      struct llama_grammar * grammar);
 
 void llama_grammar_accept(
-        const std::vector<std::vector<llama_grammar_element>>         & rules,
-        const std::vector<std::vector<const llama_grammar_element *>> & stacks,
-        const uint32_t                                                  chr,
-        std::vector<std::vector<const llama_grammar_element *>>       & new_stacks);
+        const llama_grammar_rules  & rules,
+        const llama_grammar_stacks & stacks,
+        const uint32_t chr,
+              llama_grammar_stacks & new_stacks);
+
+std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
+        const llama_grammar_rules & rules,
+        const llama_grammar_stack & stack,
+        const llama_grammar_candidates & candidates);
 
 std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
         const std::string & src,
-        llama_partial_utf8   partial_start);
+        llama_partial_utf8 partial_start);
 
 // Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
 // This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
index 51daa15afa66927d5cf984474b941896fc9edd9f..46650bff06d15e9146888c332e4c79913ef09fcb 100644 (file)
@@ -1,3 +1,7 @@
+#if defined(_MSC_VER)
+#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
+#endif
+
 #include "unicode.h"
 #include "unicode-data.h"
 
 #include <locale>
 #include <codecvt>
 
+size_t unicode_len_utf8(char src) {
+    const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
+    uint8_t highbits = static_cast<uint8_t>(src) >> 4;
+    return lookup[highbits];
+}
+
 static std::string unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
     std::string result;
     for (size_t i = 0; i < cps.size(); ++i) {
index 30b07ba7fa4935f441d0a738419be35c0476530d..008532a242ab8d44141577e65b2445df8060fee5 100644 (file)
@@ -4,6 +4,8 @@
 #include <string>
 #include <vector>
 
+// TODO: prefix all symbols with "llama_"
+
 struct codepoint_flags {
     enum {
         UNDEFINED       = 0x0001,
@@ -46,6 +48,7 @@ struct codepoint_flags {
     }
 };
 
+size_t unicode_len_utf8(char src);
 
 std::string unicode_cpt_to_utf8(uint32_t cp);
 uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
index 42eeebfde85de235eee4a8c4e9120a49366b31c0..d5450bdd5aca2aa3ac452f62fe5234f3c8f44be0 100755 (executable)
@@ -2,7 +2,8 @@
 
 cp -rpv ../llama.cpp/include/llama.h ./examples/talk-llama/llama.h
 
-cp -rpv ../llama.cpp/src/llama.cpp        ./examples/talk-llama/llama.cpp
+cp -rpv ../llama.cpp/src/llama*.cpp       ./examples/talk-llama/
+cp -rpv ../llama.cpp/src/llama*.h         ./examples/talk-llama/
 cp -rpv ../llama.cpp/src/unicode.h        ./examples/talk-llama/unicode.h
 cp -rpv ../llama.cpp/src/unicode.cpp      ./examples/talk-llama/unicode.cpp
 cp -rpv ../llama.cpp/src/unicode-data.h   ./examples/talk-llama/unicode-data.h