]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
talk-llama : sync llama.cpp
authorGeorgi Gerganov <redacted>
Tue, 24 Sep 2024 10:22:55 +0000 (13:22 +0300)
committerGeorgi Gerganov <redacted>
Tue, 24 Sep 2024 16:45:08 +0000 (19:45 +0300)
14 files changed:
Makefile
examples/CMakeLists.txt
examples/talk-llama/llama-grammar.cpp
examples/talk-llama/llama-grammar.h
examples/talk-llama/llama-impl.h
examples/talk-llama/llama-sampling.cpp
examples/talk-llama/llama-sampling.h
examples/talk-llama/llama-vocab.cpp
examples/talk-llama/llama-vocab.h
examples/talk-llama/llama.cpp
examples/talk-llama/llama.h
examples/talk-llama/talk-llama.cpp
examples/talk-llama/unicode.cpp
src/whisper.cpp

index 3e359f2fe1ac8a511ca77f1805feb24108a02e46..807f98f857b9d1523376e4dfd446a10b1f114047 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -1080,10 +1080,12 @@ lsp: examples/lsp/lsp.cpp \
        $(CXX) $(CXXFLAGS) $(CFLAGS_SDL) -c $< -o $(call GET_OBJ_FILE, $<)
        $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LDFLAGS_SDL)
 
-talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp \
-       $(OBJ_GGML) $(OBJ_WHISPER) $(OBJ_COMMON) $(OBJ_SDL)
-       $(CXX) $(CXXFLAGS) $(CFLAGS_SDL) -c $< -o $(call GET_OBJ_FILE, $<)
-       $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LDFLAGS_SDL)
+# TODO: disabled until update
+#       https://github.com/ggerganov/whisper.cpp/issues/1818
+#talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp \
+#      $(OBJ_GGML) $(OBJ_WHISPER) $(OBJ_COMMON) $(OBJ_SDL)
+#      $(CXX) $(CXXFLAGS) $(CFLAGS_SDL) -c $< -o $(call GET_OBJ_FILE, $<)
+#      $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LDFLAGS_SDL)
 
 talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp examples/talk-llama/llama-vocab.cpp examples/talk-llama/llama-grammar.cpp examples/talk-llama/llama-sampling.cpp examples/talk-llama/unicode.cpp examples/talk-llama/unicode-data.cpp \
        $(OBJ_GGML) $(OBJ_WHISPER) $(OBJ_COMMON) $(OBJ_SDL)
index cc091d716d1bae30ac935ec97fa59fae25cc01a5..163f425f64b67c9e21cc9cb8e687110842f5cdbc 100644 (file)
@@ -127,8 +127,10 @@ endif (WHISPER_SDL2)
     add_subdirectory(quantize)
     set_target_properties(quantize PROPERTIES FOLDER "examples")
 if (WHISPER_SDL2)
-    add_subdirectory(talk)
-    set_target_properties(talk PROPERTIES FOLDER "examples")
+    # TODO: disabled until update
+    #       https://github.com/ggerganov/whisper.cpp/issues/1818
+    #add_subdirectory(talk)
+    #set_target_properties(talk PROPERTIES FOLDER "examples")
     add_subdirectory(talk-llama)
     set_target_properties(talk-llama PROPERTIES FOLDER "examples")
     add_subdirectory(lsp)
index b123d733100ce836878170b7b2048c376bd91655..74e9f64b393b2f2e144f78b1e30830771e91099b 100644 (file)
@@ -3,11 +3,31 @@
 #include "llama-vocab.h"
 #include "llama-sampling.h"
 
+#include <cmath>
 #include <algorithm>
+#include <stdexcept>
 
-// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
-// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
-std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
+//
+// helpers
+//
+
+// NOTE: assumes valid utf8 (but checks for overrun)
+static std::pair<uint32_t, const char *> decode_utf8(const char * src) {
+    static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
+    uint8_t  first_byte = static_cast<uint8_t>(*src);
+    uint8_t  highbits   = first_byte >> 4;
+    int      len        = lookup[highbits];
+    uint8_t  mask       = (1 << (8 - len)) - 1;
+    uint32_t value      = first_byte & mask;
+    const char * end    = src + len; // may overrun!
+    const char * pos    = src + 1;
+    for ( ; pos < end && *pos; pos++) {
+        value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
+    }
+    return std::make_pair(value, pos);
+}
+
+static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
         const std::string & src,
         llama_partial_utf8 partial_start) {
     static const int      lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
@@ -40,7 +60,7 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
     while (*pos != 0) {
         uint8_t first_byte = static_cast<uint8_t>(*pos);
         uint8_t highbits   = first_byte >> 4;
-                n_remain   = lookup[highbits] - 1;
+        n_remain   = lookup[highbits] - 1;
 
         if (n_remain < 0) {
             // invalid sequence, abort
@@ -50,7 +70,7 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
         }
 
         uint8_t mask  = (1 << (7 - n_remain)) - 1;
-                value = first_byte & mask;
+        value = first_byte & mask;
 
         ++pos;
         while (*pos != 0 && n_remain > 0) {
@@ -67,12 +87,510 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
     return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
 }
 
-const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
-    return grammar->rules;
+static bool is_digit_char(char c) {
+    return '0' <= c && c <= '9';
 }
 
-llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
-    return grammar->stacks;
+static bool is_word_char(char c) {
+    return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
+}
+
+static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
+    const char * pos   = src;
+    const char * end   = src + size;
+    uint32_t     value = 0;
+    for ( ; pos < end && *pos; pos++) {
+        value <<= 4;
+        char c = *pos;
+        if ('a' <= c && c <= 'f') {
+            value += c - 'a' + 10;
+        } else if ('A' <= c && c <= 'F') {
+            value += c - 'A' + 10;
+        } else if ('0' <= c && c <= '9') {
+            value += c - '0';
+        } else {
+            break;
+        }
+    }
+    if (pos != end) {
+        throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
+    }
+    return std::make_pair(value, pos);
+}
+
+static const char * parse_space(const char * src, bool newline_ok) {
+    const char * pos = src;
+    while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
+            (newline_ok && (*pos == '\r' || *pos == '\n'))) {
+        if (*pos == '#') {
+            while (*pos && *pos != '\r' && *pos != '\n') {
+                pos++;
+            }
+        } else {
+            pos++;
+        }
+    }
+    return pos;
+}
+
+static const char * parse_name(const char * src) {
+    const char * pos = src;
+    while (is_word_char(*pos)) {
+        pos++;
+    }
+    if (pos == src) {
+        throw std::runtime_error(std::string("expecting name at ") + src);
+    }
+    return pos;
+}
+
+static const char * parse_int(const char * src) {
+    const char * pos = src;
+    while (is_digit_char(*pos)) {
+        pos++;
+    }
+    if (pos == src) {
+        throw std::runtime_error(std::string("expecting integer at ") + src);
+    }
+    return pos;
+}
+
+static std::pair<uint32_t, const char *> parse_char(const char * src) {
+    if (*src == '\\') {
+        switch (src[1]) {
+            case 'x': return parse_hex(src + 2, 2);
+            case 'u': return parse_hex(src + 2, 4);
+            case 'U': return parse_hex(src + 2, 8);
+            case 't': return std::make_pair('\t', src + 2);
+            case 'r': return std::make_pair('\r', src + 2);
+            case 'n': return std::make_pair('\n', src + 2);
+            case '\\':
+            case '"':
+            case '[':
+            case ']':
+                      return std::make_pair(src[1], src + 2);
+            default:
+                      throw std::runtime_error(std::string("unknown escape at ") + src);
+        }
+    } else if (*src) {
+        return decode_utf8(src);
+    }
+    throw std::runtime_error("unexpected end of input");
+}
+
+static void print_grammar_char(FILE * file, uint32_t c) {
+    if (0x20 <= c && c <= 0x7f) {
+        fprintf(file, "%c", static_cast<char>(c));
+    } else {
+        // cop out of encoding UTF-8
+        fprintf(file, "<U+%04X>", c);
+    }
+}
+
+static bool is_char_element(llama_grammar_element elem) {
+    switch (elem.type) {
+        case LLAMA_GRETYPE_CHAR:           return true;
+        case LLAMA_GRETYPE_CHAR_NOT:       return true;
+        case LLAMA_GRETYPE_CHAR_ALT:       return true;
+        case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
+        case LLAMA_GRETYPE_CHAR_ANY:       return true;
+        default:                           return false;
+    }
+}
+
+static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
+    for (auto elem : rule) {
+        switch (elem.type) {
+            case LLAMA_GRETYPE_END:            fprintf(file, "END");            break;
+            case LLAMA_GRETYPE_ALT:            fprintf(file, "ALT");            break;
+            case LLAMA_GRETYPE_RULE_REF:       fprintf(file, "RULE_REF");       break;
+            case LLAMA_GRETYPE_CHAR:           fprintf(file, "CHAR");           break;
+            case LLAMA_GRETYPE_CHAR_NOT:       fprintf(file, "CHAR_NOT");       break;
+            case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
+            case LLAMA_GRETYPE_CHAR_ALT:       fprintf(file, "CHAR_ALT");       break;
+            case LLAMA_GRETYPE_CHAR_ANY:       fprintf(file, "CHAR_ANY");       break;
+        }
+        switch (elem.type) {
+            case LLAMA_GRETYPE_END:
+            case LLAMA_GRETYPE_ALT:
+            case LLAMA_GRETYPE_RULE_REF:
+                fprintf(file, "(%u) ", elem.value);
+                break;
+            case LLAMA_GRETYPE_CHAR:
+            case LLAMA_GRETYPE_CHAR_NOT:
+            case LLAMA_GRETYPE_CHAR_RNG_UPPER:
+            case LLAMA_GRETYPE_CHAR_ALT:
+            case LLAMA_GRETYPE_CHAR_ANY:
+                fprintf(file, "(\"");
+                print_grammar_char(file, elem.value);
+                fprintf(file, "\") ");
+                break;
+        }
+    }
+    fprintf(file, "\n");
+}
+
+static void print_rule(
+        FILE     * file,
+        uint32_t   rule_id,
+        const llama_grammar_rule & rule,
+        const std::map<uint32_t, std::string> & symbol_id_names) {
+    if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
+        throw std::runtime_error(
+            "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
+    }
+    fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
+    for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
+        llama_grammar_element elem = rule[i];
+        switch (elem.type) {
+            case LLAMA_GRETYPE_END:
+                throw std::runtime_error(
+                    "unexpected end of rule: " + std::to_string(rule_id) + "," +
+                    std::to_string(i));
+            case LLAMA_GRETYPE_ALT:
+                fprintf(file, "| ");
+                break;
+            case LLAMA_GRETYPE_RULE_REF:
+                fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
+                break;
+            case LLAMA_GRETYPE_CHAR:
+                fprintf(file, "[");
+                print_grammar_char(file, elem.value);
+                break;
+            case LLAMA_GRETYPE_CHAR_NOT:
+                fprintf(file, "[^");
+                print_grammar_char(file, elem.value);
+                break;
+            case LLAMA_GRETYPE_CHAR_RNG_UPPER:
+                if (i == 0 || !is_char_element(rule[i - 1])) {
+                    throw std::runtime_error(
+                        "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
+                        std::to_string(rule_id) + "," + std::to_string(i));
+                }
+                fprintf(file, "-");
+                print_grammar_char(file, elem.value);
+                break;
+            case LLAMA_GRETYPE_CHAR_ALT:
+                if (i == 0 || !is_char_element(rule[i - 1])) {
+                    throw std::runtime_error(
+                        "LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
+                        std::to_string(rule_id) + "," + std::to_string(i));
+                }
+                print_grammar_char(file, elem.value);
+                break;
+            case LLAMA_GRETYPE_CHAR_ANY:
+                fprintf(file, ".");
+                break;
+        }
+        if (is_char_element(elem)) {
+            switch (rule[i + 1].type) {
+                case LLAMA_GRETYPE_CHAR_ALT:
+                case LLAMA_GRETYPE_CHAR_RNG_UPPER:
+                case LLAMA_GRETYPE_CHAR_ANY:
+                    break;
+                default:
+                    fprintf(file, "] ");
+            }
+        }
+    }
+    fprintf(file, "\n");
+}
+
+//
+// implementation
+//
+
+uint32_t llama_grammar_parser::get_symbol_id(const char * src, size_t len) {
+    uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
+    auto result = symbol_ids.emplace(std::string(src, len), next_id);
+    return result.first->second;
+}
+
+uint32_t llama_grammar_parser::generate_symbol_id(const std::string & base_name) {
+    uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
+    symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
+    return next_id;
+}
+
+void llama_grammar_parser::add_rule(uint32_t rule_id, const llama_grammar_rule & rule) {
+    if (rules.size() <= rule_id) {
+        rules.resize(rule_id + 1);
+    }
+    rules[rule_id] = rule;
+}
+
+const char * llama_grammar_parser::parse_alternates(
+        const char        * src,
+        const std::string & rule_name,
+        uint32_t            rule_id,
+        bool                is_nested) {
+    llama_grammar_rule rule;
+    const char * pos = parse_sequence(src, rule_name, rule, is_nested);
+    while (*pos == '|') {
+        rule.push_back({LLAMA_GRETYPE_ALT, 0});
+        pos = parse_space(pos + 1, true);
+        pos = parse_sequence(pos, rule_name, rule, is_nested);
+    }
+    rule.push_back({LLAMA_GRETYPE_END, 0});
+    add_rule(rule_id, rule);
+    return pos;
+}
+
+const char * llama_grammar_parser::parse_sequence(
+        const char         * src,
+        const std::string  & rule_name,
+        llama_grammar_rule & rule,
+        bool               is_nested) {
+    size_t last_sym_start = rule.size();
+    const char * pos = src;
+
+        auto handle_repetitions = [&](int min_times, int max_times) {
+
+            if (last_sym_start == rule.size()) {
+                throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
+            }
+
+            // apply transformation to previous symbol (last_sym_start to end) according to
+            // the following rewrite rules:
+            // S{m,n} --> S S S (m times) S'(n-m)
+            //            S'(x)   ::= S S'(x-1) |
+            //            (... n-m definitions of these S' rules ...)
+            //            S'(1)   ::= S |
+            // S{m,} -->  S S S (m times) S'
+            //            S'     ::= S S' |
+            // S*     --> S{0,}
+            //        --> S'     ::= S S' |
+            // S+     --> S{1,}
+            //        --> S S'
+            //            S'     ::= S S' |
+            // S?     --> S{0,1}
+            //        --> S'
+            //            S'     ::= S |
+
+            llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
+            if (min_times == 0) {
+                rule.resize(last_sym_start);
+            } else {
+                // Repeat the previous elements (min_times - 1) times
+                for (int i = 1; i < min_times; i++) {
+                    rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
+                }
+            }
+
+            uint32_t last_rec_rule_id = 0;
+            auto n_opt = max_times < 0 ? 1 : max_times - min_times;
+
+            llama_grammar_rule rec_rule(prev_rule);
+            for (int i = 0; i < n_opt; i++) {
+                rec_rule.resize(prev_rule.size());
+                uint32_t rec_rule_id = generate_symbol_id( rule_name);
+                if (i > 0 || max_times < 0) {
+                    rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
+                }
+                rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
+                rec_rule.push_back({LLAMA_GRETYPE_END, 0});
+                add_rule( rec_rule_id, rec_rule);
+                last_rec_rule_id = rec_rule_id;
+            }
+            if (n_opt > 0) {
+                rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
+            }
+        };
+
+        while (*pos) {
+            if (*pos == '"') { // literal string
+                pos++;
+                last_sym_start = rule.size();
+                while (*pos != '"') {
+                    if (!*pos) {
+                        throw std::runtime_error("unexpected end of input");
+                    }
+                    auto char_pair = parse_char(pos);
+                         pos       = char_pair.second;
+                    rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
+                }
+                pos = parse_space(pos + 1, is_nested);
+            } else if (*pos == '[') { // char range(s)
+                pos++;
+                enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
+                if (*pos == '^') {
+                    pos++;
+                    start_type = LLAMA_GRETYPE_CHAR_NOT;
+                }
+                last_sym_start = rule.size();
+                while (*pos != ']') {
+                    if (!*pos) {
+                        throw std::runtime_error("unexpected end of input");
+                    }
+                    auto char_pair = parse_char(pos);
+                         pos       = char_pair.second;
+                    enum llama_gretype type = last_sym_start < rule.size()
+                        ? LLAMA_GRETYPE_CHAR_ALT
+                        : start_type;
+
+                    rule.push_back({type, char_pair.first});
+                    if (pos[0] == '-' && pos[1] != ']') {
+                        if (!pos[1]) {
+                            throw std::runtime_error("unexpected end of input");
+                        }
+                        auto endchar_pair = parse_char(pos + 1);
+                             pos          = endchar_pair.second;
+                        rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
+                    }
+                }
+                pos = parse_space(pos + 1, is_nested);
+            } else if (is_word_char(*pos)) { // rule reference
+                const char * name_end    = parse_name(pos);
+                uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
+                pos = parse_space(name_end, is_nested);
+                last_sym_start = rule.size();
+                rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
+            } else if (*pos == '(') { // grouping
+                // parse nested alternates into synthesized rule
+                pos = parse_space(pos + 1, true);
+                uint32_t sub_rule_id = generate_symbol_id(rule_name);
+                pos = parse_alternates(pos, rule_name, sub_rule_id, true);
+                last_sym_start = rule.size();
+                // output reference to synthesized rule
+                rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
+                if (*pos != ')') {
+                    throw std::runtime_error(std::string("expecting ')' at ") + pos);
+                }
+                pos = parse_space(pos + 1, is_nested);
+            } else if (*pos == '.') { // any char
+                last_sym_start = rule.size();
+                rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
+                pos = parse_space(pos + 1, is_nested);
+            } else if (*pos == '*') {
+                pos = parse_space(pos + 1, is_nested);
+                handle_repetitions(0, -1);
+            } else if (*pos == '+') {
+                pos = parse_space(pos + 1, is_nested);
+                handle_repetitions(1, -1);
+            } else if (*pos == '?') {
+                pos = parse_space(pos + 1, is_nested);
+                handle_repetitions(0, 1);
+            } else if (*pos == '{') {
+                pos = parse_space(pos + 1, is_nested);
+
+                if (!is_digit_char(*pos)) {
+                    throw std::runtime_error(std::string("expecting an int at ") + pos);
+                }
+                const char * int_end = parse_int(pos);
+                int min_times = std::stoul(std::string(pos, int_end - pos));
+                pos = parse_space(int_end, is_nested);
+
+                int max_times = -1;
+
+                if (*pos == '}') {
+                    max_times = min_times;
+                    pos = parse_space(pos + 1, is_nested);
+                } else if (*pos == ',') {
+                    pos = parse_space(pos + 1, is_nested);
+
+                    if (is_digit_char(*pos)) {
+                        const char * int_end = parse_int(pos);
+                        max_times = std::stoul(std::string(pos, int_end - pos));
+                        pos = parse_space(int_end, is_nested);
+                    }
+
+                    if (*pos != '}') {
+                        throw std::runtime_error(std::string("expecting '}' at ") + pos);
+                    }
+                    pos = parse_space(pos + 1, is_nested);
+                } else {
+                    throw std::runtime_error(std::string("expecting ',' at ") + pos);
+                }
+                handle_repetitions(min_times, max_times);
+            } else {
+                break;
+            }
+        }
+        return pos;
+    }
+
+const char * llama_grammar_parser::parse_rule(const char * src) {
+        const char * name_end = parse_name(src);
+        const char * pos      = parse_space(name_end, false);
+        size_t       name_len = name_end - src;
+        uint32_t     rule_id  = get_symbol_id(src, name_len);
+        const std::string name(src, name_len);
+
+        if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
+            throw std::runtime_error(std::string("expecting ::= at ") + pos);
+        }
+        pos = parse_space(pos + 3, true);
+
+        pos = parse_alternates(pos, name, rule_id, false);
+
+        if (*pos == '\r') {
+            pos += pos[1] == '\n' ? 2 : 1;
+        } else if (*pos == '\n') {
+            pos++;
+        } else if (*pos) {
+            throw std::runtime_error(std::string("expecting newline or end at ") + pos);
+        }
+        return parse_space(pos, true);
+    }
+
+bool llama_grammar_parser::parse(const char * src) {
+    try {
+        const char * pos = parse_space(src, true);
+        while (*pos) {
+            pos = parse_rule(pos);
+        }
+        // Validate the state to ensure that all rules are defined
+        for (const auto & rule : rules) {
+            if (rule.empty()) {
+                throw std::runtime_error("Undefined rule");
+            }
+            for (const auto & elem : rule) {
+                if (elem.type == LLAMA_GRETYPE_RULE_REF) {
+                    // Ensure that the rule at that location exists
+                    if (elem.value >= rules.size() || rules[elem.value].empty()) {
+                        // Get the name of the rule that is missing
+                        for (const auto & kv : symbol_ids) {
+                            if (kv.second == elem.value) {
+                                throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } catch (const std::exception & err) {
+        fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
+        rules.clear();
+        return false;
+    }
+
+    return true;
+}
+
+void llama_grammar_parser::print(FILE * file) {
+    try {
+        std::map<uint32_t, std::string> symbol_id_names;
+        for (const auto & kv : symbol_ids) {
+            symbol_id_names[kv.second] = kv.first;
+        }
+        for (size_t i = 0, end = rules.size(); i < end; i++) {
+            // fprintf(file, "%zu: ", i);
+            // print_rule_binary(file, rules[i]);
+            print_rule(file, uint32_t(i), rules[i], symbol_id_names);
+            // fprintf(file, "\n");
+        }
+    } catch (const std::exception & err) {
+        fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
+    }
+}
+
+llama_grammar_stack llama_grammar_parser::c_rules() const {
+    llama_grammar_stack ret;
+    ret.reserve(rules.size());
+    for (const auto & rule : rules) {
+        ret.push_back(rule.data());
+    }
+    return ret;
 }
 
 // returns true iff pos points to the end of one of the definitions of a rule
@@ -89,7 +607,6 @@ static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos)
 static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
         const llama_grammar_element * pos,
         const uint32_t                chr) {
-
     bool found            = false;
     bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
 
@@ -225,16 +742,93 @@ static void llama_grammar_advance_stack(
     }
 }
 
-// takes a set of possible pushdown stacks on a grammar, which are required to
-// be positioned at a character range (see `llama_grammar_advance_stack`), and
-// produces the N possible stacks if the given char is accepted at those
-// positions
+static llama_grammar_candidates llama_grammar_reject_candidates(
+        const llama_grammar_rules      & rules,
+        const llama_grammar_stacks     & stacks,
+        const llama_grammar_candidates & candidates) {
+    GGML_ASSERT(!stacks.empty()); // REVIEW
+
+    if (candidates.empty()) {
+        return {};
+    }
+
+    auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
+
+    for (size_t i = 1, size = stacks.size(); i < size; ++i) {
+        rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
+    }
+
+    return rejects;
+}
+
+static bool llama_grammar_detect_left_recursion(
+        const llama_grammar_rules & rules,
+        size_t rule_index,
+        std::vector<bool> * rules_visited,
+        std::vector<bool> * rules_in_progress,
+        std::vector<bool> * rules_may_be_empty) {
+    if ((*rules_in_progress)[rule_index]) {
+        return true;
+    }
+
+    (*rules_in_progress)[rule_index] = true;
+
+    const llama_grammar_rule & rule = rules[rule_index];
+
+    // First check if the rule might produce the empty string. This could be done combined with the second
+    // step but it's more readable as two steps.
+    bool at_rule_start = true;
+    for (size_t i = 0; i < rule.size(); i++) {
+        if (llama_grammar_is_end_of_sequence(&rule[i])) {
+            if (at_rule_start) {
+                (*rules_may_be_empty)[rule_index] = true;
+                break;
+            }
+            at_rule_start = true;
+        } else {
+            at_rule_start = false;
+        }
+    }
+
+    // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
+    // be empty)
+    bool recurse_into_nonterminal = true;
+    for (size_t i = 0; i < rule.size(); i++) {
+        if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
+            if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
+                return true;
+            }
+            if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
+                recurse_into_nonterminal = false;
+            }
+        } else if (llama_grammar_is_end_of_sequence(&rule[i])) {
+            recurse_into_nonterminal = true;
+        } else {
+            recurse_into_nonterminal = false;
+        }
+    }
+
+    (*rules_in_progress)[rule_index] = false;
+    (*rules_visited)[rule_index] = true;
+
+    return false;
+}
+
+const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
+    return grammar->rules;
+}
+
+llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
+    return grammar->stacks;
+}
+
 void llama_grammar_accept(
         const llama_grammar_rules  & rules,
         const llama_grammar_stacks & stacks,
         const uint32_t               chr,
-              llama_grammar_stacks & new_stacks) {
-    new_stacks.clear();
+              llama_grammar_stacks & stacks_new) {
+    stacks_new.clear();
+    stacks_new.reserve(stacks.size());
 
     for (const auto & stack : stacks) {
         if (stack.empty()) {
@@ -250,29 +844,11 @@ void llama_grammar_accept(
             if (!llama_grammar_is_end_of_sequence(pos)) {
                 new_stack.push_back(pos);
             }
-            llama_grammar_advance_stack(rules, new_stack, new_stacks);
+            llama_grammar_advance_stack(rules, new_stack, stacks_new);
         }
     }
 }
 
-static llama_grammar_candidates llama_grammar_reject_candidates(
-        const llama_grammar_rules  & rules,
-        const llama_grammar_stacks & stacks,
-        const llama_grammar_candidates & candidates) {
-    GGML_ASSERT(!stacks.empty()); // REVIEW
-
-    if (candidates.empty()) {
-        return {};
-    }
-
-    auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
-
-    for (size_t i = 1, size = stacks.size(); i < size; ++i) {
-        rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
-    }
-    return rejects;
-}
-
 llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
         const llama_grammar_rules      & rules,
         const llama_grammar_stack      & stack,
@@ -328,72 +904,97 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
     return rejects;
 }
 
-static bool llama_grammar_detect_left_recursion(
-        const llama_grammar_rules & rules,
-        size_t rule_index,
-        std::vector<bool> * rules_visited,
-        std::vector<bool> * rules_in_progress,
-        std::vector<bool> * rules_may_be_empty) {
-    if ((*rules_in_progress)[rule_index]) {
-        return true;
-    }
+////////////////////
 
-    (*rules_in_progress)[rule_index] = true;
+struct llama_grammar * llama_grammar_init_impl(
+        const struct llama_vocab * vocab,
+        const llama_grammar_element ** rules,
+        size_t n_rules,
+        size_t start_rule_index) {
+    const llama_grammar_element * pos;
 
-    const llama_grammar_rule & rule = rules[rule_index];
+    // copy rule definitions into vectors
+    llama_grammar_rules vec_rules(n_rules);
+    for (size_t i = 0; i < n_rules; i++) {
+        for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
+            vec_rules[i].push_back(*pos);
+        }
+        vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
+    }
 
-    // First check if the rule might produce the empty string. This could be done combined with the second
-    // step but it's more readable as two steps.
-    bool at_rule_start = true;
-    for (size_t i = 0; i < rule.size(); i++) {
-        if (llama_grammar_is_end_of_sequence(&rule[i])) {
-            if (at_rule_start) {
-                (*rules_may_be_empty)[rule_index] = true;
-                break;
-            }
-            at_rule_start = true;
-        } else {
-            at_rule_start = false;
+    // Check for left recursion
+    std::vector<bool> rules_visited(n_rules);
+    std::vector<bool> rules_in_progress(n_rules);
+    std::vector<bool> rules_may_be_empty(n_rules);
+    for (size_t i = 0; i < n_rules; i++) {
+        if (rules_visited[i]) {
+            continue;
+        }
+        if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
+            LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
+            return nullptr;
         }
     }
 
-    // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
-    // be empty)
-    bool recurse_into_nonterminal = true;
-    for (size_t i = 0; i < rule.size(); i++) {
-        if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
-            if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
-                return true;
-            }
-            if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
-                recurse_into_nonterminal = false;
-            }
-        } else if (llama_grammar_is_end_of_sequence(&rule[i])) {
-            recurse_into_nonterminal = true;
+    // loop over alternates of start rule to build initial stacks
+    llama_grammar_stacks stacks;
+    pos = vec_rules[start_rule_index].data();
+    do {
+        llama_grammar_stack stack;
+        if (!llama_grammar_is_end_of_sequence(pos)) {
+            // if alternate is nonempty, add to stack
+            stack.push_back(pos);
+        }
+        llama_grammar_advance_stack(vec_rules, stack, stacks);
+        while (!llama_grammar_is_end_of_sequence(pos)) {
+            // scan to end of alternate def
+            pos++;
+        }
+        if (pos->type == LLAMA_GRETYPE_ALT) {
+            // there's another alternate def of this rule to process
+            pos++;
         } else {
-            recurse_into_nonterminal = false;
+            break;
         }
-    }
+    } while (true);
 
-    (*rules_in_progress)[rule_index] = false;
-    (*rules_visited)[rule_index] = true;
-    return false;
+    // Important: vec_rules has to be moved here, not copied, because stacks contains
+    // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
+    // then the pointers would be invalidated when the local vec_rules goes out of scope.
+    return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
 }
 
-//
-// grammar - external
-//
+struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
+    llama_grammar_parser parser;
+
+    // if there is a grammar, parse it
+    if (!parser.parse(grammar_str)) {
+        return nullptr;
+    }
+
+    // will be empty (default) if there are parse errors
+    if (parser.rules.empty()) {
+        fprintf(stderr, "%s: failed to parse grammar\n", __func__);
+        return nullptr;
+    }
+
+    // Ensure that there is a "root" node.
+    if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) {
+        fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
+        return nullptr;
+    }
+
+    std::vector<const llama_grammar_element *> grammar_rules(parser.c_rules());
+
+    const size_t n_rules = grammar_rules.size();
+    const size_t start_rule_index = parser.symbol_ids.at(grammar_root);
 
-struct llama_grammar * llama_grammar_init_impl(
-            const llama_grammar_element ** rules,
-                                 size_t    n_rules,
-                                 size_t    start_rule_index) {
     const llama_grammar_element * pos;
 
     // copy rule definitions into vectors
     llama_grammar_rules vec_rules(n_rules);
     for (size_t i = 0; i < n_rules; i++) {
-        for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
+        for (pos = grammar_rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
             vec_rules[i].push_back(*pos);
         }
         vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
@@ -438,22 +1039,26 @@ struct llama_grammar * llama_grammar_init_impl(
     // Important: vec_rules has to be moved here, not copied, because stacks contains
     // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
     // then the pointers would be invalidated when the local vec_rules goes out of scope.
-    return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
+    return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
 }
 
 void llama_grammar_free_impl(struct llama_grammar * grammar) {
+    if (grammar == nullptr) {
+        return;
+    }
+
     delete grammar;
 }
 
-struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) {
-    llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 };
+struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
+    llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, };
 
     // redirect elements in stacks to point to new rules
     for (size_t is = 0; is < result->stacks.size(); is++) {
         for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
-            for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) {
-                for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) {
-                    if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) {
+            for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
+                for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
+                    if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
                          result->stacks[is][ie]  =  &result->rules[ir0][ir1];
                     }
                 }
@@ -464,14 +1069,11 @@ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * gram
     return result;
 }
 
-void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) {
-    GGML_ASSERT(grammar);
-    GGML_ASSERT(vocab);
-
-    int64_t t_start_sample_us = ggml_time_us();
+void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
+    GGML_ASSERT(grammar.vocab != nullptr);
 
     bool allow_eog = false;
-    for (const auto & stack : grammar->stacks) {
+    for (const auto & stack : grammar.stacks) {
         if (stack.empty()) {
             allow_eog = true;
             break;
@@ -479,40 +1081,38 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc
     }
 
     std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
-    candidates_decoded.reserve(candidates->size);
+    candidates_decoded.reserve(cur_p->size);
 
     llama_grammar_candidates candidates_grammar;
-    candidates_grammar.reserve(candidates->size);
+    candidates_grammar.reserve(cur_p->size);
 
-    for (size_t i = 0; i < candidates->size; ++i) {
-        const llama_token id      = candidates->data[i].id;
-        const std::string & piece = vocab->cache_token_to_piece.at(id);
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        const llama_token id      = cur_p->data[i].id;
+        const std::string & piece = grammar.vocab->cache_token_to_piece.at(id);
 
-        if (llama_token_is_eog_impl(*vocab, id)) {
+        if (llama_token_is_eog_impl(*grammar.vocab, id)) {
             if (!allow_eog) {
-                candidates->data[i].logit = -INFINITY;
+                cur_p->data[i].logit = -INFINITY;
             }
         } else if (piece.empty() || piece[0] == 0) {
-            candidates->data[i].logit = -INFINITY;
+            cur_p->data[i].logit = -INFINITY;
         } else {
-            candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
+            candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
             candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
         }
     }
 
-    const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar);
+    const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
     for (const auto & reject : rejects) {
-        candidates->data[reject.index].logit = -INFINITY;
+        cur_p->data[reject.index].logit = -INFINITY;
     }
-
-    smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
 }
 
-void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) {
-    const int64_t t_start_sample_us = ggml_time_us();
+void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
+    GGML_ASSERT(grammar.vocab != nullptr);
 
-    if (llama_token_is_eog_impl(*vocab, token)) {
-        for (const auto & stack : grammar->stacks) {
+    if (llama_token_is_eog_impl(*grammar.vocab, token)) {
+        for (const auto & stack : grammar.stacks) {
             if (stack.empty()) {
                 return;
             }
@@ -520,20 +1120,19 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc
         GGML_ABORT("fatal error");
     }
 
-    const std::string & piece = vocab->cache_token_to_piece.at(token);
+    const std::string & piece = grammar.vocab->cache_token_to_piece.at(token);
 
     // Note terminating 0 in decoded string
-    const auto   decoded     = decode_utf8(piece, grammar->partial_utf8);
+    const auto   decoded     = decode_utf8(piece, grammar.partial_utf8);
     const auto & code_points = decoded.first;
 
-    llama_grammar_stacks tmp_new_stacks;
+    llama_grammar_stacks stacks_new;
+
     for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
-        llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks);
-        grammar->stacks = tmp_new_stacks;
+        llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new);
+        grammar.stacks = std::move(stacks_new);
     }
 
-    grammar->partial_utf8 = decoded.second;
-    GGML_ASSERT(!grammar->stacks.empty());
-
-    smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+    grammar.partial_utf8 = decoded.second;
+    GGML_ASSERT(!grammar.stacks.empty());
 }
index 695ea0632bb84c698db84e833a615d32199821c1..f529ce351e4167d03cbeb538018047a6287c1c02 100644 (file)
 
 #include "llama-impl.h"
 
+#include <map>
+
 struct llama_vocab;
-struct llama_sampling;
+
+// grammar element type
+enum llama_gretype {
+    // end of rule definition
+    LLAMA_GRETYPE_END            = 0,
+
+    // start of alternate definition for rule
+    LLAMA_GRETYPE_ALT            = 1,
+
+    // non-terminal element: reference to rule
+    LLAMA_GRETYPE_RULE_REF       = 2,
+
+    // terminal element: character (code point)
+    LLAMA_GRETYPE_CHAR           = 3,
+
+    // inverse char(s) ([^a], [^a-b] [^abc])
+    LLAMA_GRETYPE_CHAR_NOT       = 4,
+
+    // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
+    // be an inclusive range ([a-z])
+    LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
+
+    // modifies a preceding LLAMA_GRETYPE_CHAR or
+    // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
+    LLAMA_GRETYPE_CHAR_ALT       = 6,
+
+    // any character (.)
+    LLAMA_GRETYPE_CHAR_ANY       = 7,
+};
+
+typedef struct llama_grammar_element {
+    enum llama_gretype type;
+    uint32_t           value; // Unicode code point or rule ID
+} llama_grammar_element;
+
+struct llama_partial_utf8 {
+    uint32_t value;    // bit value so far (unshifted)
+    int      n_remain; // num bytes remaining; -1 indicates invalid sequence
+};
+
+struct llama_grammar_candidate {
+    size_t               index;
+    const uint32_t     * code_points;
+    llama_partial_utf8   partial_utf8;
+};
+
+using llama_grammar_rule  = std::vector<      llama_grammar_element>;
+using llama_grammar_stack = std::vector<const llama_grammar_element *>;
+
+using llama_grammar_rules      = std::vector<llama_grammar_rule>;
+using llama_grammar_stacks     = std::vector<llama_grammar_stack>;
+using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
+
+const llama_grammar_rules  & llama_grammar_get_rules (const struct llama_grammar * grammar);
+      llama_grammar_stacks & llama_grammar_get_stacks(      struct llama_grammar * grammar);
+
+// takes a set of possible pushdown stacks on a grammar, which are required to
+// be positioned at a character range (see `llama_grammar_advance_stack`), and
+// produces the N possible stacks if the given char is accepted at those
+// positions
+void llama_grammar_accept(
+        const llama_grammar_rules  & rules,
+        const llama_grammar_stacks & stacks,
+                          uint32_t   chr,
+              llama_grammar_stacks & stacks_new);
+
+std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
+        const llama_grammar_rules      & rules,
+        const llama_grammar_stack      & stack,
+        const llama_grammar_candidates & candidates);
+
+struct llama_grammar_parser {
+    std::map<std::string, uint32_t> symbol_ids;
+
+    llama_grammar_rules rules;
+
+    llama_grammar_stack c_rules() const;
+
+    uint32_t get_symbol_id(const char * src, size_t len);
+    uint32_t generate_symbol_id(const std::string & base_name);
+
+    void add_rule(uint32_t rule_id, const llama_grammar_rule & rule);
+
+    const char * parse_alternates(
+            const char        * src,
+            const std::string & rule_name,
+            uint32_t            rule_id,
+            bool                is_nested);
+
+    const char * parse_sequence(
+            const char         * src,
+            const std::string  & rule_name,
+            llama_grammar_rule & rule,
+            bool               is_nested);
+
+    const char * parse_rule(const char * src);
+
+    bool parse(const char * src);
+    void print(FILE * file);
+};
 
 struct llama_grammar {
-    const llama_grammar_rules  rules;
+    // note: allow null vocab for testing (not great)
+    const llama_vocab * vocab;
+
+    const llama_grammar_rules  rules;  // TODO: shared ptr
           llama_grammar_stacks stacks;
 
     // buffer for partially generated UTF-8 sequence from accepted tokens
@@ -17,23 +121,24 @@ struct llama_grammar {
 // internal API
 //
 
+// note: needed for tests (not great)
 struct llama_grammar * llama_grammar_init_impl(
-            const llama_grammar_element ** rules,
-                                 size_t    n_rules,
-                                 size_t    start_rule_index);
+        const struct llama_vocab * vocab,
+        const llama_grammar_element ** rules,
+        size_t n_rules,
+        size_t start_rule_index);
+
+struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root);
 
 void llama_grammar_free_impl(struct llama_grammar * grammar);
 
-struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar);
+struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar);
 
-void llama_grammar_sample_impl(
-        const struct llama_grammar * grammar,
-          const struct llama_vocab * vocab,
-       const struct llama_sampling * smpl,
-            llama_token_data_array * candidates);
+// TODO: move the API below as member functions of llama_grammar
+void llama_grammar_apply_impl(
+        const struct llama_grammar & grammar,
+            llama_token_data_array * cur_p);
 
-void llama_grammar_accept_token_impl(
-              struct llama_grammar * grammar,
-          const struct llama_vocab * vocab,
-       const struct llama_sampling * smpl,
+void llama_grammar_accept_impl(
+              struct llama_grammar & grammar,
                        llama_token   token);
index 9527740961da652657d15b898d973ad8aac6d33c..70f16b61c12e07f6e5735030639b894dd9c110a6 100644 (file)
@@ -1,8 +1,11 @@
 #pragma once
 
-#define LLAMA_API_INTERNAL
 #include "llama.h"
 
+#include <string>
+#include <vector>
+#include <stdexcept>
+
 #ifdef __GNUC__
 #ifdef __MINGW32__
 #define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
@@ -21,14 +24,31 @@ LLAMA_ATTRIBUTE_FORMAT(2, 3)
 void llama_log_internal        (ggml_log_level level, const char * format, ...);
 void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data);
 
+#define LLAMA_LOG(...)       llama_log_internal(GGML_LOG_LEVEL_NONE , __VA_ARGS__)
 #define LLAMA_LOG_INFO(...)  llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
 #define LLAMA_LOG_WARN(...)  llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
 #define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
+#define LLAMA_LOG_DEBUG(...) llama_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
+#define LLAMA_LOG_CONT(...)  llama_log_internal(GGML_LOG_LEVEL_CONT , __VA_ARGS__)
 
 //
 // helpers
 //
 
+struct time_meas {
+    time_meas(int64_t & t_acc, bool disable = false) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
+
+    ~time_meas() {
+        if (t_start_us >= 0) {
+            t_acc += ggml_time_us() - t_start_us;
+        }
+    }
+
+    const int64_t t_start_us;
+
+    int64_t & t_acc;
+};
+
 static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
     if (search.empty()) {
         return;
@@ -45,3 +65,117 @@ static void replace_all(std::string & s, const std::string & search, const std::
     builder.append(s, last_pos, std::string::npos);
     s = std::move(builder);
 }
+
+const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
+    struct llama_context * ctx
+);
+
+// the ring buffer works similarly to std::deque, but with a fixed capacity
+template<typename T>
+struct ring_buffer {
+    ring_buffer(size_t cap) : capacity(cap), data(cap) {}
+
+    T & front() {
+        if (sz == 0) {
+            throw std::runtime_error("ring buffer is empty");
+        }
+        return data[first];
+    }
+
+    const T & front() const {
+        if (sz == 0) {
+            throw std::runtime_error("ring buffer is empty");
+        }
+        return data[first];
+    }
+
+    T & back() {
+        if (sz == 0) {
+            throw std::runtime_error("ring buffer is empty");
+        }
+        return data[pos];
+    }
+
+    const T & back() const {
+        if (sz == 0) {
+            throw std::runtime_error("ring buffer is empty");
+        }
+        return data[pos];
+    }
+
+    void push_back(const T & value) {
+        if (capacity == 0) {
+            throw std::runtime_error("ring buffer: capacity is zero");
+        }
+
+        if (sz == capacity) {
+            // advance the start when buffer is full
+            first = (first + 1) % capacity;
+        } else {
+            sz++;
+        }
+        data[pos] = value;
+        pos = (pos + 1) % capacity;
+    }
+
+    T pop_front() {
+        if (sz == 0) {
+            throw std::runtime_error("ring buffer is empty");
+        }
+        T value = data[first];
+        first = (first + 1) % capacity;
+        sz--;
+        return value;
+    }
+
+    //T & operator[](size_t i) {
+    //    if (i >= sz) {
+    //        throw std::runtime_error("ring buffer: index out of bounds");
+    //    }
+    //    return data[(first + i) % capacity];
+    //}
+
+    //const T & at(size_t i) const {
+    //    if (i >= sz) {
+    //        throw std::runtime_error("ring buffer: index out of bounds");
+    //    }
+    //    return data[(first + i) % capacity];
+    //}
+
+    const T & rat(size_t i) const {
+        if (i >= sz) {
+            throw std::runtime_error("ring buffer: index out of bounds");
+        }
+        return data[(first + sz - i - 1) % capacity];
+    }
+
+    std::vector<T> to_vector() const {
+        std::vector<T> result;
+        result.reserve(sz);
+        for (size_t i = 0; i < sz; i++) {
+            result.push_back(data[(first + i) % capacity]);
+        }
+        return result;
+    }
+
+    void clear() {
+        // here only reset the status of the buffer
+        sz = 0;
+        first = 0;
+        pos = 0;
+    }
+
+    bool empty() const {
+        return sz == 0;
+    }
+
+    size_t size() const {
+        return sz;
+    }
+
+    size_t capacity = 0;
+    size_t sz = 0;
+    size_t first = 0;
+    size_t pos = 0;
+    std::vector<T> data;
+};
index 8f4841d9daf7b90f681eca1d0c989e761d205181..e255a8fc4fd548e135bb00465698b1b61a5043f2 100644 (file)
@@ -1,12 +1,53 @@
 #include "llama-sampling.h"
 
+#include "llama-vocab.h"
+#include "llama-grammar.h"
+
 #include <algorithm>
+#include <cassert>
+#include <cfloat>
+#include <chrono>
+#include <cmath>
+#include <cstdlib>
 #include <cstring>
 #include <ctime>
-#include <cfloat>
 #include <numeric>
+#include <random>
 #include <unordered_map>
 
+static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
+    // iterator for the probabilities
+#ifdef __GNUC__
+    #pragma GCC diagnostic push
+    #pragma GCC diagnostic ignored "-Wunused-local-typedefs"
+#endif
+
+    struct probs_iterator {
+        typedef std::input_iterator_tag iterator_category;
+        typedef float value_type;
+        typedef float * pointer;
+        typedef float & reference;
+        typedef ptrdiff_t difference_type;
+
+        const llama_token_data * data;
+
+        bool operator==(const probs_iterator & other) const { return data == other.data; }
+        bool operator!=(const probs_iterator & other) const { return data != other.data; }
+        const float & operator*() const { return data->p; }
+        probs_iterator & operator++() { ++data; return *this; }
+        probs_iterator operator++(int) { probs_iterator tmp = *this; ++data; return tmp; }
+    };
+
+#ifdef __GNUC__
+    #pragma GCC diagnostic pop
+#endif
+
+    std::discrete_distribution<int> dist(probs_iterator{cur_p->data}, probs_iterator{cur_p->data + cur_p->size});
+
+    return dist(rng);
+}
+
+/*
 static void llama_log_softmax(float * array, size_t size) {
     float max_l = *std::max_element(array, array + size);
     float sum = 0.f;
@@ -20,66 +61,52 @@ static void llama_log_softmax(float * array, size_t size) {
         array[i] = logf(array[i] / sum);
     }
 }
+*/
 
-void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) {
-    if (seed == LLAMA_DEFAULT_SEED) {
-        seed = time(NULL);
-    }
-
-    smpl->rng.seed(seed);
-}
-
-void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
-    GGML_ASSERT(candidates->size > 0);
-
-    const int64_t t_start_sample_us = ggml_time_us();
+static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
+    GGML_ASSERT(cur_p->size > 0);
 
     // Sort the logits in descending order
-    if (!candidates->sorted) {
-        std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
+    if (!cur_p->sorted) {
+        std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
             return a.logit > b.logit;
         });
-        candidates->sorted = true;
+        cur_p->sorted = true;
     }
 
-    float max_l = candidates->data[0].logit;
+    float max_l = cur_p->data[0].logit;
     float cum_sum = 0.0f;
-    for (size_t i = 0; i < candidates->size; ++i) {
-        float p = expf(candidates->data[i].logit - max_l);
-        candidates->data[i].p = p;
+
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        float p = expf(cur_p->data[i].logit - max_l);
+        cur_p->data[i].p = p;
         cum_sum += p;
     }
-    for (size_t i = 0; i < candidates->size; ++i) {
-        candidates->data[i].p /= cum_sum;
-    }
 
-    if (smpl) {
-        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        cur_p->data[i].p /= cum_sum;
     }
 }
 
-void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
+static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
     // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
-    // if (k >= (int32_t)candidates->size) {
+    // if (k >= (int32_t)cur_p->size) {
     //     return;
     // }
 
-    const int64_t t_start_sample_us = ggml_time_us();
-
     if (k <= 0) {
-        k = candidates->size;
+        k = cur_p->size;
     }
 
-    k = std::max(k, (int) min_keep);
-    k = std::min(k, (int) candidates->size);
+    k = std::min(k, (int) cur_p->size);
 
     // Sort scores in descending order
-    if (!candidates->sorted) {
+    if (!cur_p->sorted) {
         auto comp = [](const llama_token_data & a, const llama_token_data & b) {
             return a.logit > b.logit;
         };
         if (k <= 128) {
-            std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
+            std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp);
         } else {
             constexpr int   nbuckets     = 128;
             constexpr float bucket_low   = -10.0f;
@@ -87,11 +114,11 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
             constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
             constexpr float bucket_inter = -bucket_low * bucket_scale;
 
-            std::vector<int> bucket_idx(candidates->size);
+            std::vector<int> bucket_idx(cur_p->size);
             std::vector<int> histo(nbuckets, 0);
 
-            for (int i = 0; i < (int)candidates->size; ++i) {
-                const float val = candidates->data[i].logit;
+            for (int i = 0; i < (int)cur_p->size; ++i) {
+                const float val = cur_p->data[i].logit;
                 int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
                 ib = std::max(0, std::min(nbuckets-1, ib));
                 bucket_idx[i] = ib;
@@ -101,20 +128,22 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
             int ib = nbuckets - 1;
             for ( ; ib >= 0; --ib) {
                 nhave += histo[ib];
-                if (nhave >= k) break;
+                if (nhave >= k) {
+                    break;
+                }
             }
             std::vector<llama_token_data> tmp_tokens(nhave);
-            auto ptr = tmp_tokens.data();
+            auto ptr = tmp_tokens.data();
             std::vector<llama_token_data*> bucket_ptrs;
             bucket_ptrs.reserve(nbuckets - ib);
             for (int j = nbuckets - 1; j >= ib; --j) {
                 bucket_ptrs.push_back(ptr);
                 ptr += histo[j];
             }
-            for (int i = 0; i < (int)candidates->size; ++i) {
+            for (int i = 0; i < (int)cur_p->size; ++i) {
                 int j = bucket_idx[i];
                 if (j >= ib) {
-                    *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i];
+                    *bucket_ptrs[nbuckets-1-j]++ = cur_p->data[i];
                 }
             }
 
@@ -127,125 +156,582 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
             }
             std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
 
-            std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data));
+            std::memcpy(cur_p->data, tmp_tokens.data(), k*sizeof(llama_token_data));
 
         }
-        candidates->sorted = true;
+        cur_p->sorted = true;
     }
-    candidates->size = k;
+    cur_p->size = k;
+}
 
-    if (smpl) {
-        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+static uint32_t get_rng_seed(uint32_t seed) {
+    if (seed == LLAMA_DEFAULT_SEED) {
+        // use system clock if std::random_device is not a true RNG
+        static bool is_rd_prng = std::random_device().entropy() == 0;
+        if (is_rd_prng) {
+            return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count();
+        }
+        std::random_device rd;
+        return rd();
     }
+    return seed;
 }
 
-void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
-    if (p >= 1.0f) {
+// llama_sampler API
+
+const char * llama_sampler_name(const struct llama_sampler * smpl) {
+    if (!smpl->iface) {
+        return "(null)";
+    }
+
+    return smpl->iface->name(smpl);
+}
+
+void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
+    if (smpl->iface->accept) {
+        smpl->iface->accept(smpl, token);
+    }
+}
+
+void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
+    GGML_ASSERT(smpl->iface->apply);
+    smpl->iface->apply(smpl, cur_p);
+}
+
+void llama_sampler_reset(struct llama_sampler * smpl) {
+    if (smpl->iface->reset) {
+        smpl->iface->reset(smpl);
+    }
+}
+
+struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
+    if (smpl->iface->clone) {
+        return smpl->iface->clone(smpl);
+    }
+
+    if (smpl->ctx == nullptr) {
+        return new llama_sampler {
+            /* .iface = */ smpl->iface,
+            /* .ctx   = */ nullptr,
+        };
+    }
+
+    GGML_ABORT("the sampler does not support cloning");
+}
+
+void llama_sampler_free(struct llama_sampler * smpl) {
+    if (smpl == nullptr) {
         return;
     }
 
-    llama_sample_softmax_impl(smpl, candidates);
+    if (smpl->iface->free) {
+        smpl->iface->free(smpl);
+    }
+
+    delete smpl;
+}
+
+llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
+    const auto * logits = llama_get_logits_ith(ctx, idx);
+
+    const int n_vocab = llama_n_vocab(llama_get_model(ctx));
+
+    // TODO: do not allocate each time
+    std::vector<llama_token_data> cur;
+    cur.reserve(n_vocab);
+    for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+        cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+    }
+
+    llama_token_data_array cur_p = {
+        /* .data       = */ cur.data(),
+        /* .size       = */ cur.size(),
+        /* .selected   = */ -1,
+        /* .sorted     = */ false,
+    };
+
+    llama_sampler_apply(smpl, &cur_p);
+
+    GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
+
+    auto token = cur_p.data[cur_p.selected].id;
+
+    llama_sampler_accept(smpl, token);
+
+    return token;
+}
+
+// sampler chain
+
+static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
+    return "chain";
+}
+
+static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) {
+    auto * chain = (llama_sampler_chain *) smpl->ctx;
+
+    time_meas tm(chain->t_sample_us, chain->params.no_perf);
+
+    for (auto * smpl : chain->samplers) {
+        llama_sampler_accept(smpl, token);
+    }
+
+    chain->n_sample++;
+}
+
+static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    auto * chain = (llama_sampler_chain *) smpl->ctx;
+
+    time_meas tm(chain->t_sample_us, chain->params.no_perf);
+
+    for (auto * smpl : chain->samplers) {
+        llama_sampler_apply(smpl, cur_p);
+    }
+}
+
+static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
+    auto * chain = (llama_sampler_chain *) smpl->ctx;
+
+    for (auto * smpl : chain->samplers) {
+        llama_sampler_reset(smpl);
+    }
+
+    chain->t_sample_us = 0;
+    chain->n_sample    = 0;
+}
+
+static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
+    const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
+
+    auto * result = llama_sampler_chain_init(chain_src->params);
+
+    for (auto * smpl : chain_src->samplers) {
+        llama_sampler_chain_add(result, llama_sampler_clone(smpl));
+    }
+
+    return result;
+}
+
+static void llama_sampler_chain_free(struct llama_sampler * smpl) {
+    auto * chain = (llama_sampler_chain *) smpl->ctx;
+
+    for (auto * smpl : chain->samplers) {
+        llama_sampler_free(smpl);
+    }
+
+    delete chain;
+}
+
+static struct llama_sampler_i llama_sampler_chain_i = {
+    /* .name   = */ llama_sampler_chain_name,
+    /* .accept = */ llama_sampler_chain_accept,
+    /* .apply  = */ llama_sampler_chain_apply,
+    /* .reset  = */ llama_sampler_chain_reset,
+    /* .clone  = */ llama_sampler_chain_clone,
+    /* .free   = */ llama_sampler_chain_free,
+};
+
+struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_chain_i,
+        /* .ctx   = */ new llama_sampler_chain {
+            /* .params      = */ params,
+            /* .samplers    = */ {},
+            /* .t_sample_us = */ 0,
+            /* .n_sample    = */ 0,
+        },
+    };
+}
+
+void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
+    auto * p = (llama_sampler_chain *) chain->ctx;
+    p->samplers.push_back(smpl);
+}
+
+struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
+    const auto * p = (const llama_sampler_chain *) chain->ctx;
+
+    if (i < 0 || (size_t) i >= p->samplers.size()) {
+        return nullptr;
+    }
+
+    return p->samplers[i];
+}
+
+struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
+    auto * p = (llama_sampler_chain *) chain->ctx;
+
+    if (i < 0 || (size_t) i >= p->samplers.size()) {
+        return nullptr;
+    }
+
+    auto * result = p->samplers[i];
+    p->samplers.erase(p->samplers.begin() + i);
+
+    return result;
+}
+
+int llama_sampler_chain_n(const struct llama_sampler * chain) {
+    const auto * p = (const llama_sampler_chain *) chain->ctx;
+
+    return p->samplers.size();
+}
+
+//
+// samplers
+//
+
+// greedy
+
+static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) {
+    return "greedy";
+}
+
+static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
+    cur_p->selected = 0;
+    for (size_t i = 1; i < cur_p->size; ++i) {
+        if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) {
+            cur_p->selected = i;
+        }
+    }
+}
+
+static struct llama_sampler_i llama_sampler_greedy_i = {
+    /* .name   = */ llama_sampler_greedy_name,
+    /* .accept = */ nullptr,
+    /* .apply  = */ llama_sampler_greedy_apply,
+    /* .reset  = */ nullptr,
+    /* .clone  = */ nullptr,
+    /* .free   = */ nullptr,
+};
+
+struct llama_sampler * llama_sampler_init_greedy() {
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_greedy_i,
+        /* .ctx   = */ nullptr,
+    };
+}
+
+// dist
+
+struct llama_sampler_dist {
+    const uint32_t seed;
+          uint32_t seed_cur;
+
+    std::mt19937 rng;
+};
+
+static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) {
+    return "dist";
+}
+
+static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    auto * ctx = (llama_sampler_dist *) smpl->ctx;
+    cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
+}
+
+static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
+    auto * result = llama_sampler_init_dist(ctx->seed);
+
+    // copy the state
+    {
+        auto * result_ctx = (llama_sampler_dist *) result->ctx;
+
+        result_ctx->rng = ctx->rng;
+    }
+
+    return result;
+}
+
+static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
+    auto * ctx = (llama_sampler_dist *) smpl->ctx;
+    ctx->seed_cur = get_rng_seed(ctx->seed);
+    ctx->rng.seed(ctx->seed_cur);
+}
+
+static void llama_sampler_dist_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_dist *) smpl->ctx;
+}
+
+static struct llama_sampler_i llama_sampler_dist_i = {
+    /* .name   = */ llama_sampler_dist_name,
+    /* .accept = */ nullptr,
+    /* .apply  = */ llama_sampler_dist_apply,
+    /* .reset  = */ llama_sampler_dist_reset,
+    /* .clone  = */ llama_sampler_dist_clone,
+    /* .free   = */ llama_sampler_dist_free,
+};
+
+struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
+    auto seed_cur = get_rng_seed(seed);
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_dist_i,
+        /* .ctx   = */ new llama_sampler_dist {
+            /* .seed     = */ seed,
+            /* .seed_cur = */ seed_cur,
+            /* .rng      = */ std::mt19937(seed_cur),
+        },
+    };
+}
+
+// softmax
+
+static const char * llama_sampler_softmax_name(const struct llama_sampler * /*smpl*/) {
+    return "softmax";
+}
+
+static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
+    llama_sampler_softmax_impl(cur_p);
+}
+
+static struct llama_sampler_i llama_sampler_softmax_i = {
+    /* .name   = */ llama_sampler_softmax_name,
+    /* .accept = */ nullptr,
+    /* .apply  = */ llama_sampler_softmax_apply,
+    /* .reset  = */ nullptr,
+    /* .clone  = */ nullptr,
+    /* .free   = */ nullptr,
+};
+
+struct llama_sampler * llama_sampler_init_softmax() {
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_softmax_i,
+        /* .ctx   = */ nullptr,
+    };
+}
+
+// top-k
+
+struct llama_sampler_top_k {
+    const int32_t k;
+};
+
+static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) {
+    return "top-k";
+}
+
+static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    const auto * ctx = (llama_sampler_top_k *) smpl->ctx;
+    llama_sampler_top_k_impl(cur_p, ctx->k);
+}
+
+static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
+    return llama_sampler_init_top_k(ctx->k);
+}
+
+static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_top_k *) smpl->ctx;
+}
+
+static struct llama_sampler_i llama_sampler_top_k_i = {
+    /* .name   = */ llama_sampler_top_k_name,
+    /* .accept = */ nullptr,
+    /* .apply  = */ llama_sampler_top_k_apply,
+    /* .reset  = */ nullptr,
+    /* .clone  = */ llama_sampler_top_k_clone,
+    /* .free   = */ llama_sampler_top_k_free,
+};
+
+struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_top_k_i,
+        /* .ctx   = */ new llama_sampler_top_k {
+            /* .k = */ k,
+        },
+    };
+}
+
+// top-p
+
+struct llama_sampler_top_p {
+    const float  p;
+    const size_t min_keep;
+};
+
+static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) {
+    return "top-p";
+}
 
-    const int64_t t_start_sample_us = ggml_time_us();
+static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    const auto * ctx = (llama_sampler_top_p *) smpl->ctx;
+
+    if (ctx->p >= 1.0f) {
+        return;
+    }
+
+    llama_sampler_softmax_impl(cur_p);
 
     // Compute the cumulative probabilities
     float cum_sum = 0.0f;
-    size_t last_idx = candidates->size;
+    size_t last_idx = cur_p->size;
 
-    for (size_t i = 0; i < candidates->size; ++i) {
-        cum_sum += candidates->data[i].p;
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        cum_sum += cur_p->data[i].p;
 
         // Check if the running sum is at least p or if we have kept at least min_keep tokens
         // we set the last index to i+1 to indicate that the current iterate should be included in the set
-        if (cum_sum >= p && i + 1 >= min_keep) {
+        if (cum_sum >= ctx->p && i + 1 >= ctx->min_keep) {
             last_idx = i + 1;
             break;
         }
     }
 
     // Resize the output vector to keep only the top-p tokens
-    candidates->size = last_idx;
+    cur_p->size = last_idx;
+}
 
-    if (smpl) {
-        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
+static struct llama_sampler * llama_sampler_top_p_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_top_p *) smpl->ctx;
+    return llama_sampler_init_top_p(ctx->p, ctx->min_keep);
+}
+
+static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_top_p *) smpl->ctx;
+}
+
+static struct llama_sampler_i llama_sampler_top_p_i = {
+    /* .name   = */ llama_sampler_top_p_name,
+    /* .accept = */ nullptr,
+    /* .apply  = */ llama_sampler_top_p_apply,
+    /* .reset  = */ nullptr,
+    /* .clone  = */ llama_sampler_top_p_clone,
+    /* .free   = */ llama_sampler_top_p_free,
+};
+
+struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_top_p_i,
+        /* .ctx   = */ new llama_sampler_top_p {
+            /* .p        = */ p,
+            /* .min_keep = */ min_keep,
+        },
+    };
+}
+
+// min-p
+
+struct llama_sampler_min_p {
+    const float  p;
+    const size_t min_keep;
+};
+
+static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) {
+    return "min-p";
 }
 
-void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
-    if (p <= 0.0f || !candidates->size) {
+static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    const auto * ctx = (llama_sampler_min_p *) smpl->ctx;
+
+    if (ctx->p <= 0.0f || !cur_p->size) {
         return;
     }
 
-    const int64_t t_start_sample_us = ggml_time_us();
-
     bool min_p_applied = false;
 
-    // if the candidates aren't sorted, try the unsorted implementation first
-    if (!candidates->sorted) {
+    // if the cur_p aren't sorted, try the unsorted implementation first
+    if (!cur_p->sorted) {
         std::vector<llama_token_data> filtered_tokens;
 
         float max_logit = -FLT_MAX;
-        for (size_t i = 0; i < candidates->size; ++i) {
-            max_logit = std::max(max_logit, candidates->data[i].logit);
+        for (size_t i = 0; i < cur_p->size; ++i) {
+            max_logit = std::max(max_logit, cur_p->data[i].logit);
         }
-        const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
+        const float min_logit = max_logit + logf(ctx->p); // min logit for p_i >= p * p_max
 
-        for (size_t i = 0; i < candidates->size; ++i) {
-            if (candidates->data[i].logit >= min_logit) {
-                filtered_tokens.push_back(candidates->data[i]);
+        for (size_t i = 0; i < cur_p->size; ++i) {
+            if (cur_p->data[i].logit >= min_logit) {
+                filtered_tokens.push_back(cur_p->data[i]);
             }
         }
 
         // if we have enough values the operation was a success
-        if (filtered_tokens.size() >= min_keep) {
-            memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
-            candidates->size = filtered_tokens.size();
+        if (filtered_tokens.size() >= ctx->min_keep) {
+            memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
+            cur_p->size = filtered_tokens.size();
             min_p_applied = true;
         }
     }
 
-    // if the candidates are sorted or the unsorted implementation failed, use this implementation
+    // if the cur_p are sorted or the unsorted implementation failed, use this implementation
     if (!min_p_applied) {
         // Sort the logits in descending order
-        if (!candidates->sorted) {
-            std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
+        if (!cur_p->sorted) {
+            std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
                 return a.logit > b.logit;
             });
-            candidates->sorted = true;
+            cur_p->sorted = true;
         }
 
-        const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max
+        const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max
         size_t i = 1; // first token always matches
 
-        for (; i < candidates->size; ++i) {
-            if (candidates->data[i].logit < min_logit && i >= min_keep) {
+        for (; i < cur_p->size; ++i) {
+            if (cur_p->data[i].logit < min_logit && i >= ctx->min_keep) {
                 break; // prob too small
             }
         }
 
         // Resize the output vector to keep only the matching tokens
-        candidates->size = i;
+        cur_p->size = i;
     }
+}
 
-    if (smpl) {
-        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
+static struct llama_sampler * llama_sampler_min_p_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_min_p *) smpl->ctx;
+    return llama_sampler_init_min_p(ctx->p, ctx->min_keep);
 }
 
-void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
-    if (z >= 1.0f || candidates->size <= 2) {
+static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_min_p *) smpl->ctx;
+}
+
+static struct llama_sampler_i llama_sampler_min_p_i = {
+    /* .name   = */ llama_sampler_min_p_name,
+    /* .accept = */ nullptr,
+    /* .apply  = */ llama_sampler_min_p_apply,
+    /* .reset  = */ nullptr,
+    /* .clone  = */ llama_sampler_min_p_clone,
+    /* .free   = */ llama_sampler_min_p_free,
+};
+
+struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_min_p_i,
+        /* .ctx   = */ new llama_sampler_min_p {
+            /* .p        = */ p,
+            /* .min_keep = */ min_keep,
+        },
+    };
+}
+
+// tail-free
+
+struct llama_sampler_tail_free {
+    const float  z;
+    const size_t min_keep;
+};
+
+static const char * llama_sampler_tail_free_name(const struct llama_sampler * /*smpl*/) {
+    return "tail-free";
+}
+
+static void llama_sampler_tail_free_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    const auto * ctx = (llama_sampler_tail_free *) smpl->ctx;
+
+    if (ctx->z >= 1.0f || cur_p->size <= 2) {
         return;
     }
 
-    llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
-    const int64_t t_start_sample_us = ggml_time_us();
+    llama_sampler_softmax_impl(cur_p);
 
     // Compute the first and second derivatives
-    std::vector<float> first_derivatives(candidates->size - 1);
-    std::vector<float> second_derivatives(candidates->size - 2);
+    std::vector<float> first_derivatives(cur_p->size - 1);
+    std::vector<float> second_derivatives(cur_p->size - 2);
 
     for (size_t i = 0; i < first_derivatives.size(); ++i) {
-        first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p;
+        first_derivatives[i] = cur_p->data[i].p - cur_p->data[i + 1].p;
     }
     for (size_t i = 0; i < second_derivatives.size(); ++i) {
         second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
@@ -272,51 +758,86 @@ void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_
     }
 
     float cum_sum = 0.0f;
-    size_t last_idx = candidates->size;
+    size_t last_idx = cur_p->size;
     for (size_t i = 0; i < second_derivatives.size(); ++i) {
         cum_sum += second_derivatives[i];
 
         // Check if the running sum is greater than z or if we have kept at least min_keep tokens
-        if (cum_sum > z && i >= min_keep) {
+        if (cum_sum > ctx->z && i >= ctx->min_keep) {
             last_idx = i;
             break;
         }
     }
 
     // Resize the output vector to keep only the tokens above the tail location
-    candidates->size = last_idx;
+    cur_p->size = last_idx;
+}
 
-    if (smpl) {
-        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
+static struct llama_sampler * llama_sampler_tail_free_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_tail_free *) smpl->ctx;
+    return llama_sampler_init_tail_free(ctx->z, ctx->min_keep);
+}
+
+static void llama_sampler_tail_free_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_tail_free *) smpl->ctx;
+}
+
+static struct llama_sampler_i llama_sampler_tail_free_i = {
+    /* .name   = */ llama_sampler_tail_free_name,
+    /* .accept = */ nullptr,
+    /* .apply  = */ llama_sampler_tail_free_apply,
+    /* .reset  = */ nullptr,
+    /* .clone  = */ llama_sampler_tail_free_clone,
+    /* .free   = */ llama_sampler_tail_free_free,
+};
+
+struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) {
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_tail_free_i,
+        /* .ctx   = */ new llama_sampler_tail_free {
+            /* .z        = */ z,
+            /*. min_keep = */ min_keep,
+        },
+    };
+}
+
+// typical
+
+struct llama_sampler_typical {
+    const float  p;
+    const size_t min_keep;
+};
+
+static const char * llama_sampler_typical_name(const struct llama_sampler * /*smpl*/) {
+    return "typical";
 }
 
-void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
+static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    const auto * ctx = (llama_sampler_typical *) smpl->ctx;
+
     // Reference implementation:
     // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
-    if (p >= 1.0f) {
+    if (ctx->p >= 1.0f) {
         return;
     }
 
     // Compute the softmax of logits and calculate entropy
-    llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
-
-    const int64_t t_start_sample_us = ggml_time_us();
+    llama_sampler_softmax_impl(cur_p);
 
     float entropy = 0.0f;
-    for (size_t i = 0; i < candidates->size; ++i) {
-        entropy += -candidates->data[i].p * logf(candidates->data[i].p);
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        entropy += -cur_p->data[i].p * logf(cur_p->data[i].p);
     }
 
     // Compute the absolute difference between negative log probability and entropy for each candidate
     std::vector<float> shifted_scores;
-    for (size_t i = 0; i < candidates->size; ++i) {
-        float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy);
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        float shifted_score = fabsf(-logf(cur_p->data[i].p) - entropy);
         shifted_scores.push_back(shifted_score);
     }
 
     // Sort tokens based on the shifted_scores and their corresponding indices
-    std::vector<size_t> indices(candidates->size);
+    std::vector<size_t> indices(cur_p->size);
     std::iota(indices.begin(), indices.end(), 0);
 
     std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
@@ -329,134 +850,618 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar
 
     for (size_t i = 0; i < indices.size(); ++i) {
         size_t idx = indices[i];
-        cum_sum += candidates->data[idx].p;
+        cum_sum += cur_p->data[idx].p;
 
         // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
-        if (cum_sum > p && i >= min_keep - 1) {
+        if (cum_sum > ctx->p && i >= ctx->min_keep - 1) {
             last_idx = i + 1;
             break;
         }
     }
 
     // Resize the output vector to keep only the locally typical tokens
-    std::vector<llama_token_data> new_candidates;
+    std::vector<llama_token_data> cur_p_new;
     for (size_t i = 0; i < last_idx; ++i) {
         size_t idx = indices[i];
-        new_candidates.push_back(candidates->data[idx]);
+        cur_p_new.push_back(cur_p->data[idx]);
     }
 
-    // Replace the data in candidates with the new_candidates data
-    std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
-    candidates->size = new_candidates.size();
-    candidates->sorted = false;
+    // Replace the data in cur_p with the cur_p_new data
+    std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data);
+    cur_p->size = cur_p_new.size();
+    cur_p->sorted = false;
+}
 
-    if (smpl) {
-        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
+static struct llama_sampler * llama_sampler_typical_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_typical *) smpl->ctx;
+    return llama_sampler_init_typical(ctx->p, ctx->min_keep);
 }
 
-void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
-    const int64_t t_start_sample_us = ggml_time_us();
+static void llama_sampler_typical_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_typical *) smpl->ctx;
+}
 
-    // no need to do anything if there is only one (or zero) candidates
-    if(candidates->size <= 1) {
-        return;
+static struct llama_sampler_i llama_sampler_typical_i = {
+    /* .name   = */ llama_sampler_typical_name,
+    /* .accept = */ nullptr,
+    /* .apply  = */ llama_sampler_typical_apply,
+    /* .reset  = */ nullptr,
+    /* .clone  = */ llama_sampler_typical_clone,
+    /* .free   = */ llama_sampler_typical_free,
+};
+
+struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_typical_i,
+        /* .ctx   = */ new llama_sampler_typical {
+            /* .p        = */ p,
+            /* .min_keep = */ min_keep,
+        },
+    };
+}
+
+// temp
+
+struct llama_sampler_temp {
+    const float temp;
+};
+
+static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) {
+    return "temp";
+}
+
+static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    const auto * ctx = (llama_sampler_temp *) smpl->ctx;
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        cur_p->data[i].logit /= ctx->temp;
     }
+}
 
-    // Calculate maximum possible entropy
-    float max_entropy = -logf(1.0f / candidates->size);
+static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_temp *) smpl->ctx;
+    return llama_sampler_init_temp(ctx->temp);
+}
 
-    llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
+static void llama_sampler_temp_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_temp *) smpl->ctx;
+}
 
-    // Calculate entropy of the softmax probabilities
-    float entropy = 0.0f;
-    for (size_t i = 0; i < candidates->size; ++i) {
-        float prob = candidates->data[i].p;
-        if (prob > 0.0f) { // Ensure no log(0)
-            entropy -= prob * logf(prob);
+static struct llama_sampler_i llama_sampler_temp_i = {
+    /* .name   = */ llama_sampler_temp_name,
+    /* .accept = */ nullptr,
+    /* .apply  = */ llama_sampler_temp_apply,
+    /* .reset  = */ nullptr,
+    /* .clone  = */ llama_sampler_temp_clone,
+    /* .free   = */ llama_sampler_temp_free,
+};
+
+struct llama_sampler * llama_sampler_init_temp(float temp) {
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_temp_i,
+        /* .ctx   = */ new llama_sampler_temp {
+            /*.temp = */ temp,
+        },
+    };
+}
+
+// temp-ext
+
+struct llama_sampler_temp_ext {
+    const float temp;
+    const float delta;
+    const float exponent;
+};
+
+static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) {
+    return "temp-ext";
+}
+
+static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
+    if (ctx->delta > 0) {
+        const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
+        const float max_temp = ctx->temp + ctx->delta;
+        float exponent_val = ctx->exponent;
+
+        // no need to do anything if there is only one (or zero) candidates
+        if (cur_p->size <= 1) {
+            return;
+        }
+
+        // Calculate maximum possible entropy
+        float max_entropy = -logf(1.0f / cur_p->size);
+
+        llama_sampler_softmax_impl(cur_p);
+
+        // Calculate entropy of the softmax probabilities
+        float entropy = 0.0f;
+        for (size_t i = 0; i < cur_p->size; ++i) {
+            float prob = cur_p->data[i].p;
+            if (prob > 0.0f) { // Ensure no log(0)
+                entropy -= prob * logf(prob);
+            }
+        }
+
+        // Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above)
+        float normalized_entropy = entropy / max_entropy;
+
+        // Map the normalized entropy to the desired temperature range using the power function
+        float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
+
+    #ifdef DEBUG
+        LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
+        LLAMA_LOG_INFO("Entropy: %f\n", entropy);
+        LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
+        LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
+        LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
+        LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
+    #endif
+
+        // Apply the dynamically calculated temperature scaling
+        for (size_t i = 0; i < cur_p->size; ++i) {
+            cur_p->data[i].logit /= dyn_temp;
+        }
+
+        // Re-compute softmax probabilities after scaling logits with dynamic temperature
+        const double max_l_double = cur_p->data[0].logit;
+
+        double cum_sum_double = 0.0;
+        for (size_t i = 0; i < cur_p->size; ++i) {
+            double p = exp(cur_p->data[i].logit - max_l_double);
+            cur_p->data[i].p = p; // Store the scaled probability
+            cum_sum_double += p;
+        }
+
+        for (size_t i = 0; i < cur_p->size; ++i) {
+            cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities
+        }
+
+    #ifdef DEBUG
+        // Print the updated top 25 probabilities after temperature scaling
+        LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
+        for (size_t i = 0; i < 25 && i < cur_p->size; ++i) {
+            LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f);
+        }
+    #endif
+    } else {
+        for (size_t i = 0; i < cur_p->size; ++i) {
+            cur_p->data[i].logit /= ctx->temp;
         }
     }
+}
 
-    // Normalize the entropy (max_entropy cannot be 0 here because we checked candidates->size != 1 above)
-    float normalized_entropy = entropy / max_entropy;
+static struct llama_sampler * llama_sampler_temp_ext_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx;
+    return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent);
+}
 
-    // Map the normalized entropy to the desired temperature range using the power function
-    float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
+static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_temp_ext *) smpl->ctx;
+}
 
-#ifdef DEBUG
-    LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
-    LLAMA_LOG_INFO("Entropy: %f\n", entropy);
-    LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
-    LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
-    LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
-    LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
-#endif
+static struct llama_sampler_i llama_sampler_temp_ext_i = {
+    /* .name   = */ llama_sampler_temp_ext_name,
+    /* .accept = */ nullptr,
+    /* .apply  = */ llama_sampler_temp_ext_apply,
+    /* .reset  = */ nullptr,
+    /* .clone  = */ llama_sampler_temp_ext_clone,
+    /* .free   = */ llama_sampler_temp_ext_free,
+};
+
+struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_temp_ext_i,
+        /* .ctx   = */ new llama_sampler_temp_ext {
+            /* .temp     = */ temp,
+            /* .delta    = */ delta,
+            /* .exponent = */ exponent,
+        },
+    };
+}
+
+// mirostat
+
+struct llama_sampler_mirostat {
+    const int32_t n_vocab;
+
+    const uint32_t seed;
+          uint32_t seed_cur;
+
+    const float tau;
+    const float eta;
+
+    const int32_t m;
+
+    float mu;
 
-    // Apply the dynamically calculated temperature scaling
-    for (size_t i = 0; i < candidates->size; ++i) {
-        candidates->data[i].logit /= dyn_temp;
+    std::mt19937 rng;
+};
+
+static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) {
+    return "mirostat";
+}
+
+static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
+
+    llama_sampler_softmax_impl(cur_p);
+
+    // Estimate s_hat using the most probable m tokens
+    float s_hat = 0.0;
+    float sum_ti_bi = 0.0;
+    float sum_ti_sq = 0.0;
+    for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) {
+        float t_i = logf(float(i + 2) / float(i + 1));
+        float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p);
+        sum_ti_bi += t_i * b_i;
+        sum_ti_sq += t_i * t_i;
     }
+    s_hat = sum_ti_bi / sum_ti_sq;
+
+    // Compute k from the estimated s_hat and target surprise value
+    float epsilon_hat = s_hat - 1;
+    float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat);
+
+    llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
+    llama_sampler_softmax_impl(cur_p);
+
+    const int idx = llama_sample_dist(cur_p, ctx->rng);
+
+    cur_p->selected = idx;
+
+    float observed_surprise = -log2f(cur_p->data[idx].p);
+    float e = observed_surprise - ctx->tau;
 
-    // Re-compute softmax probabilities after scaling logits with dynamic temperature
-    double max_l_double = candidates->data[0].logit;
-    double cum_sum_double = 0.0;
-    for (size_t i = 0; i < candidates->size; ++i) {
-        double p = exp(candidates->data[i].logit - max_l_double);
-        candidates->data[i].p = p; // Store the scaled probability
-        cum_sum_double += p;
+    // Update mu using the learning rate and error
+    ctx->mu = ctx->mu - ctx->eta * e;
+}
+
+static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx;
+    auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
+
+    // copy the state
+    {
+        auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx;
+
+        result_ctx->mu  = ctx->mu;
+        result_ctx->rng = ctx->rng;
     }
-    for (size_t i = 0; i < candidates->size; ++i) {
-        candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities
+
+    return result;
+}
+
+static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) {
+    auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
+    ctx->mu = 2.0f*ctx->tau;
+    ctx->seed_cur = get_rng_seed(ctx->seed);
+    ctx->rng.seed(ctx->seed_cur);
+}
+
+static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_mirostat *) smpl->ctx;
+}
+
+static struct llama_sampler_i llama_sampler_mirostat_i = {
+    /* .name   = */ llama_sampler_mirostat_name,
+    /* .accept = */ nullptr,
+    /* .apply  = */ llama_sampler_mirostat_apply,
+    /* .reset  = */ llama_sampler_mirostat_reset,
+    /* .clone  = */ llama_sampler_mirostat_clone,
+    /* .free   = */ llama_sampler_mirostat_free,
+};
+
+struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
+    auto seed_cur = get_rng_seed(seed);
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_mirostat_i,
+        /* .ctx   = */ new llama_sampler_mirostat {
+            /* .n_vocab  = */ n_vocab,
+            /* .seed     = */ seed,
+            /* .seed_cur = */ seed_cur,
+            /* .tau      = */ tau,
+            /* .eta      = */ eta,
+            /* .m        = */ m,
+            /* .mu       = */ 2.0f*tau,
+            /* .rng      = */ std::mt19937(seed_cur),
+        },
+    };
+}
+
+// mirostat v2
+
+struct llama_sampler_mirostat_v2 {
+    const uint32_t seed;
+          uint32_t seed_cur;
+
+    const float tau;
+    const float eta;
+
+    float mu;
+
+    std::mt19937 rng;
+};
+
+static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) {
+    return "mirostat-v2";
+}
+
+static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
+
+    llama_sampler_softmax_impl(cur_p);
+
+    // Truncate the words with surprise values greater than mu
+    cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
+        return -log2f(candidate.p) > ctx->mu;
+    }));
+
+    if (cur_p->size == 0) {
+        cur_p->size = 1;
     }
 
-#ifdef DEBUG
-    // Print the updated top 25 probabilities after temperature scaling
-    LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
-    for (size_t i = 0; i < 25 && i < candidates->size; ++i) {
-        LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f);
+    // Normalize the probabilities of the remaining words
+    llama_sampler_softmax_impl(cur_p);
+
+    const int idx = llama_sample_dist(cur_p, ctx->rng);
+
+    cur_p->selected = idx;
+
+    float observed_surprise = -log2f(cur_p->data[idx].p);
+    float e = observed_surprise - ctx->tau;
+
+    // Update mu using the learning rate and error
+    ctx->mu = ctx->mu - ctx->eta * e;
+}
+
+static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) {
+    auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
+    ctx->mu = 2.0f*ctx->tau;
+    ctx->seed_cur = get_rng_seed(ctx->seed);
+    ctx->rng.seed(ctx->seed_cur);
+}
+
+static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx;
+
+    auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
+
+    // copy the state
+    {
+        auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx;
+
+        result_ctx->mu  = ctx->mu;
+        result_ctx->rng = ctx->rng;
     }
-#endif
 
-    if (smpl) {
-        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+    return result;
+}
+
+static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_mirostat_v2 *) smpl->ctx;
+}
+
+static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
+    /* .name   = */ llama_sampler_mirostat_v2_name,
+    /* .accept = */ nullptr,
+    /* .apply  = */ llama_sampler_mirostat_v2_apply,
+    /* .reset  = */ llama_sampler_mirostat_v2_reset,
+    /* .clone  = */ llama_sampler_mirostat_v2_clone,
+    /* .free   = */ llama_sampler_mirostat_v2_free,
+};
+
+struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
+    auto seed_cur = get_rng_seed(seed);
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_mirostat_v2_i,
+        /* .ctx   = */ new llama_sampler_mirostat_v2 {
+            /* .seed     = */ seed,
+            /* .seed_cur = */ seed_cur,
+            /* .tau      = */ tau,
+            /* .eta      = */ eta,
+            /* .mu       = */ 2.0f*tau,
+            /* .rng      = */ std::mt19937(seed_cur),
+        },
+    };
+}
+
+// grammar
+
+struct llama_sampler_grammar {
+    const struct llama_vocab * vocab;
+
+    std::string grammar_str;
+    std::string grammar_root;
+
+    struct llama_grammar * grammar;
+};
+
+static const char * llama_sampler_grammar_name(const struct llama_sampler * /*smpl*/) {
+    return "grammar";
+}
+
+static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama_token token) {
+    auto * ctx = (llama_sampler_grammar *) smpl->ctx;
+    if (ctx->grammar) {
+        llama_grammar_accept_impl(*ctx->grammar, token);
+    }
+}
+
+static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    auto * ctx = (llama_sampler_grammar *) smpl->ctx;
+    if (ctx->grammar) {
+        llama_grammar_apply_impl(*ctx->grammar, cur_p);
+    }
+}
+
+static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
+    auto * ctx = (llama_sampler_grammar *) smpl->ctx;
+    if (!ctx->grammar) {
+        return;
     }
+
+    auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str());
+
+    llama_grammar_free_impl(ctx->grammar);
+    ctx->grammar = grammar_new;
 }
 
-void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) {
-    const int64_t t_start_sample_us = ggml_time_us();
+static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
+
+    auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr);
+
+    // copy the state
+    {
+        auto * result_ctx = (llama_sampler_grammar *) result->ctx;
 
-    for (size_t i = 0; i < candidates->size; ++i) {
-        candidates->data[i].logit /= temp;
+        if (ctx->grammar) {
+            result_ctx->grammar_str  = ctx->grammar_str;
+            result_ctx->grammar_root = ctx->grammar_root;
+
+            result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar);
+        }
+    }
+
+    return result;
+}
+
+static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
+    const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
+
+    if (ctx->grammar) {
+        llama_grammar_free_impl(ctx->grammar);
+    }
+
+    delete ctx;
+}
+
+static struct llama_sampler_i llama_sampler_grammar_i = {
+    /* .name   = */ llama_sampler_grammar_name,
+    /* .accept = */ llama_sampler_grammar_accept_impl,
+    /* .apply  = */ llama_sampler_grammar_apply,
+    /* .reset  = */ llama_sampler_grammar_reset,
+    /* .clone  = */ llama_sampler_grammar_clone,
+    /* .free   = */ llama_sampler_grammar_free,
+};
+
+struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
+    auto * ctx = new llama_sampler_grammar;
+
+    if (grammar_str != nullptr && grammar_str[0] != '\0') {
+        *ctx = {
+            /* .vocab        = */ &vocab,
+            /* .grammar_str  = */ grammar_str,
+            /* .grammar_root = */ grammar_root,
+            /* .grammar      = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
+        };
+    } else {
+        *ctx = {
+            /* .vocab        = */ &vocab,
+            /* .grammar_str  = */ {},
+            /* .grammar_root = */ {},
+            /* .grammar      = */ nullptr,
+        };
     }
 
-    if (smpl) {
-        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_grammar_i,
+        /* .ctx   = */ ctx,
+    };
+}
+
+// penalties
+
+struct llama_sampler_penalties {
+    const int32_t     n_vocab;
+    const llama_token special_eos_id;
+    const llama_token linefeed_id;
+
+    const int32_t penalty_last_n;
+    const float   penalty_repeat;
+    const float   penalty_freq;
+    const float   penalty_present;
+
+    const bool    penalize_nl;
+    const bool    ignore_eos;
+
+    ring_buffer<llama_token> prev;
+};
+
+static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
+    return "penalties";
+}
+
+static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_token token) {
+    auto * ctx = (llama_sampler_penalties *) smpl->ctx;
+    if (ctx->penalty_last_n == 0) {
+        return;
     }
+
+    ctx->prev.push_back(token);
 }
 
-void llama_sample_repetition_penalties_impl(
-        struct llama_sampling * smpl,
-       llama_token_data_array * candidates,
-            const llama_token * last_tokens,
-                       size_t   penalty_last_n,
-                       float   penalty_repeat,
-                       float   penalty_freq,
-                       float   penalty_present) {
-    if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
+static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    auto * ctx = (llama_sampler_penalties *) smpl->ctx;
+
+    if (ctx->ignore_eos) {
+        assert(ctx->special_eos_id >= 0);
+
+        // optimistically check if the candidates are not yet sorted/shuffled/truncated
+        if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) {
+            cur_p->data[ctx->special_eos_id].logit = -INFINITY;
+        } else {
+            // else, search for the special EOS token
+            for (size_t i = 0; i < cur_p->size; ++i) {
+                if (cur_p->data[i].id == ctx->special_eos_id) {
+                    cur_p->data[i].logit = -INFINITY;
+                    break;
+                }
+            }
+        }
+    }
+
+    if ((ctx->penalty_last_n == 0) ||
+        (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
         return;
     }
 
-    const int64_t t_start_sample_us = ggml_time_us();
+    bool nl_found = false;
+    size_t nl_idx = 0;
+    float nl_logit = -INFINITY;
+    if (!ctx->penalize_nl) {
+        assert(ctx->linefeed_id >= 0);
+
+        // optimistically check if the candidates are not yet sorted/shuffled/truncated
+        if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) {
+            nl_found = true;
+            nl_idx = ctx->linefeed_id;
+            nl_logit = cur_p->data[ctx->linefeed_id].logit;
+        } else {
+            // else, search for the linefeed token
+            for (size_t i = 0; i < cur_p->size; ++i) {
+                if (cur_p->data[i].id == ctx->linefeed_id) {
+                    nl_found = true;
+                    nl_idx = i;
+                    nl_logit = cur_p->data[i].logit;
+                    break;
+                }
+            }
+        }
+    }
 
     // Create a frequency map to count occurrences of each token in last_tokens
-    std::unordered_map<llama_token, int> token_count;
-    for (size_t i = 0; i < penalty_last_n; ++i) {
-        token_count[last_tokens[i]]++;
+    // TODO: optimize this by maintaining the token count in the sampler context
+    using llama_token_cnt = std::unordered_map<llama_token, int>;
+    llama_token_cnt token_count;
+
+    for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
+        token_count[ctx->prev.rat(i)]++;
     }
 
-    // Apply frequency and presence penalties to the candidates
-    for (size_t i = 0; i < candidates->size; ++i) {
-        const auto token_iter = token_count.find(candidates->data[i].id);
+    // Apply frequency and presence penalties to the cur_p
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        const auto token_iter = token_count.find(cur_p->data[i].id);
         if (token_iter == token_count.end()) {
             continue;
         }
@@ -465,171 +1470,238 @@ void llama_sample_repetition_penalties_impl(
 
         // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
         // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
-        if (candidates->data[i].logit <= 0) {
-            candidates->data[i].logit *= penalty_repeat;
+        if (cur_p->data[i].logit <= 0) {
+            cur_p->data[i].logit *= ctx->penalty_repeat;
         } else {
-            candidates->data[i].logit /= penalty_repeat;
+            cur_p->data[i].logit /= ctx->penalty_repeat;
         }
 
-        candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present;
+        cur_p->data[i].logit -= float(count) * ctx->penalty_freq + float(count > 0) * ctx->penalty_present;
     }
 
-    candidates->sorted = false;
+    cur_p->sorted = false;
 
-    if (smpl) {
-        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+    if (!ctx->penalize_nl && nl_found) {
+        // restore the logit of the newline token if it was penalized
+        cur_p->data[nl_idx].logit = nl_logit;
     }
 }
 
-void llama_sample_apply_guidance_impl(
-        struct llama_sampling * smpl,
-                        float * logits,
-                        float * logits_guidance,
-                        float   scale) {
-    GGML_ASSERT(smpl);
-
-    const auto t_start_sample_us = ggml_time_us();
-    const auto n_vocab = smpl->n_vocab;
-
-    llama_log_softmax(logits, n_vocab);
-    llama_log_softmax(logits_guidance, n_vocab);
+static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
+    auto * ctx = (llama_sampler_penalties *) smpl->ctx;
+    ctx->prev.clear();
+}
 
-    for (int i = 0; i < n_vocab; ++i) {
-              auto & l = logits[i];
-        const auto & g = logits_guidance[i];
+static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
+    auto * result = llama_sampler_init_penalties(
+            ctx->n_vocab,
+            ctx->special_eos_id,
+            ctx->linefeed_id,
+            ctx->penalty_last_n,
+            ctx->penalty_repeat,
+            ctx->penalty_freq,
+            ctx->penalty_present,
+            ctx->penalize_nl,
+            ctx->ignore_eos);
+
+    // copy the state
+    {
+        auto * result_ctx = (llama_sampler_penalties *) result->ctx;
 
-        l = scale * (l - g) + g;
+        result_ctx->prev = ctx->prev;
     }
 
-    smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+    return result;
 }
 
-llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
-    GGML_ASSERT(smpl);
-
-    const int32_t n_vocab = float(smpl->n_vocab);
-
-    int64_t t_start_sample_us = ggml_time_us();
+static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_penalties *) smpl->ctx;
+}
 
-    llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
+static struct llama_sampler_i llama_sampler_penalties_i = {
+    /* .name   = */ llama_sampler_penalties_name,
+    /* .accept = */ llama_sampler_penalties_accept,
+    /* .apply  = */ llama_sampler_penalties_apply,
+    /* .reset  = */ llama_sampler_penalties_reset,
+    /* .clone  = */ llama_sampler_penalties_clone,
+    /* .free   = */ llama_sampler_penalties_free,
+};
+
+struct llama_sampler * llama_sampler_init_penalties(
+        int32_t n_vocab,
+        llama_token special_eos_id,
+        llama_token linefeed_id,
+        int32_t penalty_last_n,
+        float penalty_repeat,
+        float penalty_freq,
+        float penalty_present,
+        bool penalize_nl,
+        bool ignore_eos) {
+    if (linefeed_id == LLAMA_TOKEN_NULL) {
+        penalize_nl = true;
+    }
 
-    // Estimate s_hat using the most probable m tokens
-    float s_hat = 0.0;
-    float sum_ti_bi = 0.0;
-    float sum_ti_sq = 0.0;
-    for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
-        float t_i = logf(float(i + 2) / float(i + 1));
-        float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
-        sum_ti_bi += t_i * b_i;
-        sum_ti_sq += t_i * t_i;
+    if (special_eos_id == LLAMA_TOKEN_NULL) {
+        ignore_eos = false;
     }
-    s_hat = sum_ti_bi / sum_ti_sq;
 
-    // Compute k from the estimated s_hat and target surprise value
-    float epsilon_hat = s_hat - 1;
-    float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
+    penalty_last_n = std::max(penalty_last_n, 0);
+
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_penalties_i,
+        /* .ctx   = */ new llama_sampler_penalties {
+            /* .n_vocab         = */ n_vocab,
+            /* .special_eos_id  = */ special_eos_id,
+            /* .linefeed_id     = */ linefeed_id,
+            /* .penalty_last_n  = */ penalty_last_n,
+            /* .penalty_repeat  = */ penalty_repeat,
+            /* .penalty_freq    = */ penalty_freq,
+            /* .penalty_present = */ penalty_present,
+            /* .penalize_nl     = */ penalize_nl,
+            /* .ignore_eos      = */ ignore_eos,
+            /* .prev            = */ ring_buffer<llama_token>(penalty_last_n),
+        },
+    };
+}
 
-    // Sample the next word X using top-k sampling
-    llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1);
-    smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
-    llama_token X = llama_sample_token_impl(smpl, candidates);
-    t_start_sample_us = ggml_time_us();
+// logit-bias
 
-    // Compute error as the difference between observed surprise and target surprise value
-    size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
-        return candidate.id == X;
-    }));
-    float observed_surprise = -log2f(candidates->data[X_idx].p);
-    float e = observed_surprise - tau;
+struct llama_sampler_logit_bias {
+    const int32_t n_vocab;
 
-    // Update mu using the learning rate and error
-    *mu = *mu - eta * e;
+    const std::vector<llama_logit_bias> logit_bias;
 
-    smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
-    return X;
+    std::vector<llama_logit_bias> to_search;
+};
+
+static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) {
+    return "logit-bias";
 }
 
-llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
-    int64_t t_start_sample_us;
-    t_start_sample_us = ggml_time_us();
+static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
+
+    if (ctx->logit_bias.empty()) {
+        return;
+    }
 
-    llama_sample_softmax_impl(smpl, candidates);
+    ctx->to_search.clear();
 
-    // Truncate the words with surprise values greater than mu
-    candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
-        return -log2f(candidate.p) > *mu;
-    }));
+    // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id)
+    for (const auto & lb : ctx->logit_bias) {
+        if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) {
+            cur_p->data[lb.token].logit += lb.bias;
+        } else {
+            ctx->to_search.push_back(lb);
+        }
+    }
 
-    if (candidates->size == 0) {
-        candidates->size = 1;
+    if (ctx->to_search.empty()) {
+        return;
     }
 
-    if (smpl) {
-        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+    // search for the remaining candidates that were not found in the previous step
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        for (const auto & lb : ctx->to_search) {
+            if (cur_p->data[i].id == lb.token) {
+                cur_p->data[i].logit += lb.bias;
+                break;
+            }
+        }
     }
+}
 
-    // Normalize the probabilities of the remaining words
-    llama_sample_softmax_impl(smpl, candidates);
+static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
+    return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
+}
 
-    // Sample the next word X from the remaining words
-    llama_token X = llama_sample_token_impl(smpl, candidates);
-    t_start_sample_us = ggml_time_us();
+static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_logit_bias *) smpl->ctx;
+}
 
-    // Compute error as the difference between observed surprise and target surprise value
-    size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
-        return candidate.id == X;
-    }));
-    float observed_surprise = -log2f(candidates->data[X_idx].p);
-    float e = observed_surprise - tau;
+static struct llama_sampler_i llama_sampler_logit_bias_i = {
+    /* .name   = */ llama_sampler_logit_bias_name,
+    /* .accept = */ nullptr,
+    /* .apply  = */ llama_sampler_logit_bias_apply,
+    /* .reset  = */ nullptr,
+    /* .clone  = */ llama_sampler_logit_bias_clone,
+    /* .free   = */ llama_sampler_logit_bias_free,
+};
+
+struct llama_sampler * llama_sampler_init_logit_bias(
+                         int32_t   n_vocab,
+                         int32_t   n_logit_bias,
+          const llama_logit_bias * logit_bias) {
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_logit_bias_i,
+        /* .ctx   = */ new llama_sampler_logit_bias {
+            /* .n_vocab    = */ n_vocab,
+            /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
+            /* .to_search  = */ {},
+        },
+    };
+}
 
-    // Update mu using the learning rate and error
-    *mu = *mu - eta * e;
+// utils
 
-    if (smpl) {
-        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
+    if (smpl->iface == &llama_sampler_dist_i) {
+        return ((const llama_sampler_dist *) smpl->ctx)->seed_cur;
     }
-    return X;
-}
 
-llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
-    const int64_t t_start_sample_us = ggml_time_us();
+    if (smpl->iface == &llama_sampler_mirostat_i) {
+        return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur;
+    }
 
-    // Find max element
-    auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
-        return a.logit < b.logit;
-    });
+    if (smpl->iface == &llama_sampler_mirostat_v2_i) {
+        return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur;
+    }
 
-    llama_token result = max_iter->id;
-    if (smpl) {
-        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
-        smpl->n_sample++;
+    if (smpl->iface == &llama_sampler_chain_i) {
+        const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
+        for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
+            const uint32_t seed = llama_sampler_get_seed(*it);
+            if (seed != LLAMA_DEFAULT_SEED) {
+                return seed;
+            }
+        }
     }
-    return result;
+
+    return LLAMA_DEFAULT_SEED;
 }
 
-llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
-    GGML_ASSERT(smpl);
+// perf
 
-    const int64_t t_start_sample_us = ggml_time_us();
-    llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
+struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * chain) {
+    struct llama_perf_sampler_data data = {};
 
-    std::vector<float> probs;
-    probs.reserve(candidates->size);
-    for (size_t i = 0; i < candidates->size; ++i) {
-        probs.push_back(candidates->data[i].p);
+    if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
+        GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
     }
 
-    std::discrete_distribution<> dist(probs.begin(), probs.end());
-    int idx = dist(rng);
+    const auto * ctx = (const struct llama_sampler_chain *) chain->ctx;
 
-    llama_token result = candidates->data[idx].id;
+    data.t_sample_ms = 1e-3 * ctx->t_sample_us;
+    data.n_sample    = std::max(0, ctx->n_sample);
 
-    smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
-    smpl->n_sample++;
+    return data;
+}
 
-    return result;
+void llama_perf_sampler_print(const struct llama_sampler * chain) {
+    const auto data = llama_perf_sampler(chain);
+
+    LLAMA_LOG_INFO("%s:    sampling time = %10.2f ms / %5d runs   (%8.2f ms per token, %8.2f tokens per second)\n",
+            __func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample);
 }
 
-llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
-    return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng);
+void llama_perf_sampler_reset(struct llama_sampler * chain) {
+    if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
+        GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
+    }
+
+    auto * ctx = (struct llama_sampler_chain *) chain->ctx;
+
+    ctx->t_sample_us = ctx->n_sample = 0;
 }
index f7f8e3ef706bc8d3dde7db5953811dd72fcad105..d90b147130e4b10d691d797dc99137569ca22a6e 100644 (file)
@@ -1,56 +1,29 @@
 #pragma once
 
-#include "llama-impl.h"
+// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ?
 
-struct llama_sampling {
-    llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {}
+#include "llama-grammar.h"
 
-    std::mt19937 rng;
+#include <unordered_map>
 
-    int32_t n_vocab = 0;
+struct llama_vocab;
+struct llama_grammar;
 
-    mutable int64_t t_sample_us = 0;
-    mutable int32_t n_sample = 0;
+// sampler chain
 
-    void reset_timings() const {
-        t_sample_us = 0;
-        n_sample = 0;
-    }
-};
+struct llama_sampler_chain {
+    llama_sampler_chain_params params;
+
+    std::vector<struct llama_sampler *> samplers;
+
+    // timing
 
-//
-// internal API
-//
-
-void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed);
-
-void llama_sample_softmax_impl  (struct llama_sampling * smpl, llama_token_data_array * candidates);
-void llama_sample_top_k_impl    (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
-void llama_sample_top_p_impl    (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
-void llama_sample_min_p_impl    (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
-void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep);
-void llama_sample_typical_impl  (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
-void llama_sample_entropy_impl  (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
-void llama_sample_temp_impl     (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp);
-
-void llama_sample_repetition_penalties_impl(
-        struct llama_sampling * smpl,
-       llama_token_data_array * candidates,
-            const llama_token * last_tokens,
-                       size_t   penalty_last_n,
-                        float   penalty_repeat,
-                        float   penalty_freq,
-                        float   penalty_present);
-
-void llama_sample_apply_guidance_impl(
-        struct llama_sampling * smpl,
-                        float * logits,
-                        float * logits_guidance,
-                        float   scale);
-
-llama_token llama_sample_token_mirostat_impl   (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu);
-llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu);
-llama_token llama_sample_token_greedy_impl     (struct llama_sampling * smpl, llama_token_data_array * candidates);
-llama_token llama_sample_token_with_rng_impl   (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
-llama_token llama_sample_token_impl            (struct llama_sampling * smpl, llama_token_data_array * candidates);
+    mutable int64_t t_sample_us;
+
+    mutable int32_t n_sample;
+};
 
+struct llama_sampler * llama_sampler_init_grammar_impl(
+        const struct llama_vocab & vocab,
+                      const char * grammar_str,
+                      const char * grammar_root);
index 323660ef54cb07a7c90f15d86620583375754fc3..a771eccda30172a3f2a9b23e4e96d4472524534b 100644 (file)
@@ -58,17 +58,17 @@ struct naive_trie {
         auto res = children.find(c);
         if (res != children.end()) {
             return res->second.get_longest_prefix(key, len, offset + 1);
-        } else {
-            return std::make_pair(key, offset);
         }
+
+        return std::make_pair(key, offset);
     }
-    struct naive_trie * traverse(const char c) {
+    const struct naive_trie * traverse(const char c) const {
         auto res = children.find(c);
         if (res != children.end()) {
             return &res->second;
-        } else {
-            return NULL;
         }
+
+        return NULL;
     }
     std::map<char, struct naive_trie> children;
     bool has_value;
@@ -843,7 +843,7 @@ struct llm_tokenizer_ugm {
             // traverse the token matcher trie to find a matching token
             bool single_codepoint_token_found = false;
             const struct best_tokenization & current_best = tokenization_results[input_offset];
-            struct naive_trie * node  = token_matcher.traverse(normalized[prefix_offset++]);
+            const struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
 
             while (prefix_offset <= input_len && node != NULL) {
                 // check if we found valid token in prefix
@@ -963,7 +963,7 @@ private:
     /*
      * This structure is a view wrapper for XOR-compressed double array (XCDA)
      * See Shunsuke Kanda (2018). Space- and Time-Efficient String Dictionaries.
-     * Eeach bit-packed entry contains:
+     * Each bit-packed entry contains:
      * - BASE array value in bits 10-30
      * - LCHECK array value in bits 0-7
      * - LEAF array value in bit 9
@@ -1097,6 +1097,111 @@ private:
     struct naive_trie token_matcher;
 };
 
+//
+// RWKV tokenizer
+//
+
+static std::vector<uint8_t> llama_unescape_rwkv_token(const std::string & escaped) {
+    std::vector<uint8_t> output;
+    output.reserve(escaped.size());
+
+    // Parser state
+    bool escaping = false;
+    uint8_t hex_remaining = 0;
+    uint8_t hex_acc = 0;
+
+    // Step through characters, performing parsing
+    for (const char & c : escaped) {
+        // If we're parsing a hex code, interpret the next character
+        if (hex_remaining != 0) {
+            uint8_t value = (c >= 'a') ? (c - 'a' + 10) : (c - '0');
+            hex_acc = (hex_acc << 4) + value;
+
+            hex_remaining -= 1;
+            if (hex_remaining == 0) {
+                output.push_back(hex_acc);
+                hex_acc = 0;
+            }
+
+            continue;
+        }
+
+        // If we got an escape character, interpret it
+        if (escaping) {
+            if (c == 't') {
+                output.push_back('\t');
+            } else if (c == 'n') {
+                output.push_back('\n');
+            } else if (c == 'r') {
+                output.push_back('\r');
+            } else if (c == 'x') {
+                hex_remaining = 2;
+            } else {
+                output.push_back(c);
+            }
+
+            escaping = false;
+            continue;
+        }
+
+        if (c == '\\') {
+            escaping = true;
+            continue;
+        }
+
+        output.push_back(c);
+    }
+
+    return output;
+}
+
+struct llm_tokenizer_rwkv {
+    llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) {
+        // RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
+        // For now, we decode the vocab here into the lookup we'll use for tokenization.
+
+        // build trie
+        for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) {
+            const auto & token = vocab.id_to_token[id];
+            const auto data = llama_unescape_rwkv_token(token.text);
+            token_matcher.insert((const char *) data.data(), data.size(), id);
+        }
+    }
+
+    void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
+        uint32_t position = 0;
+
+        while (position < text.size()) {
+            const struct naive_trie * node = token_matcher.traverse(text[position]);
+            if (node == NULL) {
+                // no matching token found, add unknown token
+                output.push_back(vocab.special_unk_id);
+                position += 1;
+                continue;
+            }
+
+            // traverse the trie to find the longest matching token
+            uint32_t token_id = 0;
+            uint32_t token_length = 0;
+            while (node != NULL) {
+                if (node->has_value) {
+                    token_id = node->value;
+                    token_length = position + 1;
+                }
+                node = node->traverse(text[++position]);
+            }
+
+            // add the longest matching token
+            output.push_back(token_id);
+            position = token_length;
+        }
+    }
+
+    const llama_vocab & vocab;
+
+    struct naive_trie token_matcher;
+};
+
 //
 // (de-) tokenize
 //
@@ -1401,6 +1506,23 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
                     output.push_back(vocab.special_eos_id);
                 }
             } break;
+        case LLAMA_VOCAB_TYPE_RWKV:
+            {
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
+#endif
+
+                        llm_tokenizer_rwkv tokenizer(vocab);
+                        tokenizer.tokenize(raw_text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+            } break;
         case LLAMA_VOCAB_TYPE_NONE:
             GGML_ABORT("fatal error");
     }
@@ -1448,11 +1570,7 @@ llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, lla
 }
 
 bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
-    return token != -1 && (
-        token == llama_token_eos_impl(vocab) ||
-        token == llama_token_eot_impl(vocab) ||
-        token == llama_token_eom_impl(vocab)
-    );
+    return token != -1 && vocab.special_eog_ids.count(token) > 0;
 }
 
 bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
@@ -1616,6 +1734,17 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
                 }
                 break;
             }
+            case LLAMA_VOCAB_TYPE_RWKV: {
+                std::vector<uint8_t> result = llama_unescape_rwkv_token(token_text);
+
+                // If we don't have enough space, return an error
+                if (result.size() > (size_t)length) {
+                    return -(int)result.size();
+                }
+
+                memcpy(buf, result.data(), result.size());
+                return (int)result.size();
+            }
             default:
                 GGML_ABORT("fatal error");
         }
index 6e8f30be43ba1cb2f93cd07d8189546998de49d1..cc46f642bf1ae371ce5fa2b23aa9ed3b44f5895d 100644 (file)
@@ -6,6 +6,7 @@
 #include <vector>
 #include <unordered_map>
 #include <map>
+#include <set>
 
 struct llama_vocab {
     using id    = llama_token;
@@ -18,6 +19,8 @@ struct llama_vocab {
         tattr attr;
     };
 
+    uint32_t n_vocab = 0; // TODO: not great because has to keep in sync with hparams.n_vocab
+
     enum llama_vocab_type     type     = LLAMA_VOCAB_TYPE_SPM;
     enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
 
@@ -47,12 +50,15 @@ struct llama_vocab {
     id special_eot_id    = -1; // TODO: move above after "eos_id", and here add "file separator" token
     id special_eom_id    = -1;
 
+    // set of all tokens that cause "end of generation"
+    std::set<id> special_eog_ids;
+
     // tokenizer flags
-    bool tokenizer_add_space_prefix = false;
-    bool tokenizer_add_bos          = false;
-    bool tokenizer_add_eos          = false;
-    bool tokenizer_ignore_merges    = false;
-    bool tokenizer_clean_spaces     = false;  // clean_up_tokenization_spaces
+    bool tokenizer_add_space_prefix           = false;
+    bool tokenizer_add_bos                    = false;
+    bool tokenizer_add_eos                    = false;
+    bool tokenizer_ignore_merges              = false;
+    bool tokenizer_clean_spaces               = false;  // clean_up_tokenization_spaces
     bool tokenizer_remove_extra_whitespaces   = false;
     bool tokenizer_escape_whitespaces         = true;
     bool tokenizer_treat_whitespace_as_suffix = false;
@@ -62,8 +68,6 @@ struct llama_vocab {
     int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
 };
 
-const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx);
-
 //
 // internal API
 //
@@ -76,6 +80,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(
         bool add_special,
         bool parse_special = false);
 
+// TODO: move the API below as member functions of llama_vocab
 llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
 
 const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
index 8d5f24783d6aba1d5fad15718f74c4586aaf1410..a718de054f934697aca987145cf21e26d29199e2 100644 (file)
@@ -1,6 +1,5 @@
 #include "llama-impl.h"
 #include "llama-vocab.h"
-#include "llama-grammar.h"
 #include "llama-sampling.h"
 
 #include "unicode.h"
@@ -194,6 +193,7 @@ enum llm_arch {
     LLM_ARCH_ORION,
     LLM_ARCH_INTERNLM2,
     LLM_ARCH_MINICPM,
+    LLM_ARCH_MINICPM3,
     LLM_ARCH_GEMMA,
     LLM_ARCH_GEMMA2,
     LLM_ARCH_STARCODER2,
@@ -202,6 +202,7 @@ enum llm_arch {
     LLM_ARCH_COMMAND_R,
     LLM_ARCH_DBRX,
     LLM_ARCH_OLMO,
+    LLM_ARCH_OLMOE,
     LLM_ARCH_OPENELM,
     LLM_ARCH_ARCTIC,
     LLM_ARCH_DEEPSEEK2,
@@ -212,6 +213,8 @@ enum llm_arch {
     LLM_ARCH_JAIS,
     LLM_ARCH_NEMOTRON,
     LLM_ARCH_EXAONE,
+    LLM_ARCH_RWKV6,
+    LLM_ARCH_GRANITE,
     LLM_ARCH_UNKNOWN,
 };
 
@@ -241,6 +244,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_ORION,           "orion"        },
     { LLM_ARCH_INTERNLM2,       "internlm2"    },
     { LLM_ARCH_MINICPM,         "minicpm"      },
+    { LLM_ARCH_MINICPM3,        "minicpm3"     },
     { LLM_ARCH_GEMMA,           "gemma"        },
     { LLM_ARCH_GEMMA2,          "gemma2"       },
     { LLM_ARCH_STARCODER2,      "starcoder2"   },
@@ -249,6 +253,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_COMMAND_R,       "command-r"    },
     { LLM_ARCH_DBRX,            "dbrx"         },
     { LLM_ARCH_OLMO,            "olmo"         },
+    { LLM_ARCH_OLMOE,           "olmoe"        },
     { LLM_ARCH_OPENELM,         "openelm"      },
     { LLM_ARCH_ARCTIC,          "arctic"       },
     { LLM_ARCH_DEEPSEEK2,       "deepseek2"    },
@@ -259,6 +264,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_JAIS,            "jais"         },
     { LLM_ARCH_NEMOTRON,        "nemotron"     },
     { LLM_ARCH_EXAONE,          "exaone"       },
+    { LLM_ARCH_RWKV6,           "rwkv6"        },
+    { LLM_ARCH_GRANITE,         "granite"      },
     { LLM_ARCH_UNKNOWN,         "(unknown)"    },
 };
 
@@ -295,6 +302,11 @@ enum llm_kv {
     LLM_KV_DECODER_START_TOKEN_ID,
     LLM_KV_ATTN_LOGIT_SOFTCAPPING,
     LLM_KV_FINAL_LOGIT_SOFTCAPPING,
+    LLM_KV_RESCALE_EVERY_N_LAYERS,
+    LLM_KV_TIME_MIX_EXTRA_DIM,
+    LLM_KV_TIME_DECAY_EXTRA_DIM,
+    LLM_KV_RESIDUAL_SCALE,
+    LLM_KV_EMBEDDING_SCALE,
 
     LLM_KV_ATTENTION_HEAD_COUNT,
     LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -309,6 +321,7 @@ enum llm_kv {
     LLM_KV_ATTENTION_KV_LORA_RANK,
     LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
     LLM_KV_ATTENTION_SLIDING_WINDOW,
+    LLM_KV_ATTENTION_SCALE,
 
     LLM_KV_ROPE_DIMENSION_COUNT,
     LLM_KV_ROPE_FREQ_BASE,
@@ -330,6 +343,8 @@ enum llm_kv {
     LLM_KV_SSM_TIME_STEP_RANK,
     LLM_KV_SSM_DT_B_C_RMS,
 
+    LLM_KV_WKV_HEAD_SIZE,
+
     LLM_KV_TOKENIZER_MODEL,
     LLM_KV_TOKENIZER_PRE,
     LLM_KV_TOKENIZER_LIST,
@@ -389,11 +404,16 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_EXPERT_USED_COUNT,                 "%s.expert_used_count"                 },
     { LLM_KV_EXPERT_SHARED_COUNT,               "%s.expert_shared_count"               },
     { LLM_KV_EXPERT_WEIGHTS_SCALE,              "%s.expert_weights_scale"              },
-    { LLM_KV_POOLING_TYPE ,                     "%s.pooling_type"                      },
+    { LLM_KV_POOLING_TYPE                     "%s.pooling_type"                      },
     { LLM_KV_LOGIT_SCALE,                       "%s.logit_scale"                       },
     { LLM_KV_DECODER_START_TOKEN_ID,            "%s.decoder_start_token_id"            },
     { LLM_KV_ATTN_LOGIT_SOFTCAPPING,            "%s.attn_logit_softcapping"            },
     { LLM_KV_FINAL_LOGIT_SOFTCAPPING,           "%s.final_logit_softcapping"           },
+    { LLM_KV_RESCALE_EVERY_N_LAYERS,            "%s.rescale_every_n_layers"            },
+    { LLM_KV_TIME_MIX_EXTRA_DIM,                "%s.time_mix_extra_dim"                },
+    { LLM_KV_TIME_DECAY_EXTRA_DIM,              "%s.time_decay_extra_dim"              },
+    { LLM_KV_RESIDUAL_SCALE,                    "%s.residual_scale"                    },
+    { LLM_KV_EMBEDDING_SCALE,                   "%s.embedding_scale"                   },
 
     { LLM_KV_ATTENTION_HEAD_COUNT,             "%s.attention.head_count"             },
     { LLM_KV_ATTENTION_HEAD_COUNT_KV,          "%s.attention.head_count_kv"          },
@@ -408,6 +428,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_ATTENTION_KV_LORA_RANK,           "%s.attention.kv_lora_rank"           },
     { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
     { LLM_KV_ATTENTION_SLIDING_WINDOW,         "%s.attention.sliding_window"         },
+    { LLM_KV_ATTENTION_SCALE,                  "%s.attention.scale"                  },
 
     { LLM_KV_ROPE_DIMENSION_COUNT,          "%s.rope.dimension_count"                 },
     { LLM_KV_ROPE_FREQ_BASE,                "%s.rope.freq_base"                       },
@@ -429,6 +450,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_SSM_TIME_STEP_RANK,            "%s.ssm.time_step_rank" },
     { LLM_KV_SSM_DT_B_C_RMS,                "%s.ssm.dt_b_c_rms" },
 
+    { LLM_KV_WKV_HEAD_SIZE,                 "%s.wkv.head_size" },
+
     { LLM_KV_TOKENIZER_MODEL,                "tokenizer.ggml.model"                    },
     { LLM_KV_TOKENIZER_PRE,                  "tokenizer.ggml.pre"                      },
     { LLM_KV_TOKENIZER_LIST,                 "tokenizer.ggml.tokens"                   },
@@ -518,6 +541,29 @@ enum llm_tensor {
     LLM_TENSOR_SSM_A,
     LLM_TENSOR_SSM_D,
     LLM_TENSOR_SSM_OUT,
+    LLM_TENSOR_TIME_MIX_W1,
+    LLM_TENSOR_TIME_MIX_W2,
+    LLM_TENSOR_TIME_MIX_LERP_X,
+    LLM_TENSOR_TIME_MIX_LERP_W,
+    LLM_TENSOR_TIME_MIX_LERP_K,
+    LLM_TENSOR_TIME_MIX_LERP_V,
+    LLM_TENSOR_TIME_MIX_LERP_R,
+    LLM_TENSOR_TIME_MIX_LERP_G,
+    LLM_TENSOR_TIME_MIX_FIRST,
+    LLM_TENSOR_TIME_MIX_DECAY,
+    LLM_TENSOR_TIME_MIX_DECAY_W1,
+    LLM_TENSOR_TIME_MIX_DECAY_W2,
+    LLM_TENSOR_TIME_MIX_KEY,
+    LLM_TENSOR_TIME_MIX_VALUE,
+    LLM_TENSOR_TIME_MIX_RECEPTANCE,
+    LLM_TENSOR_TIME_MIX_GATE,
+    LLM_TENSOR_TIME_MIX_LN,
+    LLM_TENSOR_TIME_MIX_OUTPUT,
+    LLM_TENSOR_CHANNEL_MIX_LERP_K,
+    LLM_TENSOR_CHANNEL_MIX_LERP_R,
+    LLM_TENSOR_CHANNEL_MIX_KEY,
+    LLM_TENSOR_CHANNEL_MIX_RECEPTANCE,
+    LLM_TENSOR_CHANNEL_MIX_VALUE,
     LLM_TENSOR_ATTN_Q_A,
     LLM_TENSOR_ATTN_Q_B,
     LLM_TENSOR_ATTN_KV_A_MQA,
@@ -1000,6 +1046,29 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
             { LLM_TENSOR_FFN_UP_EXP,      "blk.%d.ffn_up.%d" },
         },
     },
+    {
+        LLM_ARCH_MINICPM3,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
+            { LLM_TENSOR_OUTPUT,             "output" },
+            { LLM_TENSOR_ROPE_FACTORS_LONG,  "rope_factors_long" },
+            { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" },
+            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q_A_NORM,      "blk.%d.attn_q_a_norm" },
+            { LLM_TENSOR_ATTN_KV_A_NORM,     "blk.%d.attn_kv_a_norm" },
+            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_Q_A,           "blk.%d.attn_q_a" },
+            { LLM_TENSOR_ATTN_Q_B,           "blk.%d.attn_q_b" },
+            { LLM_TENSOR_ATTN_KV_A_MQA,      "blk.%d.attn_kv_a_mqa" },
+            { LLM_TENSOR_ATTN_KV_B,          "blk.%d.attn_kv_b" },
+            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,           "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_UP,             "blk.%d.ffn_up" },
+            { LLM_TENSOR_FFN_DOWN,           "blk.%d.ffn_down" },
+        },
+    },
     {
         LLM_ARCH_GEMMA,
         {
@@ -1134,6 +1203,26 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_OLMOE,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
+            { LLM_TENSOR_OUTPUT,             "output" },
+            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,             "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_Q_NORM,        "blk.%d.attn_q_norm" },
+            { LLM_TENSOR_ATTN_K_NORM,        "blk.%d.attn_k_norm" },
+            { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE_INP,       "blk.%d.ffn_gate_inp" },
+            { LLM_TENSOR_FFN_GATE_EXPS,      "blk.%d.ffn_gate_exps" },
+            { LLM_TENSOR_FFN_DOWN_EXPS,      "blk.%d.ffn_down_exps" },
+            { LLM_TENSOR_FFN_UP_EXPS,        "blk.%d.ffn_up_exps" },
+        },
+    },
     {
         LLM_ARCH_OPENELM,
         {
@@ -1339,6 +1428,56 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_RWKV6,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,                "token_embd" },
+            { LLM_TENSOR_TOKEN_EMBD_NORM,           "token_embd_norm" },
+            { LLM_TENSOR_OUTPUT_NORM,               "output_norm" },
+            { LLM_TENSOR_OUTPUT,                    "output" },
+            { LLM_TENSOR_ATTN_NORM,                 "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_NORM_2,               "blk.%d.attn_norm_2" },
+            { LLM_TENSOR_TIME_MIX_W1,               "blk.%d.time_mix_w1" },
+            { LLM_TENSOR_TIME_MIX_W2,               "blk.%d.time_mix_w2" },
+            { LLM_TENSOR_TIME_MIX_LERP_X,           "blk.%d.time_mix_lerp_x" },
+            { LLM_TENSOR_TIME_MIX_LERP_W,           "blk.%d.time_mix_lerp_w" },
+            { LLM_TENSOR_TIME_MIX_LERP_K,           "blk.%d.time_mix_lerp_k" },
+            { LLM_TENSOR_TIME_MIX_LERP_V,           "blk.%d.time_mix_lerp_v" },
+            { LLM_TENSOR_TIME_MIX_LERP_R,           "blk.%d.time_mix_lerp_r" },
+            { LLM_TENSOR_TIME_MIX_LERP_G,           "blk.%d.time_mix_lerp_g" },
+            { LLM_TENSOR_TIME_MIX_FIRST,            "blk.%d.time_mix_first" },
+            { LLM_TENSOR_TIME_MIX_DECAY,            "blk.%d.time_mix_decay" },
+            { LLM_TENSOR_TIME_MIX_DECAY_W1,         "blk.%d.time_mix_decay_w1" },
+            { LLM_TENSOR_TIME_MIX_DECAY_W2,         "blk.%d.time_mix_decay_w2" },
+            { LLM_TENSOR_TIME_MIX_KEY,              "blk.%d.time_mix_key" },
+            { LLM_TENSOR_TIME_MIX_VALUE,            "blk.%d.time_mix_value" },
+            { LLM_TENSOR_TIME_MIX_RECEPTANCE,       "blk.%d.time_mix_receptance" },
+            { LLM_TENSOR_TIME_MIX_GATE,             "blk.%d.time_mix_gate" },
+            { LLM_TENSOR_TIME_MIX_LN,               "blk.%d.time_mix_ln" },
+            { LLM_TENSOR_TIME_MIX_OUTPUT,           "blk.%d.time_mix_output" },
+            { LLM_TENSOR_CHANNEL_MIX_LERP_K,        "blk.%d.channel_mix_lerp_k" },
+            { LLM_TENSOR_CHANNEL_MIX_LERP_R,        "blk.%d.channel_mix_lerp_r" },
+            { LLM_TENSOR_CHANNEL_MIX_KEY,           "blk.%d.channel_mix_key" },
+            { LLM_TENSOR_CHANNEL_MIX_VALUE,         "blk.%d.channel_mix_value" },
+            { LLM_TENSOR_CHANNEL_MIX_RECEPTANCE,    "blk.%d.channel_mix_receptance" },
+        },
+    },
+    {
+        LLM_ARCH_GRANITE,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_UNKNOWN,
         {
@@ -2088,6 +2227,10 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer
     if (host_buffer) {
         buft = ggml_backend_sycl_host_buffer_type();
     }
+#elif defined(GGML_USE_CANN)
+    if (host_buffer) {
+        buft = ggml_backend_cann_host_buffer_type();
+    }
 #elif defined(GGML_USE_CPU_HBM)
     buft = ggml_backend_cpu_hbm_buffer_type();
 #elif defined(GGML_USE_VULKAN)
@@ -2151,6 +2294,7 @@ enum e_model {
     MODEL_1B,
     MODEL_1_3B,
     MODEL_1_4B,
+    MODEL_1_6B,
     MODEL_2B,
     MODEL_2_8B,
     MODEL_3B,
@@ -2179,6 +2323,7 @@ enum e_model {
     MODEL_MEDIUM,
     MODEL_LARGE,
     MODEL_XL,
+    MODEL_A1_7B,
     MODEL_A2_7B,
     MODEL_8x7B,
     MODEL_8x22B,
@@ -2228,6 +2373,12 @@ struct llama_hparams {
     float f_attn_logit_softcapping = 50.0f;
     float f_final_logit_softcapping = 30.0f;
 
+    // for RWKV
+    uint32_t rescale_every_n_layers = 0;
+    uint32_t time_mix_extra_dim = 0;
+    uint32_t time_decay_extra_dim = 0;
+    uint32_t wkv_head_size = 0;
+
     float    rope_attn_factor = 1.0f;
     float    rope_freq_base_train;
     float    rope_freq_scale_train;
@@ -2245,6 +2396,11 @@ struct llama_hparams {
     float f_max_alibi_bias = 0.0f;
     float f_logit_scale    = 0.0f;
 
+    // Additional scale factors (Granite)
+    float f_residual_scale  = 0.0f;
+    float f_embedding_scale = 0.0f;
+    float f_attention_scale = 0.0f;
+
     bool causal_attn   = true;
     bool use_alibi     = false;
     bool attn_soft_cap = false;
@@ -2291,6 +2447,11 @@ struct llama_hparams {
         if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
         if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true;
 
+        if (this->rescale_every_n_layers != other.rescale_every_n_layers) return true;
+        if (this->time_mix_extra_dim     != other.time_mix_extra_dim)     return true;
+        if (this->time_decay_extra_dim   != other.time_decay_extra_dim)   return true;
+        if (this->wkv_head_size          != other.wkv_head_size)          return true;
+
         if (this->dec_start_token_id != other.dec_start_token_id) return true;
 
         const float EPSILON = 1e-9f;
@@ -2302,6 +2463,9 @@ struct llama_hparams {
         if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true;
         if (!is_float_close(this->expert_weights_scale,  other.expert_weights_scale,  EPSILON)) return true;
         if (!is_float_close(this->rope_yarn_log_mul,     other.rope_yarn_log_mul,     EPSILON)) return true;
+        if (!is_float_close(this->f_residual_scale,      other.f_residual_scale,      EPSILON)) return true;
+        if (!is_float_close(this->f_embedding_scale,     other.f_embedding_scale,     EPSILON)) return true;
+        if (!is_float_close(this->f_attention_scale,     other.f_attention_scale,     EPSILON)) return true;
 
         return false;
     }
@@ -2354,15 +2518,25 @@ struct llama_hparams {
     }
 
     uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings
-        // corresponds to Mamba's conv_states size
-        // TODO: maybe support other convolution strides than 1
-        // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
-        return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
+        // corresponds to Mamba's conv_states size or RWKV's token_shift states size
+        if (wkv_head_size != 0) {
+            // for RWKV models
+            return 2 * n_embd;
+        } else {
+            // TODO: maybe support other convolution strides than 1
+            // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
+            return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
+        }
     }
 
     uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings
-        // corresponds to Mamba's ssm_states size
-        return ssm_d_state * ssm_d_inner;
+        if (wkv_head_size != 0) {
+            // corresponds to RWKV's wkv_states size
+            return n_embd * wkv_head_size;
+        } else {
+            // corresponds to Mamba's ssm_states size
+            return ssm_d_state * ssm_d_inner;
+        }
     }
 };
 
@@ -2373,8 +2547,8 @@ struct llama_cparams {
     uint32_t n_batch;
     uint32_t n_ubatch;
     uint32_t n_seq_max;
-    uint32_t n_threads;       // number of threads to use for generation
-    uint32_t n_threads_batch; // number of threads to use for batch processing
+    int      n_threads;       // number of threads to use for generation
+    int      n_threads_batch; // number of threads to use for batch processing
 
     float rope_freq_base;
     float rope_freq_scale;
@@ -2392,6 +2566,7 @@ struct llama_cparams {
     bool causal_attn;
     bool offload_kqv;
     bool flash_attn;
+    bool no_perf;
 
     enum llama_pooling_type pooling_type;
 
@@ -2501,6 +2676,36 @@ struct llama_layer {
     struct ggml_tensor * ssm_conv1d_b;
     struct ggml_tensor * ssm_dt_b;
 
+    // rwkv
+    struct ggml_tensor * time_mix_w1;
+    struct ggml_tensor * time_mix_w2;
+    struct ggml_tensor * time_mix_lerp_x;
+    struct ggml_tensor * time_mix_lerp_w;
+    struct ggml_tensor * time_mix_lerp_k;
+    struct ggml_tensor * time_mix_lerp_v;
+    struct ggml_tensor * time_mix_lerp_r;
+    struct ggml_tensor * time_mix_lerp_g;
+
+    struct ggml_tensor * time_mix_first;
+    struct ggml_tensor * time_mix_decay;
+    struct ggml_tensor * time_mix_decay_w1;
+    struct ggml_tensor * time_mix_decay_w2;
+    struct ggml_tensor * time_mix_key;
+    struct ggml_tensor * time_mix_value;
+    struct ggml_tensor * time_mix_receptance;
+    struct ggml_tensor * time_mix_gate;
+
+    struct ggml_tensor * time_mix_ln;
+    struct ggml_tensor * time_mix_ln_b;
+    struct ggml_tensor * time_mix_output;
+
+    struct ggml_tensor * channel_mix_lerp_k;
+    struct ggml_tensor * channel_mix_lerp_r;
+
+    struct ggml_tensor * channel_mix_key;
+    struct ggml_tensor * channel_mix_receptance;
+    struct ggml_tensor * channel_mix_value;
+
     // long rope factors
     struct ggml_tensor * rope_long  = nullptr;
     struct ggml_tensor * rope_short = nullptr;
@@ -2851,18 +3056,14 @@ struct llama_sbatch {
         } else {
             // simple split
             if (batch->n_seq_id) {
-                for (size_t i = 0; i < length; ++i) {
-                    ubatch.n_seq_id = batch->n_seq_id + seq.offset;
-                }
+                ubatch.n_seq_id = batch->n_seq_id + seq.offset;
             } else {
                 for (size_t i = 0; i < length; ++i) {
                     ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
                 }
             }
             if (batch->seq_id) {
-                for (size_t i = 0; i < length; ++i) {
-                    ubatch.seq_id = batch->seq_id + seq.offset;
-                }
+                ubatch.seq_id = batch->seq_id + seq.offset;
             } else {
                 for (size_t i = 0; i < length; ++i) {
                     ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id;
@@ -3058,7 +3259,6 @@ struct llama_sbatch {
 struct llama_context {
     llama_context(const llama_model & model)
         : model(model)
-        , sampling(llama_n_vocab(&model))
         , t_start_us(model.t_start_us)
         , t_load_us(model.t_load_us) {}
 
@@ -3075,7 +3275,6 @@ struct llama_context {
     const struct llama_model & model;
 
     struct llama_cparams        cparams;
-    struct llama_sampling       sampling;
     struct llama_sbatch         sbatch;
     struct llama_kv_cache       kv_self;
     struct llama_control_vector cvec;
@@ -3091,18 +3290,21 @@ struct llama_context {
 #endif
     ggml_backend_t backend_cpu = nullptr;
 
+    ggml_threadpool_t threadpool       = nullptr;
+    ggml_threadpool_t threadpool_batch = nullptr;
+
     bool has_evaluated_once = false;
 
-    int64_t t_start_us;
-    int64_t t_load_us;
-    int64_t t_p_eval_us = 0;
-    int64_t t_eval_us   = 0;
+    mutable int64_t t_start_us;
+    mutable int64_t t_load_us;
+    mutable int64_t t_p_eval_us = 0;
+    mutable int64_t t_eval_us   = 0;
 
-    int64_t t_compute_start_us = 0;
-    int64_t n_queued_tokens = 0;
+    mutable int64_t t_compute_start_us = 0;
+    mutable int64_t n_queued_tokens = 0;
 
-    int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
-    int32_t n_eval   = 0; // number of eval calls
+    mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
+    mutable int32_t n_eval   = 0; // number of eval calls
 
     // host buffer for the model output (logits and embeddings)
     ggml_backend_buffer_t buf_output = nullptr;
@@ -3222,29 +3424,33 @@ static size_t llama_get_device_count(const llama_model & model) {
 static ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_model & model, int gpu) {
     ggml_backend_buffer_type_t buft = nullptr;
 
-#if defined(GGML_USE_RPC)
-    int dev_count = (int)llama_get_device_count(model);
+#ifdef GGML_USE_RPC
     int rpc_count = (int)model.rpc_servers.size();
-    if (gpu >= dev_count - rpc_count) {
-        const char * endpoint = model.rpc_servers[gpu - dev_count + rpc_count].c_str();
+#else
+    int rpc_count = 0;
+#endif
+    int local_gpu = gpu - rpc_count;
+#if defined(GGML_USE_RPC)
+    if (gpu < rpc_count) {
+        const char * endpoint = model.rpc_servers[gpu].c_str();
         return ggml_backend_rpc_buffer_type(endpoint);
     }
 #endif
 #if defined(GGML_USE_METAL)
     buft = ggml_backend_metal_buffer_type();
 #elif defined(GGML_USE_CUDA)
-    buft = ggml_backend_cuda_buffer_type(gpu);
+    buft = ggml_backend_cuda_buffer_type(local_gpu);
 #elif defined(GGML_USE_VULKAN)
-    buft = ggml_backend_vk_buffer_type(gpu);
+    buft = ggml_backend_vk_buffer_type(local_gpu);
 #elif defined(GGML_USE_SYCL)
-    buft = ggml_backend_sycl_buffer_type(gpu);
+    buft = ggml_backend_sycl_buffer_type(local_gpu);
 #elif defined(GGML_USE_KOMPUTE)
-    buft = ggml_backend_kompute_buffer_type(gpu);
+    buft = ggml_backend_kompute_buffer_type(local_gpu);
     if (buft == nullptr) {
-        LLAMA_LOG_WARN("%s: cannot use GPU %d, check `vulkaninfo --summary`\n", __func__, gpu);
+        LLAMA_LOG_WARN("%s: cannot use GPU %d, check `vulkaninfo --summary`\n", __func__, local_gpu);
     }
 #elif defined(GGML_USE_CANN)
-    buft = ggml_backend_cann_buffer_type(gpu);
+    buft = ggml_backend_cann_buffer_type(local_gpu);
 #endif
 
     if (buft == nullptr) {
@@ -3252,7 +3458,7 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_
     }
     return buft;
     GGML_UNUSED(model);
-    GGML_UNUSED(gpu);
+    GGML_UNUSED(local_gpu);
 }
 
 static ggml_backend_buffer_type_t llama_default_buffer_type_split(const llama_model & model, int fallback_gpu, const float * tensor_split) {
@@ -3279,13 +3485,17 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_split(const llama_mo
 }
 
 static size_t llama_get_device_memory(const llama_model & model, int device) {
-#if defined(GGML_USE_RPC)
-    int dev_count = (int)llama_get_device_count(model);
+#ifdef GGML_USE_RPC
     int rpc_count = (int)model.rpc_servers.size();
-    if (device >= dev_count - rpc_count) {
+#else
+    int rpc_count = 0;
+#endif
+    int local_device = device - rpc_count;
+#if defined(GGML_USE_RPC)
+    if (device < rpc_count) {
         size_t total;
         size_t free;
-        const char * endpoint = model.rpc_servers[device - dev_count + rpc_count].c_str();
+        const char * endpoint = model.rpc_servers[device].c_str();
         ggml_backend_rpc_get_device_memory(endpoint, &free, &total);
         return free;
     }
@@ -3293,28 +3503,28 @@ static size_t llama_get_device_memory(const llama_model & model, int device) {
 #if defined(GGML_USE_CUDA)
     size_t total;
     size_t free;
-    ggml_backend_cuda_get_device_memory(device, &free, &total);
+    ggml_backend_cuda_get_device_memory(local_device, &free, &total);
     return free;
 #elif defined(GGML_USE_SYCL)
     size_t total;
     size_t free;
-    ggml_backend_sycl_get_device_memory(device, &free, &total);
+    ggml_backend_sycl_get_device_memory(local_device, &free, &total);
     return free;
 #elif defined(GGML_USE_VULKAN)
     size_t total;
     size_t free;
-    ggml_backend_vk_get_device_memory(device, &free, &total);
+    ggml_backend_vk_get_device_memory(local_device, &free, &total);
     return free;
 #elif defined(GGML_USE_CANN)
     size_t total;
     size_t free;
-    ggml_backend_cann_get_device_memory(device, &free, &total);
+    ggml_backend_cann_get_device_memory(local_device, &free, &total);
     return free;
 #else
     return 1;
 #endif
     GGML_UNUSED(model);
-    GGML_UNUSED(device);
+    GGML_UNUSED(local_device);
 }
 
 //
@@ -3423,7 +3633,7 @@ static bool llama_kv_cache_find_slot(
     const uint32_t n_seq_tokens = batch.n_seq_tokens;
 
     if (cache.recurrent) {
-        // For recurrent state architectures (like Mamba),
+        // For recurrent state architectures (like Mamba or RWKV),
         // each cache cell can store the state for a whole sequence.
         // A slot should be always be contiguous.
 
@@ -3672,7 +3882,7 @@ static bool llama_kv_cache_seq_rm(
     if (p0 < 0) p0 = 0;
     if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
 
-    // models like Mamba can't have a state partially erased
+    // models like Mamba or RWKV can't have a state partially erased
     if (cache.recurrent) {
         if (seq_id >= (int64_t) cache.size) {
             // could be fatal
@@ -3686,7 +3896,8 @@ static bool llama_kv_cache_seq_rm(
                 if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
                     return false;
                 }
-                if (p0 <= cell.pos && p1 < cell.pos) {
+                // invalidate tails which will be cleared
+                if (p0 <= cell.pos && cell.pos < p1) {
                     tail_id = -1;
                 }
             }
@@ -3808,7 +4019,7 @@ static void llama_kv_cache_seq_add(
     if (p0 == p1) return;
 
     if (cache.recurrent) {
-        // for Mamba-like models, only the pos needs to be shifted
+        // for Mamba-like or RWKV models, only the pos needs to be shifted
         if (0 <= seq_id && seq_id < (int64_t) cache.size) {
             const int32_t tail_id = cache.cells[seq_id].tail;
             if (tail_id >= 0) {
@@ -3857,7 +4068,7 @@ static void llama_kv_cache_seq_div(
     if (p0 == p1) return;
 
     if (cache.recurrent) {
-        // for Mamba-like models, only the pos needs to be changed
+        // for Mamba-like or RWKV models, only the pos needs to be changed
         if (0 <= seq_id && seq_id < (int64_t) cache.size) {
             const int32_t tail_id = cache.cells[seq_id].tail;
             if (tail_id >= 0) {
@@ -4311,6 +4522,8 @@ struct llama_model_loader {
                 case GGML_TYPE_Q4_K:    ftype = LLAMA_FTYPE_MOSTLY_Q4_K_M;  break;
                 case GGML_TYPE_Q5_K:    ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M;  break;
                 case GGML_TYPE_Q6_K:    ftype = LLAMA_FTYPE_MOSTLY_Q6_K;    break;
+                case GGML_TYPE_TQ1_0:   ftype = LLAMA_FTYPE_MOSTLY_TQ1_0;   break;
+                case GGML_TYPE_TQ2_0:   ftype = LLAMA_FTYPE_MOSTLY_TQ2_0;   break;
                 case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break;
                 case GGML_TYPE_IQ2_XS:  ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS;  break;
                 case GGML_TYPE_IQ2_S:   ftype = LLAMA_FTYPE_MOSTLY_IQ2_S;   break;
@@ -5004,6 +5217,8 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
         case LLAMA_FTYPE_MOSTLY_Q5_K_S:   return "Q5_K - Small";
         case LLAMA_FTYPE_MOSTLY_Q5_K_M:   return "Q5_K - Medium";
         case LLAMA_FTYPE_MOSTLY_Q6_K:     return "Q6_K";
+        case LLAMA_FTYPE_MOSTLY_TQ1_0:    return "TQ1_0 - 1.69 bpw ternary";
+        case LLAMA_FTYPE_MOSTLY_TQ2_0:    return "TQ2_0 - 2.06 bpw ternary";
         case LLAMA_FTYPE_MOSTLY_IQ2_XXS:  return "IQ2_XXS - 2.0625 bpw";
         case LLAMA_FTYPE_MOSTLY_IQ2_XS:   return "IQ2_XS - 2.3125 bpw";
         case LLAMA_FTYPE_MOSTLY_IQ2_S:    return "IQ2_S - 2.5 bpw";
@@ -5048,6 +5263,7 @@ static const char * llama_model_type_name(e_model type) {
         case MODEL_1B:            return "1B";
         case MODEL_1_3B:          return "1.3B";
         case MODEL_1_4B:          return "1.4B";
+        case MODEL_1_6B:          return "1.6B";
         case MODEL_2B:            return "2B";
         case MODEL_2_8B:          return "2.8B";
         case MODEL_3B:            return "3B";
@@ -5076,6 +5292,7 @@ static const char * llama_model_type_name(e_model type) {
         case MODEL_MEDIUM:        return "0.4B";
         case MODEL_LARGE:         return "0.8B";
         case MODEL_XL:            return "1.5B";
+        case MODEL_A1_7B:         return "A1.7B";
         case MODEL_A2_7B:         return "A2.7B";
         case MODEL_8x7B:          return "8x7B";
         case MODEL_8x22B:         return "8x22B";
@@ -5094,6 +5311,7 @@ static const char * llama_model_vocab_type_name(enum llama_vocab_type type){
         case LLAMA_VOCAB_TYPE_BPE:  return "BPE";
         case LLAMA_VOCAB_TYPE_WPM:  return "WPM";
         case LLAMA_VOCAB_TYPE_UGM:  return "UGM";
+        case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
         default:                    return "unknown";
     }
 }
@@ -5249,6 +5467,17 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_MINICPM3:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q);
+                ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
+
+                switch (hparams.n_layer) {
+                    case 62: model.type = e_model::MODEL_4B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_GROK:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -5614,6 +5843,14 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_OLMOE:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                switch (hparams.n_layer) {
+                    case 16: model.type = e_model::MODEL_A1_7B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_OPENELM:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -5790,6 +6027,40 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_RWKV6:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
+                ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size);
+                ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim);
+                ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim);
+                ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false);
+
+                switch (hparams.n_layer) {
+                    case 24: model.type = e_model::MODEL_1_6B; break;
+                    case 32:
+                        switch (hparams.n_embd) {
+                            case 2560: model.type = e_model::MODEL_3B; break;
+                            case 4096: model.type = e_model::MODEL_7B; break;
+                            default: model.type = e_model::MODEL_UNKNOWN;
+                        } break;
+                    case 61: model.type = e_model::MODEL_14B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
+        case LLM_ARCH_GRANITE:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
+                ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale);
+                ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale);
+                ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale);
+
+                switch (hparams.n_layer) {
+                    case 40: model.type = e_model::MODEL_3B; break;
+                    // Add additional layer/vocab/etc checks here for other model sizes
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
         default: (void)0;
     }
 
@@ -5832,8 +6103,15 @@ static void llm_load_vocab(
             vocab.special_mask_id = -1;
             vocab.linefeed_id     = -1;
 
+            // read vocab size from metadata
+            if (!ml.get_key(LLM_KV_VOCAB_SIZE, vocab.n_vocab, false)) {
+                vocab.n_vocab = 0;
+                LLAMA_LOG_WARN("%s: there is no vocab_size in metadata, vocab.n_vocab will be set to %u\n", __func__, vocab.n_vocab);
+            }
             return;
-        } else if (tokenizer_model == "llama") {
+        }
+
+        if (tokenizer_model == "llama") {
             vocab.type = LLAMA_VOCAB_TYPE_SPM;
 
             // default special tokens
@@ -5919,6 +6197,15 @@ static void llm_load_vocab(
                 }
 #endif
             }
+        } else if (tokenizer_model == "rwkv") {
+            vocab.type = LLAMA_VOCAB_TYPE_RWKV;
+
+            // default special tokens
+            vocab.special_bos_id = -1;
+            vocab.special_eos_id = -1;
+            vocab.special_unk_id = -1;
+            vocab.special_sep_id = -1;
+            vocab.special_pad_id = -1;
         } else {
             throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
         }
@@ -6050,6 +6337,12 @@ static void llm_load_vocab(
             vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
             vocab.tokenizer_add_bos = false;
             vocab.tokenizer_add_eos = true;
+        } else if (vocab.type == LLAMA_VOCAB_TYPE_RWKV) {
+            vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+            vocab.tokenizer_add_space_prefix = false;
+            vocab.tokenizer_clean_spaces = false;
+            vocab.tokenizer_add_bos = false;
+            vocab.tokenizer_add_eos = false;
         } else {
             vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
         }
@@ -6077,6 +6370,7 @@ static void llm_load_vocab(
 
     const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx);
 
+    vocab.n_vocab = n_vocab;
     vocab.id_to_token.resize(n_vocab);
 
     for (uint32_t i = 0; i < n_vocab; i++) {
@@ -6154,6 +6448,10 @@ static void llm_load_vocab(
         }
     } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) {
         vocab.linefeed_id = vocab.special_pad_id;
+    } else if (vocab.type == LLAMA_VOCAB_TYPE_RWKV) {
+        const std::vector<int> ids = llama_tokenize_internal(vocab, "\n", false);
+        GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
+        vocab.linefeed_id = ids[0];
     } else {
         const std::vector<int> ids = llama_tokenize_internal(vocab, "\xC4\x8A", false); // U+010A
         GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
@@ -6211,18 +6509,23 @@ static void llm_load_vocab(
         //       for now, we apply this workaround to find the EOT token based on its text
         if (vocab.special_eot_id == -1) {
             for (const auto & t : vocab.token_to_id) {
-                if (
+                if (false
                         // TODO: gemma "<end_of_turn>" is exported as a normal token, so the following check does not work
                         //       need to fix convert script
                         //vocab.id_to_token[t.second].type == LLAMA_TOKEN_TYPE_CONTROL &&
-                        (t.first == "<|eot_id|>" ||
-                         t.first == "<|im_end|>" ||
-                         t.first == "<|end|>" ||
-                         t.first == "<end_of_turn>" ||
-                         t.first == "<|endoftext|>"
-                        )
+                        || t.first == "<|eot_id|>"
+                        || t.first == "<|im_end|>"
+                        || t.first == "<|end|>"
+                        || t.first == "<end_of_turn>"
+                        || t.first == "<|endoftext|>"
+                        || t.first == "<EOT>"
                    ) {
                     vocab.special_eot_id = t.second;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.first.c_str());
+                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
                     break;
                 }
             }
@@ -6236,8 +6539,51 @@ static void llm_load_vocab(
             const auto & t = vocab.token_to_id.find("<|eom_id|>");
             if (t != vocab.token_to_id.end()) {
                 vocab.special_eom_id = t->second;
+                if ((vocab.id_to_token[t->second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                        __func__, t->first.c_str());
+                    vocab.id_to_token[t->second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                }
+            }
+        }
+
+        // maintain a list of tokens that cause end-of-generation
+        // this is currently determined based on the token text, which is obviously not ideal
+        // ref: https://github.com/ggerganov/llama.cpp/issues/9606
+        vocab.special_eog_ids.clear();
+        for (const auto & t : vocab.token_to_id) {
+            if (false
+                    || t.first == "<|eot_id|>"
+                    || t.first == "<|im_end|>"
+                    || t.first == "<|end|>"
+                    || t.first == "<end_of_turn>"
+                    || t.first == "<|endoftext|>"
+                    || t.first == "<|eom_id|>"
+                    || t.first == "<EOT>"
+               ) {
+                vocab.special_eog_ids.insert(t.second);
+                if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                            __func__, t.first.c_str());
+                    vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                }
             }
         }
+
+        if (vocab.special_eos_id != -1 && vocab.special_eog_ids.count(vocab.special_eos_id) == 0) {
+            vocab.special_eog_ids.insert(vocab.special_eos_id);
+            LLAMA_LOG_WARN("%s: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+
+        if (vocab.special_eot_id != -1 && vocab.special_eog_ids.count(vocab.special_eot_id) == 0) {
+            vocab.special_eog_ids.insert(vocab.special_eot_id);
+            LLAMA_LOG_WARN("%s: special_eot_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+
+        if (vocab.special_eom_id != -1 && vocab.special_eog_ids.count(vocab.special_eom_id) == 0) {
+            vocab.special_eog_ids.insert(vocab.special_eom_id);
+            LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
     }
 
     // build special tokens cache
@@ -6441,6 +6787,11 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     if (vocab.special_suffix_id != -1) { LLAMA_LOG_INFO( "%s: SUF token        = %d '%s'\n", __func__, vocab.special_suffix_id, vocab.id_to_token[vocab.special_suffix_id].text.c_str() ); }
     if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token        = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); }
     if (vocab.special_eot_id    != -1) { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, vocab.special_eot_id,    vocab.id_to_token[vocab.special_eot_id].text.c_str() );    }
+    if (vocab.special_eom_id    != -1) { LLAMA_LOG_INFO( "%s: EOM token        = %d '%s'\n", __func__, vocab.special_eom_id,    vocab.id_to_token[vocab.special_eom_id].text.c_str() );    }
+
+    for (const auto & id : vocab.special_eog_ids) {
+        LLAMA_LOG_INFO( "%s: EOG token        = %d '%s'\n", __func__, id, vocab.id_to_token[id].text.c_str() );
+    }
 
     LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, vocab.max_token_len);
 
@@ -6458,6 +6809,12 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
         LLAMA_LOG_INFO("%s: n_ff_exp         = %d\n",     __func__, hparams.n_ff_exp);
         LLAMA_LOG_INFO("%s: n_ff_shexp       = %d\n",     __func__, hparams.n_ff_shexp);
     }
+
+    if (model.arch == LLM_ARCH_GRANITE) {
+        LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
+        LLAMA_LOG_INFO("%s: f_residual_scale  = %f\n", __func__, hparams.f_residual_scale);
+        LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
+    }
 }
 
 // Returns false if cancelled by progress_callback
@@ -6471,8 +6828,6 @@ static bool llm_load_tensors(
         bool use_mlock,
         llama_progress_callback progress_callback,
         void * progress_callback_user_data) {
-    model.t_start_us = ggml_time_us();
-
     auto & hparams = model.hparams;
 
     model.split_mode   = split_mode;
@@ -6628,6 +6983,7 @@ static bool llm_load_tensors(
             case LLM_ARCH_LLAMA:
             case LLM_ARCH_REFACT:
             case LLM_ARCH_MINICPM:
+            case LLM_ARCH_GRANITE:
                 {
                     model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
 
@@ -6708,6 +7064,54 @@ static bool llm_load_tensors(
                         }
                     }
                 } break;
+            case LLM_ARCH_MINICPM3:
+                {
+                    const int64_t n_embd_head_qk_rope = hparams.n_rot;
+                    const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
+
+                    const int64_t q_lora_rank  = hparams.n_lora_q;
+                    const int64_t kv_lora_rank = hparams.n_lora_kv;
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                    // output
+                    {
+                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
+                        // if output is NULL, init from the input tok embed
+                        if (model.output == NULL) {
+                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
+                        }
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.attn_q_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank});
+
+                        layer.attn_kv_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank});
+
+                        layer.wq_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank});
+                        layer.wq_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k});
+
+                        layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)});
+                        layer.wkv_b     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B,     "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)});
+                        layer.wo        = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {              n_head * (                      n_embd_head_v), n_embd});
+
+                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+
+                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
+                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
+                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+
+                        layer.rope_long  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight"), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+                        layer.rope_short = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+                    }
+                } break;
             case LLM_ARCH_GROK:
                 {
                     if (n_expert == 0) {
@@ -7745,6 +8149,44 @@ static bool llm_load_tensors(
                         layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
                     }
                 } break;
+            case LLM_ARCH_OLMOE:
+                {
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                    // output
+                    {
+                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+
+                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
+                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd});
+                        layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd});
+
+                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+
+                        layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
+
+                        GGML_ASSERT(n_expert      > 0);
+                        GGML_ASSERT(n_expert_used > 0);
+
+                        // MoE branch
+                        layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff,   n_expert});
+                        layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff,   n_embd, n_expert});
+                        layer.ffn_up_exps   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd, n_ff,   n_expert});
+                    }
+                } break;
             case LLM_ARCH_OPENELM:
                 {
                     model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@@ -7944,23 +8386,23 @@ static bool llm_load_tensors(
                         layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd});
 
                         layer.wq       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wq_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "scale",  i), {1});
+                        layer.wq_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
                         layer.wk       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wk_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "scale",  i), {1});
+                        layer.wk_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
                         layer.wv       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "scale",  i), {1});
+                        layer.wv_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
                         layer.wo       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-                        layer.wo_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "scale",  i), {1});
+                        layer.wo_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
                         layer.ffn_norm     = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM,     "weight", i), {n_embd});
                         layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff});
 
                         layer.ffn_gate       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
-                        layer.ffn_gate_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "scale",  i), {1});
+                        layer.ffn_gate_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
                         layer.ffn_down       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
-                        layer.ffn_down_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "scale",  i), {1});
+                        layer.ffn_down_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
                         layer.ffn_up         = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
-                        layer.ffn_up_scale   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "scale",  i), {1});
+                        layer.ffn_up_scale   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
                     }
                 } break;
             case LLM_ARCH_T5:
@@ -8200,6 +8642,68 @@ static bool llm_load_tensors(
                         layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
                     }
                 } break;
+            case LLM_ARCH_RWKV6:
+                {
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                    // Block 0, LN0
+                    model.tok_norm = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
+                    model.tok_norm_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd});
+
+                    // output
+                    model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                    model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
+                    model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
+
+                    const int time_mix_extra_dim = hparams.time_mix_extra_dim;
+                    const int time_decay_extra_dim = hparams.time_decay_extra_dim;
+                    const int head_size = hparams.wkv_head_size;
+                    const int attn_hidden_size = n_embd;
+                    const int ffn_size = hparams.n_ff_arr[0];
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd});
+
+                        layer.attn_norm_2   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd});
+                        layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i),   {n_embd});
+
+                        layer.time_mix_w1 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5});
+                        layer.time_mix_w2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5});
+
+                        layer.time_mix_lerp_x = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1});
+                        layer.time_mix_lerp_w = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1});
+                        layer.time_mix_lerp_k = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1});
+                        layer.time_mix_lerp_v = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1});
+                        layer.time_mix_lerp_r = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1});
+                        layer.time_mix_lerp_g = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1});
+
+                        layer.time_mix_first = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size});
+                        layer.time_mix_decay = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd});
+                        layer.time_mix_decay_w1 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim});
+                        layer.time_mix_decay_w2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size});
+                        layer.time_mix_key = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd});
+                        layer.time_mix_value = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd});
+                        layer.time_mix_receptance = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd});
+                        layer.time_mix_gate = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd});
+
+                        layer.time_mix_ln = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd});
+                        layer.time_mix_ln_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd});
+                        layer.time_mix_output = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size});
+
+                        layer.channel_mix_lerp_k = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1});
+                        layer.channel_mix_lerp_r = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_LERP_R, "weight", i), {n_embd, 1, 1});
+
+                        layer.channel_mix_key = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size});
+                        layer.channel_mix_value = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd});
+                        layer.channel_mix_receptance = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "weight", i), {n_embd, n_embd});
+                    }
+
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -8341,14 +8845,13 @@ static bool llm_load_tensors(
         }
     }
 
-    // loading time will be recalculate after the first eval, so
-    // we take page faults deferred by mmap() into consideration
-    model.t_load_us = ggml_time_us() - model.t_start_us;
     return true;
 }
 
 // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
 static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) {
+    model.t_start_us = ggml_time_us();
+
     try {
         llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides);
 
@@ -8410,6 +8913,10 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
         return -1;
     }
 
+    // loading time will be recalculate after the first eval, so
+    // we take page faults deferred by mmap() into consideration
+    model.t_load_us = ggml_time_us() - model.t_start_us;
+
     return 0;
 }
 
@@ -8460,6 +8967,11 @@ static struct ggml_tensor * llm_build_inp_embd(
         ggml_set_input(lctx.inp_embd);
     }
 
+    // For Granite architecture
+    if (hparams.f_embedding_scale != 0.0f) {
+        inpL = ggml_scale(ctx, inpL, hparams.f_embedding_scale);
+    }
+
     cb(inpL, "inp_embd", -1);
 
     return inpL;
@@ -8484,8 +8996,7 @@ static void llm_build_kv_store(
 
     GGML_ASSERT(kv.size == n_ctx);
 
-    struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa,
-            (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head);
+    struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa)*kv_head);
     cb(k_cache_view, "k_cache_view", il);
 
     // note: storing RoPE-ed version of K in the KV cache
@@ -8496,8 +9007,7 @@ static void llm_build_kv_store(
     struct ggml_tensor * v_cache_view = nullptr;
 
     if (cparams.flash_attn) {
-        v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa,
-                (kv_head)*ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa));
+        v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa)*kv_head);
     } else {
         // note: the V cache is transposed when not using flash attention
         v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa,
@@ -8984,8 +9494,7 @@ static struct ggml_tensor * llm_build_kv(
 
     struct ggml_tensor * cur;
 
-    cur  = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b,
-            q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il);
+    cur  = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b, q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il);
     cb(cur, "kqv_out", il);
 
     return cur;
@@ -9013,7 +9522,7 @@ static struct ggml_tensor * llm_build_copy_mask_state(
     // FIXME: zero-out NANs?
     states = ggml_mul(ctx, states, state_mask);
 
-    // copy states which won't be changed further (between n_seqs and n_rs)
+    // copy states which won't be changed further (between n_seqs and n_kv)
     ggml_build_forward_expand(graph,
         ggml_cpy(ctx,
             ggml_view_1d(ctx, states, n_state*(n_kv - n_seqs), n_seqs*n_state*ggml_element_size(states)),
@@ -9159,7 +9668,172 @@ static struct ggml_tensor * llm_build_mamba(
     return cur;
 }
 
-struct llm_build_context {
+static struct ggml_tensor * llm_build_rwkv6_time_mix(
+        struct llama_context & lctx,
+        struct ggml_context * ctx,
+        const struct llama_layer * layer,
+        struct ggml_tensor * cur,
+        struct ggml_tensor * x_prev,
+        struct ggml_tensor ** wkv_state) {
+    size_t n_embd       = cur->ne[0];
+    size_t n_seq_tokens = cur->ne[1];
+    size_t n_seqs       = cur->ne[2];
+
+    size_t head_size  = layer->time_mix_first->ne[0];
+    size_t head_count = layer->time_mix_first->ne[1];
+
+    size_t n_tokens = n_seqs * n_seq_tokens;
+
+    struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
+
+    sx  = ggml_reshape_2d(ctx, sx,  n_embd, n_tokens);
+    cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
+
+    struct ggml_tensor * xxx = ggml_add(ctx, ggml_mul(ctx, sx, layer->time_mix_lerp_x), cur);
+
+    xxx = ggml_reshape_4d(
+        ctx,
+        ggml_tanh(
+            ctx,
+            ggml_mul_mat(ctx, layer->time_mix_w1, xxx)
+        ),
+        layer->time_mix_w1->ne[1] / 5, 1, 5, n_tokens
+    );
+
+    xxx = ggml_cont(ctx, ggml_permute(ctx, xxx, 0, 1, 3, 2));
+
+    xxx = ggml_mul_mat(
+        ctx,
+        ggml_reshape_4d(
+            ctx,
+            layer->time_mix_w2,
+            layer->time_mix_w2->ne[0], layer->time_mix_w2->ne[1], 1, 5
+        ),
+        xxx
+    );
+
+    struct ggml_tensor *mw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
+    struct ggml_tensor *mk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
+    struct ggml_tensor *mv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
+    struct ggml_tensor *mr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
+    struct ggml_tensor *mg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
+
+    struct ggml_tensor * xw = ggml_add(
+        ctx,
+        ggml_mul(
+            ctx,
+            ggml_add(ctx, mw, layer->time_mix_lerp_w),
+            sx
+        ),
+        cur
+    );
+
+    struct ggml_tensor * xk = ggml_add(
+        ctx,
+        ggml_mul(
+            ctx,
+            ggml_add(ctx, mk, layer->time_mix_lerp_k),
+            sx
+        ),
+        cur
+    );
+
+    struct ggml_tensor * xv = ggml_add(
+        ctx,
+        ggml_mul(
+            ctx,
+            ggml_add(ctx, mv, layer->time_mix_lerp_v),
+            sx
+        ),
+        cur
+    );
+
+    struct ggml_tensor * xr = ggml_add(
+        ctx,
+        ggml_mul(
+            ctx,
+            ggml_add(ctx, mr, layer->time_mix_lerp_r),
+            sx
+        ),
+        cur
+    );
+
+    struct ggml_tensor * xg = ggml_add(
+        ctx,
+        ggml_mul(
+            ctx,
+            ggml_add(ctx, mg, layer->time_mix_lerp_g),
+            sx
+        ),
+        cur
+    );
+
+    struct ggml_tensor * r = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr), head_size, 1,         head_count, n_tokens);
+    struct ggml_tensor * k = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_key,        xk), 1,         head_size, head_count, n_tokens);
+    struct ggml_tensor * v = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_value,      xv), head_size, 1,         head_count, n_tokens);
+    struct ggml_tensor * g = ggml_silu(
+        ctx,
+        llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg)
+    );
+
+    struct ggml_tensor * w = ggml_mul_mat(
+        ctx,
+        layer->time_mix_decay_w2,
+        ggml_tanh(
+            ctx,
+            ggml_mul_mat(ctx, layer->time_mix_decay_w1, xw)
+        )
+    );
+
+    w = ggml_add(ctx, w, ggml_reshape_1d(ctx, layer->time_mix_decay, n_embd));
+    w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w)));
+    w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens);
+
+    k = ggml_transpose(ctx, k);
+    v = ggml_transpose(ctx, v);
+    r = ggml_transpose(ctx, r);
+
+    struct ggml_tensor * wkv_output = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
+    cur = ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0);
+    *wkv_state = ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
+
+    // group norm with head_count groups
+    cur = ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens);
+    cur = ggml_norm(ctx, cur, 64e-5f);
+
+    // Convert back to regular vectors.
+    cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
+    cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
+
+    cur = ggml_mul(ctx, cur, g);
+    cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur);
+
+    return ggml_reshape_3d(ctx, cur, n_embd, n_seq_tokens, n_seqs);
+}
+
+static struct ggml_tensor * llm_build_rwkv6_channel_mix(
+        struct llama_context & lctx,
+        struct ggml_context * ctx,
+        const struct llama_layer * layer,
+        struct ggml_tensor * cur,
+        struct ggml_tensor * x_prev) {
+    struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
+    struct ggml_tensor * xk = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_k), cur);
+    struct ggml_tensor * xr = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_r), cur);
+
+    struct ggml_tensor * r = ggml_sigmoid(ctx, llm_build_lora_mm(lctx, ctx, layer->channel_mix_receptance, xr));
+    struct ggml_tensor * k = ggml_sqr(
+        ctx,
+        ggml_relu(
+            ctx,
+            llm_build_lora_mm(lctx, ctx, layer->channel_mix_key, xk)
+        )
+    );
+
+    return ggml_mul(ctx, r, llm_build_lora_mm(lctx, ctx, layer->channel_mix_value, k));
+}
+
+struct llm_build_context {
     const llama_model    & model;
           llama_context  & lctx;
     const llama_hparams  & hparams;
@@ -9299,17 +9973,36 @@ struct llm_build_context {
             const int64_t n_head_kv = hparams.n_head_kv(il);
             const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
             struct ggml_tensor * rope_factors = build_rope_factors(il);
-            struct ggml_tensor * tmp =
+            struct ggml_tensor * k =
+                ggml_view_3d(ctx0, kv_self.k_l[il],
+                    n_embd_head_k, n_head_kv, n_ctx,
+                    ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
+                    ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
+                    0);
+
+            struct ggml_tensor * tmp;
+            if (ggml_is_quantized(k->type)) {
+                // dequantize to f32 -> RoPE -> quantize back
+                tmp = ggml_cast(ctx0, k, GGML_TYPE_F32);
+                cb(tmp, "K_f32", il);
+                for (auto * backend : lctx.backends) {
+                    // Figure out which backend KV cache belongs to
+                    if (ggml_backend_supports_buft(backend, lctx.model.buft_layer[il].buft)) {
+                        ggml_backend_sched_set_tensor_backend(lctx.sched, tmp, backend);
+                        break;
+                    }
+                }
+                tmp = ggml_rope_ext_inplace(ctx0, tmp,
+                        lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(tmp, "K_shifted_f32", il);
+                tmp = ggml_cpy(ctx0, tmp, k);
+            } else {
                 // we rotate only the first n_rot dimensions
-                ggml_rope_ext_inplace(ctx0,
-                        ggml_view_3d(ctx0, kv_self.k_l[il],
-                            n_embd_head_k, n_head_kv, n_ctx,
-                            ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
-                            ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
-                            0),
+                tmp = ggml_rope_ext_inplace(ctx0, k,
                         lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
-
+            }
             cb(tmp, "K_shifted", il);
             ggml_build_forward_expand(gf, tmp);
         }
@@ -9467,8 +10160,8 @@ struct llm_build_context {
     struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) {
         // find result_norm tensor for input
         struct ggml_tensor * inp = nullptr;
-        for (int i = gf->n_nodes - 1; i >= 0; --i) {
-            inp = gf->nodes[i];
+        for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
+            inp = ggml_graph_node(gf, i);
             if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
                 break;
             } else {
@@ -9576,6 +10269,7 @@ struct llm_build_context {
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
+        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * inpSA = inpL;
 
@@ -9628,7 +10322,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -9639,6 +10333,11 @@ struct llm_build_context {
                 inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
             }
 
+            // For Granite architecture
+            if (hparams.f_residual_scale) {
+                cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
+            }
+
             struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
             cb(ffn_inp, "ffn_inp", il);
 
@@ -9675,6 +10374,11 @@ struct llm_build_context {
                 cb(cur, "ffn_moe_out", il);
             }
 
+            // For Granite architecture
+            if (hparams.f_residual_scale) {
+                cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
+            }
+
             cur = ggml_add(ctx0, cur, ffn_inp);
             cb(cur, "ffn_out", il);
 
@@ -9694,6 +10398,12 @@ struct llm_build_context {
 
         // lm_head
         cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+        // For Granite architecture
+        if (hparams.f_logit_scale) {
+            cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
+        }
+
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -12427,6 +13137,215 @@ struct llm_build_context {
         return gf;
     }
 
+    struct ggml_cgraph * build_minicpm3() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+        //TODO: if the model varies, these parameters need to be read from the model
+        const int64_t n_embd_base = 256;
+        const float scale_embd  = 12.0f;
+        const float scale_depth = 1.4f;
+        const float kq_scale = 1.0f / sqrtf(float(hparams.n_embd_head_k));
+
+        const uint32_t n_embd_head_qk_rope = hparams.n_rot;
+        const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
+        const uint32_t kv_lora_rank = hparams.n_lora_kv;
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        // scale the input embeddings
+        inpL = ggml_scale(ctx0, inpL, scale_embd);
+        cb(inpL, "inp_scaled", -1);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+
+            struct ggml_tensor * rope_factors = build_rope_factors(il);
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self_attention
+            {
+                struct ggml_tensor * q = NULL;
+                // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens}
+                q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
+                cb(q, "q", il);
+
+                q = llm_build_norm(ctx0, q, hparams,
+                        model.layers[il].attn_q_a_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(q, "q", il);
+
+                // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens}
+                q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
+                cb(q, "q", il);
+
+                // split into {n_head * n_embd_head_qk_nope, n_tokens}
+                struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
+                        ggml_row_size(q->type, hparams.n_embd_head_k),
+                        ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
+                        0);
+                cb(q_nope, "q_nope", il);
+
+                // and {n_head * n_embd_head_qk_rope, n_tokens}
+                struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
+                        ggml_row_size(q->type, hparams.n_embd_head_k),
+                        ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
+                        ggml_row_size(q->type, n_embd_head_qk_nope));
+                cb(q_pe, "q_pe", il);
+
+                // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
+                struct ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
+                cb(kv_pe_compresseed, "kv_pe_compresseed", il);
+
+                // split into {kv_lora_rank, n_tokens}
+                struct ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
+                        kv_pe_compresseed->nb[1],
+                        0);
+                cb(kv_compressed, "kv_compressed", il);
+
+                // and {n_embd_head_qk_rope, n_tokens}
+                struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
+                        kv_pe_compresseed->nb[1],
+                        kv_pe_compresseed->nb[1],
+                        ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
+                cb(k_pe, "k_pe", il);
+
+                kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
+                kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
+                        model.layers[il].attn_kv_a_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(kv_compressed, "kv_compressed", il);
+
+                // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
+                struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
+                cb(kv, "kv", il);
+
+                // split into {n_head * n_embd_head_qk_nope, n_tokens}
+                struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
+                        ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
+                        ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
+                        0);
+                cb(k_nope, "k_nope", il);
+
+                // and {n_head * n_embd_head_v, n_tokens}
+                struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
+                        ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
+                        ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
+                        ggml_row_size(kv->type, (n_embd_head_qk_nope)));
+                cb(v_states, "v_states", il);
+
+                v_states = ggml_cont(ctx0, v_states);
+                cb(v_states, "v_states", il);
+
+                v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
+                    ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
+                    0);
+                cb(v_states, "v_states", il);
+
+                q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
+                q_pe = ggml_rope_ext(
+                    ctx0, q_pe, inp_pos, rope_factors,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(q_pe, "q_pe", il);
+
+                // shared RoPE key
+                k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
+                k_pe = ggml_rope_ext(
+                    ctx0, k_pe, inp_pos, rope_factors,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(k_pe, "k_pe", il);
+
+                struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
+                cb(q_states, "q_states", il);
+
+                struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
+                cb(k_states, "k_states", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        model.layers[il].wo, NULL,
+                        k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            // scale_res - scale the hidden states for residual connection
+            const float scale_res = scale_depth/sqrtf(float(n_layer));
+            cur = ggml_scale(ctx0, cur, scale_res);
+            cb(cur, "hidden_scaled", il);
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            {
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            // scale the hidden states for residual connection
+            cur = ggml_scale(ctx0, cur, scale_res);
+            cb(cur, "hidden_scaled_ffn", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head scaling
+        const float scale_lmhead = float(n_embd_base)/float(n_embd);
+        cur = ggml_scale(ctx0, cur, scale_lmhead);
+        cb(cur, "lmhead_scaling", -1);
+
+        // lm_head
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
     struct ggml_cgraph * build_gemma() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
@@ -13098,8 +14017,136 @@ struct llm_build_context {
             cb(cur, "ffn_out", il);
 
             cur = ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "ffn_out", il);
-
+            cb(cur, "ffn_out", il);
+
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                NULL, NULL,
+                LLM_NORM, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    // based on the build_qwen2moe() function, changes:
+    //   * removed shared experts
+    //   * removed bias
+    //   * added q, k norm
+    struct ggml_cgraph * build_olmoe() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+        // mutable variable, needed during the last layer of the computation to skip unused tokens
+        int32_t n_tokens = this->n_tokens;
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self_attention
+            {
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(Qcur, "Qcur_normed", il);
+
+                Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(Kcur, "Kcur_normed", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = ggml_rope_ext(
+                    ctx0, Qcur, inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur_rope", il);
+
+                Kcur = ggml_rope_ext(
+                    ctx0, Kcur, inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur_rope", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        model.layers[il].wo, NULL,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                n_tokens = n_outputs;
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // MoE branch
+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = llm_build_moe_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_gate_inp,
+                    model.layers[il].ffn_up_exps,
+                    model.layers[il].ffn_gate_exps,
+                    model.layers[il].ffn_down_exps,
+                    n_expert, n_expert_used,
+                    LLM_FFN_SILU, false,
+                    false, 0.0,
+                    cb, il);
+            cb(cur, "ffn_moe_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
             cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
@@ -13110,8 +14157,8 @@ struct llm_build_context {
         cur = inpL;
 
         cur = llm_build_norm(ctx0, cur, hparams,
-                NULL, NULL,
-                LLM_NORM, cb, -1);
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
         cb(cur, "result_norm", -1);
 
         // lm_head
@@ -13779,7 +14826,9 @@ struct llm_build_context {
             {
                 // compute Q and K and RoPE them
                 struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale);
+                if (model.layers[il].wq_scale) {
+                    Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale);
+                }
                 cb(Qcur, "Qcur", il);
                 if (model.layers[il].bq) {
                     Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
@@ -13788,7 +14837,9 @@ struct llm_build_context {
 
                 // B1.K
                 struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale);
+                if (model.layers[il].wk_scale) {
+                    Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale);
+                }
                 cb(Kcur, "Kcur", il);
                 if (model.layers[il].bk) {
                     Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
@@ -13797,7 +14848,9 @@ struct llm_build_context {
 
                 // B1.V
                 struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale);
+                if (model.layers[il].wv_scale) {
+                    Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale);
+                }
                 cb(Vcur, "Vcur", il);
                 if (model.layers[il].bv) {
                     Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -13828,7 +14881,9 @@ struct llm_build_context {
                 cb(cur, "attn_sub_norm", il);
 
                 cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
-                cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale);
+                if (model.layers[il].wo_scale) {
+                    cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale);
+                }
                 if (model.layers[il].bo) {
                     cur = ggml_add(ctx0, cur, model.layers[il].bo);
                 }
@@ -13865,7 +14920,9 @@ struct llm_build_context {
             cb(cur, "ffn_sub_norm", il);
 
             cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_down, cur);
-            cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale);
+            if (model.layers[il].ffn_down_scale) {
+                cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale);
+            }
             cb(cur, "ffn_down", il);
 
             cur = ggml_add(ctx0, cur, ffn_inp);
@@ -14680,6 +15737,117 @@ struct llm_build_context {
 
         return gf;
     }
+
+    ggml_cgraph * build_rwkv6() {
+        ggml_cgraph *gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+        // Token shift state dimensions should be 2 * n_emb
+        GGML_ASSERT(n_embd == hparams.n_embd_k_s() / 2);
+
+        const int64_t n_seqs = batch.n_seqs;
+        const int64_t n_seq_tokens = batch.n_seq_tokens;
+        const int64_t n_tokens = batch.n_tokens;
+        GGML_ASSERT(n_seqs != 0);
+        GGML_ASSERT(batch.equal_seqs);
+        GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+        struct ggml_tensor * state_copy = build_inp_s_copy();
+        struct ggml_tensor * state_mask = build_inp_s_mask();
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
+
+        for (int il = 0; il < n_layer; ++il) {
+            const llama_layer * layer = &model.layers[il];
+
+            // (ab)using the KV cache to store the states
+            struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0,
+                    gf, kv_self.k_l[il], state_copy, state_mask,
+                    hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs);
+            struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0,
+                    gf, kv_self.v_l[il], state_copy, state_mask,
+                    hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs);
+
+            cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
+            token_shift = ggml_reshape_3d(ctx0, token_shift, n_embd, 2, n_seqs);
+
+            struct ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
+            struct ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
+
+            struct ggml_tensor * x_norm_att = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, il);
+            struct ggml_tensor * x_prev = ggml_concat(
+                ctx0,
+                att_shift,
+                ggml_view_3d(ctx0, x_norm_att, n_embd, n_seq_tokens - 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], 0),
+                1
+            );
+
+            cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states));
+            ggml_build_forward_expand(gf, cur);
+            ggml_build_forward_expand(
+                gf,
+                ggml_cpy(
+                    ctx0,
+                    wkv_states,
+                    ggml_view_1d(
+                        ctx0,
+                        kv_self.v_l[il],
+                        hparams.n_embd_v_s() * n_seqs,
+                        hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
+                    )
+                )
+            );
+
+            struct ggml_tensor * x_norm_ffn = llm_build_norm(ctx0, cur, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, il);
+            x_prev = ggml_concat(
+                ctx0,
+                ffn_shift,
+                ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0),
+                1
+            );
+            cur = ggml_add(ctx0, cur, llm_build_rwkv6_channel_mix(lctx, ctx0, layer, x_norm_ffn, x_prev));
+            ggml_build_forward_expand(gf, cur);
+
+            struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
+            struct ggml_tensor * last_norm_ffn = ggml_view_3d(ctx0, x_norm_ffn, n_embd, 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_ffn));
+
+            token_shift = ggml_concat(ctx0, last_norm_att, last_norm_ffn, 1);
+
+            ggml_build_forward_expand(
+                gf,
+                ggml_cpy(
+                    ctx0,
+                    ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * 2, 0),
+                    ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il]))
+                )
+            );
+
+            if (hparams.rescale_every_n_layers != 0 && (il + 1) % hparams.rescale_every_n_layers == 0) {
+                cur = ggml_scale(ctx0, cur, 0.5F);
+            }
+
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+        struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+        cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
+        cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+
+        cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+        cb(cur, "result_output", -1);
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
 };
 
 static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@@ -14761,6 +15929,7 @@ static struct ggml_cgraph * llama_build_graph(
 
     switch (model.arch) {
         case LLM_ARCH_LLAMA:
+        case LLM_ARCH_GRANITE:
             {
                 result = llm.build_llama();
             } break;
@@ -14846,6 +16015,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_minicpm();
             } break;
+        case LLM_ARCH_MINICPM3:
+            {
+                result = llm.build_minicpm3();
+            } break;
         case LLM_ARCH_GEMMA:
             {
                 result = llm.build_gemma();
@@ -14878,6 +16051,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_olmo();
             } break;
+        case LLM_ARCH_OLMOE:
+            {
+                result = llm.build_olmoe();
+            } break;
         case LLM_ARCH_OPENELM:
             {
                 result = llm.build_openelm();
@@ -14926,6 +16103,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_exaone();
             } break;
+        case LLM_ARCH_RWKV6:
+            {
+                result = llm.build_rwkv6();
+            } break;
         default:
             GGML_ABORT("fatal error");
     }
@@ -15285,7 +16466,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
 
             // clear unused states
             for (int i = 0; i < n_kv; ++i) {
-                uint32_t        cell_id = i + kv_self.head;
+                const uint32_t  cell_id = i + kv_self.head;
                 llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
 
                 data[i] = (float) (kv_cell.src >= 0);
@@ -15494,9 +16675,10 @@ static void llama_output_reorder(struct llama_context * ctx) {
 }
 
 static void llama_graph_compute(
-        llama_context & lctx,
-          ggml_cgraph * gf,
-                  int   n_threads) {
+          llama_context & lctx,
+            ggml_cgraph * gf,
+                    int   n_threads,
+        ggml_threadpool * threadpool) {
 #ifdef GGML_USE_METAL
     if (ggml_backend_is_metal(lctx.backend_metal)) {
         ggml_backend_metal_set_n_cb(lctx.backend_metal, n_threads);
@@ -15505,6 +16687,7 @@ static void llama_graph_compute(
 
     if (lctx.backend_cpu != nullptr) {
         ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads);
+        ggml_backend_cpu_set_threadpool(lctx.backend_cpu, threadpool);
         ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data);
     }
 #ifdef GGML_USE_BLAS
@@ -15535,7 +16718,7 @@ static int llama_decode_internal(
     const uint32_t n_tokens_all = batch_all.n_tokens;
 
     if (n_tokens_all == 0) {
-        LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
+        LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
         return -1;
     }
 
@@ -15545,6 +16728,15 @@ static int llama_decode_internal(
 
     GGML_ASSERT((!batch_all.token && batch_all.embd) || (batch_all.token && !batch_all.embd)); // NOLINT
 
+    if (batch_all.token) {
+        for (uint32_t i = 0; i < n_tokens_all; ++i) {
+            if (batch_all.token[i] < 0 || (uint32_t)batch_all.token[i] >= model.vocab.n_vocab) {
+                LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch_all.token[i]);
+                return -1;
+            }
+        }
+    }
+
     GGML_ASSERT(n_tokens_all <= cparams.n_batch);
 
     GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
@@ -15625,6 +16817,8 @@ static int llama_decode_internal(
         }
 
         int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
+        ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
+
         GGML_ASSERT(n_threads > 0);
 
         // non-causal masks do not use the KV cache
@@ -15659,8 +16853,8 @@ static int llama_decode_internal(
         ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false);
 
         // the output is always the last tensor in the graph
-        struct ggml_tensor * res  = gf->nodes[gf->n_nodes - 1];
-        struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
+        struct ggml_tensor * res  = ggml_graph_node(gf, -1);
+        struct ggml_tensor * embd = ggml_graph_node(gf, -2);
 
         if (lctx.n_outputs == 0) {
             // no output
@@ -15669,9 +16863,9 @@ static int llama_decode_internal(
         } else if (cparams.embeddings) {
             res  = nullptr; // do not extract logits for embedding case
             embd = nullptr;
-            for (int i = gf->n_nodes - 1; i >= 0; --i) {
-                if (strcmp(gf->nodes[i]->name, "result_embd_pooled") == 0) {
-                    embd = gf->nodes[i];
+            for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
+                if (strcmp(ggml_graph_node(gf, i)->name, "result_embd_pooled") == 0) {
+                    embd = ggml_graph_node(gf, i);
                     break;
                 }
             }
@@ -15686,7 +16880,7 @@ static int llama_decode_internal(
 
         llama_set_inputs(lctx, ubatch);
 
-        llama_graph_compute(lctx, gf, n_threads);
+        llama_graph_compute(lctx, gf, n_threads, threadpool);
 
         // update the kv ring buffer
         {
@@ -15825,7 +17019,7 @@ static int llama_encode_internal(
     const uint32_t n_tokens = batch.n_tokens;
 
     if (n_tokens == 0) {
-        LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
+        LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
         return -1;
     }
 
@@ -15835,6 +17029,15 @@ static int llama_encode_internal(
 
     GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
 
+    if (batch.token) {
+        for (uint32_t i = 0; i < n_tokens; ++i) {
+            if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= model.vocab.n_vocab) {
+                LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
+                return -1;
+            }
+        }
+    }
+
     // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
     GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
 
@@ -15863,7 +17066,9 @@ static int llama_encode_internal(
     lctx.inp_embd_enc = NULL;
     lctx.n_outputs = n_tokens;
 
-    const int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
+    int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
+    ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
+
     GGML_ASSERT(n_threads > 0);
 
     ggml_backend_sched_reset(lctx.sched);
@@ -15877,15 +17082,15 @@ static int llama_encode_internal(
     // there are two cases here
     if (llama_model_has_decoder(&lctx.model)) {
         // first case is an encoder-decoder T5 model where embeddings are passed to decoder
-        embd = gf->nodes[gf->n_nodes - 1];
+        embd = ggml_graph_node(gf, -1);
         GGML_ASSERT(strcmp(embd->name, "result_norm") == 0 && "missing result_output tensor");
     } else {
         // second case is an encoder-only T5 model
         if (cparams.embeddings) {
             // only output embeddings if required
-            embd = gf->nodes[gf->n_nodes - 1];
+            embd = ggml_graph_node(gf, -1);
             if (strcmp(embd->name, "result_embd_pooled") != 0) {
-                embd = gf->nodes[gf->n_nodes - 2];
+                embd = ggml_graph_node(gf, -2);
             }
             GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
         }
@@ -15895,7 +17100,7 @@ static int llama_encode_internal(
 
     llama_set_inputs(lctx, ubatch);
 
-    llama_graph_compute(lctx, gf, n_threads);
+    llama_graph_compute(lctx, gf, n_threads, threadpool);
 
     // extract embeddings
     if (embd) {
@@ -16177,7 +17382,7 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
 
     ggml_cgraph * gf = llama_build_graph_defrag(lctx, ids);
 
-    llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
+    llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
 #endif
 
     //const int64_t t_end = ggml_time_us();
@@ -16203,7 +17408,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
 
             llama_set_k_shift(lctx);
 
-            llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
+            llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
 
             need_reserve = true;
         }
@@ -16414,6 +17619,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
                      new_type == GGML_TYPE_Q4_0_8_8) {
                 new_type = GGML_TYPE_Q4_0;
             }
+            else if (ftype == LLAMA_FTYPE_MOSTLY_TQ1_0 || ftype == LLAMA_FTYPE_MOSTLY_TQ2_0) {
+                new_type = GGML_TYPE_Q4_K;
+            }
         }
     } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S ||
                ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M    || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
@@ -16613,6 +17821,8 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
     }
     if (convert_incompatible_tensor) {
         switch (new_type) {
+            case GGML_TYPE_TQ1_0:
+            case GGML_TYPE_TQ2_0:  new_type = GGML_TYPE_Q4_0; break;  // TODO: use a symmetric type instead
             case GGML_TYPE_IQ2_XXS:
             case GGML_TYPE_IQ2_XS:
             case GGML_TYPE_IQ2_S:
@@ -16718,6 +17928,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         case LLAMA_FTYPE_MOSTLY_Q5_K_S:
         case LLAMA_FTYPE_MOSTLY_Q5_K_M:  default_type = GGML_TYPE_Q5_K;    break;
         case LLAMA_FTYPE_MOSTLY_Q6_K:    default_type = GGML_TYPE_Q6_K;    break;
+        case LLAMA_FTYPE_MOSTLY_TQ1_0:   default_type = GGML_TYPE_TQ1_0;   break;
+        case LLAMA_FTYPE_MOSTLY_TQ2_0:   default_type = GGML_TYPE_TQ2_0;   break;
         case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break;
         case LLAMA_FTYPE_MOSTLY_IQ2_XS:  default_type = GGML_TYPE_IQ2_XS;  break;
         case LLAMA_FTYPE_MOSTLY_IQ2_S:   default_type = GGML_TYPE_IQ2_XS;  break;
@@ -16964,6 +18176,13 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         // NOTE: can't use LLM_TN here because the layer number is not known
         quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
 
+        // do not quantize RWKV's time_mix_first tensors
+        quantize &= name.find("time_mix_first.weight") == std::string::npos;
+        quantize &= name.find("time_mix_w1.weight") == std::string::npos;
+        quantize &= name.find("time_mix_w2.weight") == std::string::npos;
+        quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
+        quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
+
         // do not quantize relative position bias (T5)
         quantize &= name.find("attn_rel_b.weight") == std::string::npos;
 
@@ -17347,7 +18566,6 @@ struct llama_model_params llama_model_default_params() {
 
 struct llama_context_params llama_context_default_params() {
     struct llama_context_params result = {
-        /*.seed                        =*/ LLAMA_DEFAULT_SEED,
         /*.n_ctx                       =*/ 512,
         /*.n_batch                     =*/ 2048,
         /*.n_ubatch                    =*/ 512,
@@ -17373,6 +18591,7 @@ struct llama_context_params llama_context_default_params() {
         /*.embeddings                  =*/ false,
         /*.offload_kqv                 =*/ true,
         /*.flash_attn                  =*/ false,
+        /*.no_perf                     =*/ true,
         /*.abort_callback              =*/ nullptr,
         /*.abort_callback_data         =*/ nullptr,
     };
@@ -17380,6 +18599,14 @@ struct llama_context_params llama_context_default_params() {
     return result;
 }
 
+struct llama_sampler_chain_params llama_sampler_chain_default_params() {
+    struct llama_sampler_chain_params result = {
+        /*.no_perf                     =*/ true,
+    };
+
+    return result;
+}
+
 struct llama_model_quantize_params llama_model_quantize_default_params() {
     struct llama_model_quantize_params result = {
         /*.nthread                     =*/ 0,
@@ -17451,6 +18678,19 @@ void llama_numa_init(enum ggml_numa_strategy numa) {
     }
 }
 
+void llama_attach_threadpool(
+             struct llama_context * ctx,
+        ggml_threadpool_t   threadpool,
+        ggml_threadpool_t   threadpool_batch) {
+    ctx->threadpool       = threadpool;
+    ctx->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
+}
+
+void llama_detach_threadpool(struct llama_context * ctx) {
+    ctx->threadpool       = nullptr;
+    ctx->threadpool_batch = nullptr;
+}
+
 void llama_backend_free(void) {
     ggml_quantize_free();
 }
@@ -17474,9 +18714,9 @@ struct llama_model * llama_load_model_from_file(
             unsigned percentage = (unsigned) (100 * progress);
             while (percentage > *cur_percentage_p) {
                 *cur_percentage_p = percentage;
-                LLAMA_LOG_INFO(".");
+                LLAMA_LOG_CONT(".");
                 if (percentage >= 100) {
-                    LLAMA_LOG_INFO("\n");
+                    LLAMA_LOG_CONT("\n");
                 }
             }
             return true;
@@ -17562,6 +18802,7 @@ struct llama_context * llama_new_context_with_model(
     cparams.embeddings       = params.embeddings;
     cparams.offload_kqv      = params.offload_kqv;
     cparams.flash_attn       = params.flash_attn;
+    cparams.no_perf          = params.no_perf;
     cparams.pooling_type     = params.pooling_type;
 
     cparams.n_ctx            = params.n_ctx           == 0    ? hparams.n_ctx_train           : params.n_ctx;
@@ -17620,10 +18861,6 @@ struct llama_context * llama_new_context_with_model(
         cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
     }
 
-    if (params.seed == LLAMA_DEFAULT_SEED) {
-        params.seed = time(NULL);
-    }
-
     LLAMA_LOG_INFO("%s: n_ctx      = %u\n",     __func__, cparams.n_ctx);
     LLAMA_LOG_INFO("%s: n_batch    = %u\n",     __func__, cparams.n_batch);
     LLAMA_LOG_INFO("%s: n_ubatch   = %u\n",     __func__, cparams.n_ubatch);
@@ -17634,10 +18871,10 @@ struct llama_context * llama_new_context_with_model(
     ctx->abort_callback      = params.abort_callback;
     ctx->abort_callback_data = params.abort_callback_data;
 
-    ctx->sampling.rng = std::mt19937(params.seed);
-    ctx->logits_all   = params.logits_all;
+    ctx->logits_all = params.logits_all;
+
     // build worst-case graph for encoder if a model contains encoder
-    ctx->is_encoding  = llama_model_has_encoder(model);
+    ctx->is_encoding = llama_model_has_encoder(model);
 
     uint32_t kv_size = cparams.n_ctx;
     ggml_type type_k = params.type_k;
@@ -17657,6 +18894,20 @@ struct llama_context * llama_new_context_with_model(
 
     if (!hparams.vocab_only) {
         // initialize backends
+#if defined(GGML_USE_RPC)
+        if (model->n_gpu_layers > 0) {
+            for (const auto & endpoint : model->rpc_servers) {
+                ggml_backend_t backend = ggml_backend_rpc_init(endpoint.c_str());
+                if (backend == nullptr) {
+                    LLAMA_LOG_ERROR("%s: failed to initialize RPC to '%s'\n", __func__, endpoint.c_str());
+                    llama_free(ctx);
+                    return nullptr;
+                }
+                ctx->backends.push_back(backend);
+            }
+        }
+#endif
+
 #if defined(GGML_USE_METAL)
         if (model->n_gpu_layers > 0) {
             ctx->backend_metal = ggml_backend_metal_init();
@@ -17781,19 +19032,6 @@ struct llama_context * llama_new_context_with_model(
         }
 #endif
 
-#if defined(GGML_USE_RPC)
-        if (model->n_gpu_layers > 0) {
-            for (const auto & endpoint : model->rpc_servers) {
-                ggml_backend_t backend = ggml_backend_rpc_init(endpoint.c_str());
-                if (backend == nullptr) {
-                    LLAMA_LOG_ERROR("%s: failed to initialize RPC to '%s'\n", __func__, endpoint.c_str());
-                    llama_free(ctx);
-                    return nullptr;
-                }
-                ctx->backends.push_back(backend);
-            }
-        }
-#endif
         ctx->backend_cpu = ggml_backend_cpu_init();
         if (ctx->backend_cpu == nullptr) {
             LLAMA_LOG_ERROR("%s: failed to initialize CPU backend\n", __func__);
@@ -17902,7 +19140,7 @@ struct llama_context * llama_new_context_with_model(
 
             // note: the number of splits during measure is higher than during inference due to the kv shift
             int n_splits = ggml_backend_sched_get_n_splits(ctx->sched);
-            LLAMA_LOG_INFO("%s: graph nodes  = %d\n", __func__, gf->n_nodes);
+            LLAMA_LOG_INFO("%s: graph nodes  = %d\n", __func__, ggml_graph_n_nodes(gf));
             LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits);
         }
     }
@@ -17914,14 +19152,6 @@ void llama_free(struct llama_context * ctx) {
     delete ctx;
 }
 
-const struct llama_model * llama_get_model(const struct llama_context * ctx) {
-    return &ctx->model;
-}
-
-const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx) {
-    return &ctx->model.vocab;
-}
-
 uint32_t llama_n_ctx(const struct llama_context * ctx) {
     return ctx->cparams.n_ctx;
 }
@@ -17942,6 +19172,34 @@ enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
     return model->vocab.type;
 }
 
+int32_t llama_n_vocab(const struct llama_model * model) {
+    return model->hparams.n_vocab;
+}
+
+int32_t llama_n_ctx_train(const struct llama_model * model) {
+    return model->hparams.n_ctx_train;
+}
+
+int32_t llama_n_embd(const struct llama_model * model) {
+    return model->hparams.n_embd;
+}
+
+int32_t llama_n_layer(const struct llama_model * model) {
+    return model->hparams.n_layer;
+}
+
+int32_t llama_n_head(const struct llama_model * model) {
+    return model->hparams.n_head();
+}
+
+const struct llama_model * llama_get_model(const struct llama_context * ctx) {
+    return &ctx->model;
+}
+
+enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
+    return ctx->cparams.pooling_type;
+}
+
 enum llama_rope_type llama_rope_type(const struct llama_model * model) {
     switch (model->arch) {
         // these models do not use RoPE
@@ -17955,6 +19213,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_T5:
         case LLM_ARCH_T5ENCODER:
         case LLM_ARCH_JAIS:
+        case LLM_ARCH_RWKV6:
             return LLAMA_ROPE_TYPE_NONE;
 
         // use what we call a normal RoPE, operating on pairs of consecutive head values
@@ -17971,6 +19230,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_ARCTIC:
         case LLM_ARCH_DEEPSEEK2:
         case LLM_ARCH_CHATGLM:
+        case LLM_ARCH_GRANITE:
             return LLAMA_ROPE_TYPE_NORM;
 
         // the pairs of head values are offset by n_rot/2
@@ -17984,6 +19244,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_QWEN:
         case LLM_ARCH_QWEN2:
         case LLM_ARCH_QWEN2MOE:
+        case LLM_ARCH_OLMOE:
         case LLM_ARCH_PHI2:
         case LLM_ARCH_PHI3:
         case LLM_ARCH_GEMMA:
@@ -17994,6 +19255,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_CODESHELL:
         case LLM_ARCH_NEMOTRON:
         case LLM_ARCH_EXAONE:
+        case LLM_ARCH_MINICPM3:
             return LLAMA_ROPE_TYPE_NEOX;
 
         // all model arches should be listed explicitly here
@@ -18004,26 +19266,6 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
     return LLAMA_ROPE_TYPE_NONE;
 }
 
-enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
-    return ctx->cparams.pooling_type;
-}
-
-int32_t llama_n_vocab(const struct llama_model * model) {
-    return model->hparams.n_vocab;
-}
-
-int32_t llama_n_ctx_train(const struct llama_model * model) {
-    return model->hparams.n_ctx_train;
-}
-
-int32_t llama_n_embd(const struct llama_model * model) {
-    return model->hparams.n_embd;
-}
-
-int32_t llama_n_layer(const struct llama_model * model) {
-    return model->hparams.n_layer;
-}
-
 float llama_rope_freq_scale_train(const struct llama_model * model) {
     return model->hparams.rope_freq_scale_train;
 }
@@ -18123,6 +19365,7 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) {
 bool llama_model_is_recurrent(const struct llama_model * model) {
     switch (model->arch) {
         case LLM_ARCH_MAMBA:  return true;
+        case LLM_ARCH_RWKV6:  return true;
         default:              return false;
     }
 }
@@ -18439,14 +19682,14 @@ struct llama_data_write {
         // TODO: add more model-specific info which should prevent loading the session file if not identical
     }
 
-    void write_rng(const std::mt19937 & rng) {
-        std::ostringstream rng_ss;
-        rng_ss << rng;
+    //void write_rng(const std::mt19937 & rng) {
+    //    std::ostringstream rng_ss;
+    //    rng_ss << rng;
 
-        const std::string & rng_str = rng_ss.str();
+    //    const std::string & rng_str = rng_ss.str();
 
-        write_string(rng_str);
-    }
+    //    write_string(rng_str);
+    //}
 
     void write_output_ids(struct llama_context * ctx) {
         llama_output_reorder(ctx);
@@ -18666,17 +19909,17 @@ struct llama_data_read {
         // TODO: add more info which needs to be identical but which is not verified otherwise
     }
 
-    void read_rng(std::mt19937 & rng) {
-        std::string rng_str;
-        read_string(rng_str);
+    //void read_rng(std::mt19937 & rng) {
+    //    std::string rng_str;
+    //    read_string(rng_str);
 
-        std::istringstream rng_ss(rng_str);
-        rng_ss >> rng;
+    //    std::istringstream rng_ss(rng_str);
+    //    rng_ss >> rng;
 
-        if (rng_ss.fail()) {
-            throw std::runtime_error("failed to load RNG state");
-        }
-    }
+    //    if (rng_ss.fail()) {
+    //        throw std::runtime_error("failed to load RNG state");
+    //    }
+    //}
 
     void read_output_ids(struct llama_context * ctx) {
         std::vector<int32_t> output_pos;
@@ -19106,8 +20349,6 @@ static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_da
 
     data_ctx.write_model_info(ctx);
 
-    data_ctx.write_rng(ctx->sampling.rng);
-
     // copy outputs
     data_ctx.write_output_ids(ctx);
     data_ctx.write_logits(ctx);
@@ -19145,9 +20386,6 @@ static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_da
 
     data_ctx.read_model_info(ctx);
 
-    // set rng
-    data_ctx.read_rng(ctx->sampling.rng);
-
     // set outputs
     data_ctx.read_output_ids(ctx);
     data_ctx.read_logits(ctx);
@@ -19367,16 +20605,16 @@ size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepa
     }
 }
 
-void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch) {
+void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) {
     ctx->cparams.n_threads       = n_threads;
     ctx->cparams.n_threads_batch = n_threads_batch;
 }
 
-uint32_t llama_n_threads(struct llama_context * ctx) {
+int32_t llama_n_threads(struct llama_context * ctx) {
     return ctx->cparams.n_threads;
 }
 
-uint32_t llama_n_threads_batch(struct llama_context * ctx) {
+int32_t llama_n_threads_batch(struct llama_context * ctx) {
     return ctx->cparams.n_threads_batch;
 }
 
@@ -19490,10 +20728,14 @@ void llama_synchronize(struct llama_context * ctx) {
 
     // add the evaluation to the stats
     if (ctx->n_queued_tokens == 1) {
-        ctx->t_eval_us += ggml_time_us() - ctx->t_compute_start_us;
+        if (!ctx->cparams.no_perf) {
+            ctx->t_eval_us += ggml_time_us() - ctx->t_compute_start_us;
+        }
         ctx->n_eval++;
     } else if (ctx->n_queued_tokens > 1) {
-        ctx->t_p_eval_us += ggml_time_us() - ctx->t_compute_start_us;
+        if (!ctx->cparams.no_perf) {
+            ctx->t_p_eval_us += ggml_time_us() - ctx->t_compute_start_us;
+        }
         ctx->n_p_eval += ctx->n_queued_tokens;
     }
 
@@ -19550,8 +20792,9 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
         LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
 #ifndef NDEBUG
         GGML_ABORT("fatal error");
-#endif
+#else
         return nullptr;
+#endif
     }
 }
 
@@ -19599,8 +20842,9 @@ float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
         LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
 #ifndef NDEBUG
         GGML_ABORT("fatal error");
-#endif
+#else
         return nullptr;
+#endif
     }
 }
 
@@ -20034,124 +21278,18 @@ int32_t llama_chat_apply_template(
 }
 
 //
-// grammar
+// sampling
 //
 
-struct llama_grammar * llama_grammar_init(
-        const llama_grammar_element ** rules,
-        size_t    n_rules,
-        size_t    start_rule_index) {
-    return llama_grammar_init_impl(rules, n_rules, start_rule_index);
-}
-
-void llama_grammar_free(struct llama_grammar * grammar) {
-    llama_grammar_free_impl(grammar);
-}
-
-struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) {
-    return llama_grammar_copy_impl(grammar);
-}
-
-void llama_grammar_sample(
-      const struct llama_grammar * grammar,
-      const struct llama_context * ctx,
-          llama_token_data_array * candidates) {
-    llama_grammar_sample_impl(grammar, &ctx->model.vocab, &ctx->sampling, candidates);
-}
-
-void llama_sample_grammar(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-      const struct llama_grammar * grammar) {
-    llama_grammar_sample(grammar, ctx, candidates);
-}
-
-void llama_grammar_accept_token(
-            struct llama_grammar * grammar,
-            struct llama_context * ctx,
-                     llama_token   token) {
-    llama_grammar_accept_token_impl(grammar, &ctx->model.vocab, &ctx->sampling, token);
+// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp
+struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) {
+    return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root);
 }
 
 //
-// sampling
+// model split
 //
 
-void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) {
-    llama_set_rng_seed_impl(&ctx->sampling, seed);
-}
-
-void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) {
-    llama_sample_softmax_impl(ctx ? &ctx->sampling : nullptr, candidates);
-}
-
-void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
-    llama_sample_top_k_impl(ctx ? &ctx->sampling : nullptr, candidates, k, min_keep);
-}
-
-void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
-    llama_sample_top_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
-}
-
-void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
-    llama_sample_min_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
-}
-
-void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
-    llama_sample_tail_free_impl(ctx ? &ctx->sampling : nullptr, candidates, z, min_keep);
-}
-
-void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
-    llama_sample_typical_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
-}
-
-void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) {
-    llama_sample_entropy_impl(ctx ? &ctx->sampling : nullptr, candidates_p, min_temp, max_temp, exponent_val);
-}
-
-void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) {
-    llama_sample_temp_impl(ctx ? &ctx->sampling : nullptr, candidates_p, temp);
-}
-
-void llama_sample_repetition_penalties(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-               const llama_token * last_tokens,
-                          size_t   penalty_last_n,
-                           float   penalty_repeat,
-                           float   penalty_freq,
-                           float   penalty_present) {
-    llama_sample_repetition_penalties_impl(ctx ? &ctx->sampling : nullptr, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
-}
-
-void llama_sample_apply_guidance(
-          struct llama_context * ctx,
-                         float * logits,
-                         float * logits_guidance,
-                         float   scale) {
-    llama_sample_apply_guidance_impl(&ctx->sampling, logits, logits_guidance, scale);
-}
-
-llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
-    return llama_sample_token_mirostat_impl(&ctx->sampling, candidates, tau, eta, m, mu);
-}
-
-llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) {
-    return llama_sample_token_mirostat_v2_impl(ctx ? &ctx->sampling : nullptr, candidates, tau, eta, mu);
-}
-
-llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) {
-    return llama_sample_token_greedy_impl(ctx ? &ctx->sampling : nullptr, candidates);
-}
-
-llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) {
-    return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, rng);
-}
-
-llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
-    return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, ctx->sampling.rng);
-}
-
 int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
     static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf";
     if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) {
@@ -20176,45 +21314,6 @@ int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int
     return 0;
 }
 
-struct llama_timings llama_get_timings(struct llama_context * ctx) {
-    struct llama_timings result = {
-        /*.t_start_ms  =*/ 1e-3 * ctx->t_start_us,
-        /*.t_end_ms    =*/ 1.00 * ggml_time_ms(),
-        /*.t_load_ms   =*/ 1e-3 * ctx->t_load_us,
-        /*.t_sample_ms =*/ 1e-3 * ctx->sampling.t_sample_us,
-        /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us,
-        /*.t_eval_ms   =*/ 1e-3 * ctx->t_eval_us,
-
-        /*.n_sample =*/ std::max(1, ctx->sampling.n_sample),
-        /*.n_p_eval =*/ std::max(0, ctx->n_p_eval),
-        /*.n_eval   =*/ std::max(1, ctx->n_eval),
-    };
-
-    return result;
-}
-
-void llama_print_timings(struct llama_context * ctx) {
-    const llama_timings timings = llama_get_timings(ctx);
-
-    LLAMA_LOG_INFO("\n");
-    LLAMA_LOG_INFO("%s:        load time = %10.2f ms\n", __func__, timings.t_load_ms);
-    LLAMA_LOG_INFO("%s:      sample time = %10.2f ms / %5d runs   (%8.2f ms per token, %8.2f tokens per second)\n",
-            __func__, timings.t_sample_ms, timings.n_sample, timings.t_sample_ms / timings.n_sample, 1e3 / timings.t_sample_ms * timings.n_sample);
-    LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
-            __func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval);
-    LLAMA_LOG_INFO("%s:        eval time = %10.2f ms / %5d runs   (%8.2f ms per token, %8.2f tokens per second)\n",
-            __func__, timings.t_eval_ms, timings.n_eval, timings.t_eval_ms / timings.n_eval, 1e3 / timings.t_eval_ms * timings.n_eval);
-    LLAMA_LOG_INFO("%s:       total time = %10.2f ms / %5d tokens\n", __func__, (timings.t_end_ms - timings.t_start_ms), (timings.n_p_eval + timings.n_eval));
-}
-
-void llama_reset_timings(struct llama_context * ctx) {
-    ctx->t_start_us  = ggml_time_us();
-    ctx->t_eval_us   = ctx->n_eval   = 0;
-    ctx->t_p_eval_us = ctx->n_p_eval = 0;
-
-    ctx->sampling.reset_timings();
-}
-
 const char * llama_print_system_info(void) {
     static std::string s;
 
@@ -20232,6 +21331,7 @@ const char * llama_print_system_info(void) {
     s += "ARM_FMA = "     + std::to_string(ggml_cpu_has_arm_fma())     + " | ";
     s += "F16C = "        + std::to_string(ggml_cpu_has_f16c())        + " | ";
     s += "FP16_VA = "     + std::to_string(ggml_cpu_has_fp16_va())     + " | ";
+    s += "RISCV_VECT = "  + std::to_string(ggml_cpu_has_riscv_v())     + " | ";
     s += "WASM_SIMD = "   + std::to_string(ggml_cpu_has_wasm_simd())   + " | ";
     s += "BLAS = "        + std::to_string(ggml_cpu_has_blas())        + " | ";
     s += "SSE3 = "        + std::to_string(ggml_cpu_has_sse3())        + " | ";
@@ -20243,7 +21343,43 @@ const char * llama_print_system_info(void) {
     return s.c_str();
 }
 
-void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) {
+struct llama_perf_context_data llama_perf_context(const struct llama_context * ctx) {
+    struct llama_perf_context_data data = {};
+
+    if (ctx == nullptr) {
+        return data;
+    }
+
+    data.t_start_ms  = 1e-3 * ctx->t_start_us;
+    data.t_load_ms   = 1e-3 * ctx->t_load_us;
+    data.t_p_eval_ms = 1e-3 * ctx->t_p_eval_us;
+    data.t_eval_ms   = 1e-3 * ctx->t_eval_us;
+    data.n_p_eval    = std::max(1, ctx->n_p_eval);
+    data.n_eval      = std::max(1, ctx->n_eval);
+
+    return data;
+}
+
+void llama_perf_context_print(const struct llama_context * ctx) {
+    const auto data = llama_perf_context(ctx);
+
+    const double t_end_ms = 1e-3 * ggml_time_us();
+
+    LLAMA_LOG_INFO("%s:        load time = %10.2f ms\n", __func__, data.t_load_ms);
+    LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
+            __func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
+    LLAMA_LOG_INFO("%s:        eval time = %10.2f ms / %5d runs   (%8.2f ms per token, %8.2f tokens per second)\n",
+            __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
+    LLAMA_LOG_INFO("%s:       total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
+}
+
+void llama_perf_context_reset(struct llama_context * ctx) {
+    ctx->t_start_us  = ggml_time_us();
+    ctx->t_eval_us   = ctx->n_eval = 0;
+    ctx->t_p_eval_us = ctx->n_p_eval = 0;
+}
+
+void llama_perf_dump_yaml(FILE * stream, const llama_context * ctx) {
     fprintf(stream, "\n");
     fprintf(stream, "###########\n");
     fprintf(stream, "# Timings #\n");
@@ -20254,21 +21390,15 @@ void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) {
             1.0e-3 * ctx->t_eval_us / ctx->n_eval);
     fprintf(stream, "mst_p_eval: %.2f  # ms / token during prompt processing\n",
             1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval);
-    fprintf(stream, "mst_sample: %.2f  # ms / token during sampling\n",
-            1.0e-3 * ctx->sampling.t_sample_us / ctx->sampling.n_sample);
     fprintf(stream, "n_eval: %d  # number of tokens generated (excluding the first one)\n", ctx->n_eval);
     fprintf(stream, "n_p_eval: %d  # number of tokens processed in batches at the beginning\n", ctx->n_p_eval);
-    fprintf(stream, "n_sample: %d  # number of sampled tokens\n", ctx->sampling.n_sample);
     fprintf(stream, "t_eval_us: %" PRId64 "  # total microseconds spent generating tokens\n", ctx->t_eval_us);
     fprintf(stream, "t_load_us: %" PRId64 "  # total microseconds spent loading the model\n", ctx->t_load_us);
     fprintf(stream, "t_p_eval_us: %" PRId64 "  # total microseconds spent prompt processing\n", ctx->t_p_eval_us);
-    fprintf(stream, "t_sample_us: %" PRId64 "  # total microseconds spent sampling\n", ctx->sampling.t_sample_us);
     fprintf(stream, "ts_eval: %.2f  # tokens / second during generation\n",
             1.0e6 * ctx->n_eval / ctx->t_eval_us);
     fprintf(stream, "ts_p_eval: %.2f  # tokens / second during prompt processing\n",
             1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us);
-    fprintf(stream, "ts_sample: %.2f  # tokens / second during sampling\n",
-            1.0e6 * ctx->sampling.n_sample / ctx->sampling.t_sample_us);
 }
 
 // For internal test use
@@ -20298,8 +21428,8 @@ static void llama_log_internal_v(ggml_log_level level, const char * format, va_l
     if (len < 128) {
         g_state.log_callback(level, buffer, g_state.log_callback_user_data);
     } else {
-        char* buffer2 = new char[len+1];
-        vsnprintf(buffer2, len+1, format, args_copy);
+        char * buffer2 = new char[len + 1];
+        vsnprintf(buffer2, len + 1, format, args_copy);
         buffer2[len] = 0;
         g_state.log_callback(level, buffer2, g_state.log_callback_user_data);
         delete[] buffer2;
index 6cca6320b347d860246d30d34f174b0f5a0affa6..132937a0700e7c9bd69f517f05a7141a04a25a55 100644 (file)
 
 #define LLAMA_DEFAULT_SEED 0xFFFFFFFF
 
+// TODO: use everywhere in the implementation
+#define LLAMA_TOKEN_NULL -1
+
 #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
 #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
 #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
 
 #define LLAMA_SESSION_MAGIC   LLAMA_FILE_MAGIC_GGSN
-#define LLAMA_SESSION_VERSION 8
+#define LLAMA_SESSION_VERSION 9
 
 #define LLAMA_STATE_SEQ_MAGIC   LLAMA_FILE_MAGIC_GGSQ
 #define LLAMA_STATE_SEQ_VERSION 2
@@ -53,8 +56,10 @@ extern "C" {
     // TODO: show sample usage
     //
 
+    // struct llama_vocab; // TODO: add in the future
     struct llama_model;
     struct llama_context;
+    struct llama_sampler;
 
     typedef int32_t llama_pos;
     typedef int32_t llama_token;
@@ -66,6 +71,7 @@ extern "C" {
         LLAMA_VOCAB_TYPE_BPE  = 2, // GPT-2 tokenizer based on byte-level BPE
         LLAMA_VOCAB_TYPE_WPM  = 3, // BERT tokenizer based on WordPiece
         LLAMA_VOCAB_TYPE_UGM  = 4, // T5 tokenizer based on Unigram
+        LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
     };
 
     // pre-tokenization types
@@ -166,6 +172,8 @@ extern "C" {
         LLAMA_FTYPE_MOSTLY_Q4_0_4_4      = 33, // except 1d tensors
         LLAMA_FTYPE_MOSTLY_Q4_0_4_8      = 34, // except 1d tensors
         LLAMA_FTYPE_MOSTLY_Q4_0_8_8      = 35, // except 1d tensors
+        LLAMA_FTYPE_MOSTLY_TQ1_0         = 36, // except 1d tensors
+        LLAMA_FTYPE_MOSTLY_TQ2_0         = 37, // except 1d tensors
 
         LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
     };
@@ -198,6 +206,7 @@ extern "C" {
         LLAMA_SPLIT_MODE_ROW     = 2, // split rows across GPUs
     };
 
+    // TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979)
     typedef struct llama_token_data {
         llama_token id; // token id
         float logit;    // log-odds of the token
@@ -205,8 +214,10 @@ extern "C" {
     } llama_token_data;
 
     typedef struct llama_token_data_array {
+        // TODO: consider SoA
         llama_token_data * data;
         size_t size;
+        int64_t selected; // this is the index in the data array (i.e. not the token id)
         bool sorted;
     } llama_token_data_array;
 
@@ -267,9 +278,9 @@ extern "C" {
         enum llama_split_mode split_mode; // how to split the model across multiple GPUs
 
         // main_gpu interpretation depends on split_mode:
-        // LLAMA_SPLIT_NONE: the GPU that is used for the entire model
-        // LLAMA_SPLIT_ROW: the GPU that is used for small tensors and intermediate results
-        // LLAMA_SPLIT_LAYER: ignored
+        // LLAMA_SPLIT_MODE_NONE: the GPU that is used for the entire model
+        // LLAMA_SPLIT_MODE_ROW: the GPU that is used for small tensors and intermediate results
+        // LLAMA_SPLIT_MODE_LAYER: ignored
         int32_t main_gpu;
 
         // proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
@@ -299,13 +310,12 @@ extern "C" {
     // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
     //       https://github.com/ggerganov/llama.cpp/pull/7544
     struct llama_context_params {
-        uint32_t seed;              // RNG seed, -1 for random
         uint32_t n_ctx;             // text context, 0 = from model
         uint32_t n_batch;           // logical maximum batch size that can be submitted to llama_decode
         uint32_t n_ubatch;          // physical maximum batch size
         uint32_t n_seq_max;         // max number of sequences (i.e. distinct states for recurrent models)
-        uint32_t n_threads;         // number of threads to use for generation
-        uint32_t n_threads_batch;   // number of threads to use for batch processing
+        int32_t  n_threads;         // number of threads to use for generation
+        int32_t  n_threads_batch;   // number of threads to use for batch processing
 
         enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
         enum llama_pooling_type      pooling_type;      // whether to pool (sum) embedding results by sequence id
@@ -327,11 +337,13 @@ extern "C" {
         enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
         enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
 
-        // Keep the booleans together to avoid misalignment during copy-by-value.
+        // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
+        // TODO: move at the end of the struct
         bool logits_all;  // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
         bool embeddings;  // if true, extract embeddings (together with logits)
         bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
         bool flash_attn;  // whether to use flash attention [EXPERIMENTAL]
+        bool no_perf;     // whether to measure performance timings
 
         // Abort callback
         // if it returns true, execution of llama_decode() will be aborted
@@ -355,56 +367,14 @@ extern "C" {
         void * kv_overrides;                 // pointer to vector containing overrides
     } llama_model_quantize_params;
 
-    // grammar types
-    struct llama_grammar;
-
-    // grammar element type
-    enum llama_gretype {
-        // end of rule definition
-        LLAMA_GRETYPE_END            = 0,
-
-        // start of alternate definition for rule
-        LLAMA_GRETYPE_ALT            = 1,
-
-        // non-terminal element: reference to rule
-        LLAMA_GRETYPE_RULE_REF       = 2,
-
-        // terminal element: character (code point)
-        LLAMA_GRETYPE_CHAR           = 3,
-
-        // inverse char(s) ([^a], [^a-b] [^abc])
-        LLAMA_GRETYPE_CHAR_NOT       = 4,
-
-        // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
-        // be an inclusive range ([a-z])
-        LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
-
-        // modifies a preceding LLAMA_GRETYPE_CHAR or
-        // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
-        LLAMA_GRETYPE_CHAR_ALT       = 6,
-
-        // any character (.)
-        LLAMA_GRETYPE_CHAR_ANY       = 7,
-    };
-
-    typedef struct llama_grammar_element {
-        enum llama_gretype type;
-        uint32_t           value; // Unicode code point or rule ID
-    } llama_grammar_element;
-
-    // performance timing information
-    struct llama_timings {
-        double t_start_ms;
-        double t_end_ms;
-        double t_load_ms;
-        double t_sample_ms;
-        double t_p_eval_ms;
-        double t_eval_ms;
+    typedef struct llama_logit_bias {
+        llama_token token;
+        float bias;
+    } llama_logit_bias;
 
-        int32_t n_sample;
-        int32_t n_p_eval;
-        int32_t n_eval;
-    };
+    typedef struct llama_sampler_chain_params {
+        bool no_perf; // whether to measure performance timings
+    } llama_sampler_chain_params;
 
     // used in chat template
     typedef struct llama_chat_message {
@@ -416,8 +386,10 @@ extern "C" {
     struct llama_lora_adapter;
 
     // Helpers for getting default parameters
-    LLAMA_API struct llama_model_params llama_model_default_params(void);
-    LLAMA_API struct llama_context_params llama_context_default_params(void);
+    // TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172)
+    LLAMA_API struct llama_model_params          llama_model_default_params(void);
+    LLAMA_API struct llama_context_params        llama_context_default_params(void);
+    LLAMA_API struct llama_sampler_chain_params  llama_sampler_chain_default_params(void);
     LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
 
     // Initialize the llama + ggml backend
@@ -428,15 +400,23 @@ extern "C" {
     //optional:
     LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa);
 
+    // Optional: an auto threadpool gets created in ggml if not passed explicitly
+    LLAMA_API void llama_attach_threadpool(
+               struct   llama_context * ctx,
+            ggml_threadpool_t   threadpool,
+            ggml_threadpool_t   threadpool_batch);
+    LLAMA_API void llama_detach_threadpool(struct llama_context * ctx);
+
     // Call once at the end of the program - currently only used for MPI
     LLAMA_API void llama_backend_free(void);
 
     LLAMA_API struct llama_model * llama_load_model_from_file(
                              const char * path_model,
-            struct llama_model_params     params);
+              struct llama_model_params   params);
 
     LLAMA_API void llama_free_model(struct llama_model * model);
 
+    // TODO: rename to llama_init_from_model
     LLAMA_API struct llama_context * llama_new_context_with_model(
                      struct llama_model * model,
             struct llama_context_params   params);
@@ -452,22 +432,22 @@ extern "C" {
     LLAMA_API bool llama_supports_mlock      (void);
     LLAMA_API bool llama_supports_gpu_offload(void);
 
-    LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
-
     LLAMA_API uint32_t llama_n_ctx      (const struct llama_context * ctx);
     LLAMA_API uint32_t llama_n_batch    (const struct llama_context * ctx);
     LLAMA_API uint32_t llama_n_ubatch   (const struct llama_context * ctx);
     LLAMA_API uint32_t llama_n_seq_max  (const struct llama_context * ctx);
 
-    LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
-
-    LLAMA_API enum llama_vocab_type   llama_vocab_type  (const struct llama_model * model);
-    LLAMA_API enum llama_rope_type    llama_rope_type   (const struct llama_model * model);
-
     LLAMA_API int32_t llama_n_vocab    (const struct llama_model * model);
     LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
     LLAMA_API int32_t llama_n_embd     (const struct llama_model * model);
     LLAMA_API int32_t llama_n_layer    (const struct llama_model * model);
+    LLAMA_API int32_t llama_n_head     (const struct llama_model * model);
+
+    LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
+
+    LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
+    LLAMA_API enum llama_vocab_type   llama_vocab_type  (const struct llama_model * model);
+    LLAMA_API enum llama_rope_type    llama_rope_type   (const struct llama_model * model);
 
     // Get the model's RoPE frequency scaling factor
     LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
@@ -696,7 +676,7 @@ extern "C" {
     //
 
     // Returns the *actual* size in bytes of the state
-    // (rng, logits, embedding and kv_cache)
+    // (logits, embedding and kv_cache)
     // Only use when saving the state, not when restoring it, otherwise the size may be too small.
     LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
     LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
@@ -837,13 +817,13 @@ extern "C" {
     // Set the number of threads used for decoding
     // n_threads is the number of threads used for generation (single token)
     // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
-    LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
+    LLAMA_API void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch);
 
     // Get the number of threads used for generation of a single token.
-    LLAMA_API uint32_t llama_n_threads(struct llama_context * ctx);
+    LLAMA_API int32_t llama_n_threads(struct llama_context * ctx);
 
     // Get the number of threads used for prompt and batch processing (multiple token).
-    LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx);
+    LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx);
 
     // Set whether the model is in embeddings mode or not
     // If true, embeddings will be returned but logits will not
@@ -999,121 +979,114 @@ extern "C" {
                                int32_t   length);
 
     //
-    // Grammar
+    // Sampling API
+    //
+    // Sample usage:
+    //
+    //    // prepare the sampling chain at the start
+    //    auto sparams = llama_sampler_chain_default_params();
+    //
+    //    llama_sampler * smpl = llama_sampler_chain_init(sparams);
+    //
+    //    llama_sampler_chain_add(smpl, llama_sampler_init_top_k(50));
+    //    llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1));
+    //    llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.8));
+    //
+    //    // typically, the chain should end with a sampler such as "greedy", "dist" or "mirostat"
+    //    // this sampler will be responsible to select the actual token
+    //    llama_sampler_chain_add(smpl, llama_sampler_init_dist(seed));
+    //
+    //    ...
+    //
+    //    // decoding loop:
+    //    while (...) {
+    //        ...
+    //
+    //        llama_decode(ctx, batch);
+    //
+    //        // sample from the logits of the last token in the batch
+    //        const llama_token id = llama_sampler_sample(smpl, ctx, -1);
+    //
+    //        // accepting the token updates the internal state of certain samplers (e.g. grammar, repetition, etc.)
+    //        llama_sampler_accept(smpl, id);
+    //        ...
+    //    }
+    //
+    //    llama_sampler_free(smpl);
+    //
+    // TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
+    // TODO: in the future, the entire sampling API that uses llama_model should start using llama_vocab
     //
 
-    /// Initialize a llama_grammar.
-    ///
-    /// @param rules The rule elements of the grammar to initialize.
-    /// @param n_rules The number of rules.
-    /// @param start_rule_index The index of the root rule (the starting point of the grammar).
-    /// @return The initialized llama_grammar or nullptr if initialization failed.
-    LLAMA_API struct llama_grammar * llama_grammar_init(
-            const llama_grammar_element ** rules,
-                                 size_t    n_rules,
-                                 size_t    start_rule_index);
-
-    LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
-
-    LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
-
-    /// @details Apply constraints from grammar
-    LLAMA_API void llama_grammar_sample(
-            const struct llama_grammar * grammar,
-            const struct llama_context * ctx,
-                llama_token_data_array * candidates);
-    LLAMA_API DEPRECATED(void llama_sample_grammar(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-      const struct llama_grammar * grammar),
-        "use llama_grammar_sample instead");
+    typedef void * llama_sampler_context_t;
 
-    /// @details Accepts the sampled token into the grammar
-    LLAMA_API void llama_grammar_accept_token(
-            struct llama_grammar * grammar,
-            struct llama_context * ctx,
-                     llama_token   token);
+    // user code can implement the interface below in order to create custom llama_sampler
+    struct llama_sampler_i {
+        const char *           (*name)  (const struct llama_sampler * smpl);                                 // can be NULL
+        void                   (*accept)(      struct llama_sampler * smpl, llama_token token);              // can be NULL
+        void                   (*apply) (      struct llama_sampler * smpl, llama_token_data_array * cur_p); // required
+        void                   (*reset) (      struct llama_sampler * smpl);                                 // can be NULL
+        struct llama_sampler * (*clone) (const struct llama_sampler * smpl);                                 // can be NULL if ctx is NULL
+        void                   (*free)  (      struct llama_sampler * smpl);                                 // can be NULL if ctx is NULL
 
-    //
-    // Sampling functions
-    //
+        // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph
+        //void (*apply_ggml) (struct llama_sampler * smpl, ...);
+    };
+
+    struct llama_sampler {
+        struct llama_sampler_i  * iface;
+        llama_sampler_context_t   ctx;
+    };
 
-    // Sets the current rng seed.
-    LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
+    // mirror of llama_sampler_i:
+    LLAMA_API const char *           llama_sampler_name  (const struct llama_sampler * smpl);
+    LLAMA_API void                   llama_sampler_accept(      struct llama_sampler * smpl, llama_token token);
+    LLAMA_API void                   llama_sampler_apply (      struct llama_sampler * smpl, llama_token_data_array * cur_p);
+    LLAMA_API void                   llama_sampler_reset (      struct llama_sampler * smpl);
+    LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl);
+    // important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add)
+    LLAMA_API void                   llama_sampler_free  (      struct llama_sampler * smpl);
 
-    /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
-    /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
-    LLAMA_API void llama_sample_repetition_penalties(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-               const llama_token * last_tokens,
-                          size_t   penalty_last_n,
-                           float   penalty_repeat,
-                           float   penalty_freq,
-                           float   penalty_present);
-
-    /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
-    /// @param logits Logits extracted from the original generation context.
-    /// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
-    /// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
-    LLAMA_API void llama_sample_apply_guidance(
-              struct llama_context * ctx,
-                             float * logits,
-                             float * logits_guidance,
-                             float   scale);
+    // llama_sampler_chain
+    // a type of llama_sampler that can chain multiple samplers one after another
+
+    LLAMA_API struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params);
+
+    // important: takes ownership of the sampler object and will free it when llama_sampler_free is called
+    LLAMA_API void                   llama_sampler_chain_add(      struct llama_sampler * chain, struct llama_sampler * smpl);
+    LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i);
+    LLAMA_API int                    llama_sampler_chain_n  (const struct llama_sampler * chain);
+
+    // after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed
+    LLAMA_API struct llama_sampler * llama_sampler_chain_remove(   struct llama_sampler * chain, int32_t i);
+
+    // available samplers:
+
+    LLAMA_API struct llama_sampler * llama_sampler_init_greedy     (void);
+    LLAMA_API struct llama_sampler * llama_sampler_init_dist       (uint32_t seed);
 
     /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
-    LLAMA_API void llama_sample_softmax(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates);
+    /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
+    LLAMA_API struct llama_sampler * llama_sampler_init_softmax    (void);
 
     /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
-    LLAMA_API void llama_sample_top_k(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-                         int32_t   k,
-                          size_t   min_keep);
+    LLAMA_API struct llama_sampler * llama_sampler_init_top_k      (int32_t k);
 
     /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
-    LLAMA_API void llama_sample_top_p(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-                           float   p,
-                          size_t   min_keep);
+    LLAMA_API struct llama_sampler * llama_sampler_init_top_p      (float   p, size_t min_keep);
 
     /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
-    LLAMA_API void llama_sample_min_p(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-                           float   p,
-                          size_t   min_keep);
+    LLAMA_API struct llama_sampler * llama_sampler_init_min_p      (float   p, size_t min_keep);
 
     /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
-    LLAMA_API void llama_sample_tail_free(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-                           float   z,
-                          size_t   min_keep);
+    LLAMA_API struct llama_sampler * llama_sampler_init_tail_free  (float   z, size_t min_keep);
 
     /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
-    LLAMA_API void llama_sample_typical(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-                           float   p,
-                          size_t   min_keep);
+    LLAMA_API struct llama_sampler * llama_sampler_init_typical    (float   p, size_t min_keep);
+    LLAMA_API struct llama_sampler * llama_sampler_init_temp       (float   t);
 
-    /// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772.
-    LLAMA_API void llama_sample_entropy(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates_p,
-                           float   min_temp,
-                           float   max_temp,
-                           float   exponent_val);
-
-    LLAMA_API void llama_sample_temp(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-                           float   temp);
+    /// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
+    LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext   (float   t, float   delta, float exponent);
 
     /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
     /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
@@ -1121,36 +1094,62 @@ extern "C" {
     /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
     /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
     /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
-    LLAMA_API llama_token llama_sample_token_mirostat(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-                           float   tau,
-                           float   eta,
-                         int32_t   m,
-                           float * mu);
+    LLAMA_API struct llama_sampler * llama_sampler_init_mirostat(
+                             int32_t   n_vocab,
+                            uint32_t   seed,
+                               float   tau,
+                               float   eta,
+                             int32_t   m);
 
     /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
     /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
     /// @param tau  The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
     /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
     /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
-    LLAMA_API llama_token llama_sample_token_mirostat_v2(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-                           float   tau,
-                           float   eta,
-                           float * mu);
-
-    /// @details Selects the token with the highest probability.
-    ///          Does not compute the token probabilities. Use llama_sample_softmax() instead.
-    LLAMA_API llama_token llama_sample_token_greedy(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates);
-
-    /// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx.
-    LLAMA_API llama_token llama_sample_token(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates);
+    LLAMA_API struct llama_sampler * llama_sampler_init_mirostat_v2(
+                            uint32_t   seed,
+                               float   tau,
+                               float   eta);
+
+    LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
+            const struct llama_model * model,
+                          const char * grammar_str,
+                          const char * grammar_root);
+
+    LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
+                             int32_t   n_vocab,         // llama_n_vocab()
+                         llama_token   special_eos_id,  // llama_token_eos()
+                         llama_token   linefeed_id,     // llama_token_nl()
+                             int32_t   penalty_last_n,  // last n tokens to penalize (0 = disable penalty, -1 = context size)
+                               float   penalty_repeat,  // 1.0 = disabled
+                               float   penalty_freq,    // 0.0 = disabled
+                               float   penalty_present, // 0.0 = disabled
+                                bool   penalize_nl,     // consider newlines as a repeatable token
+                                bool   ignore_eos);     // ignore the end-of-sequence token
+
+    LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
+                             int32_t   n_vocab,
+                             int32_t   n_logit_bias,
+              const llama_logit_bias * logit_bias);
+
+
+    // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
+    LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
+
+    /// @details Sample and accept a token from the idx-th output of the last evaluation
+    //
+    // Shorthand for:
+    //    const auto * logits = llama_get_logits_ith(ctx, idx);
+    //    llama_token_data_array cur_p = { ... init from logits ... };
+    //    llama_sampler_apply(smpl, &cur_p);
+    //    auto token = cur_p.data[cur_p.selected].id;
+    //    llama_sampler_accept(smpl, token);
+    //    return token;
+    // Returns the sampled token
+    LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx);
+
+    // TODO: extend in the future
+    //LLAMA_API void llama_decode_with_sampler(struct llama_context * ctx, struct llama_sampler * smpl, struct llama_batch batch, ...);
 
     //
     // Model split
@@ -1166,12 +1165,6 @@ extern "C" {
     //  Returns the split_prefix length.
     LLAMA_API int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count);
 
-    // Performance information
-    LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
-
-    LLAMA_API void llama_print_timings(struct llama_context * ctx);
-    LLAMA_API void llama_reset_timings(struct llama_context * ctx);
-
     // Print system information
     LLAMA_API const char * llama_print_system_info(void);
 
@@ -1179,65 +1172,41 @@ extern "C" {
     // If this is not called, or NULL is supplied, everything is output on stderr.
     LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
 
-    LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx);
-
-#ifdef __cplusplus
-}
-#endif
-
-// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
-#ifdef LLAMA_API_INTERNAL
-
-#include <random>
-#include <string>
-#include <vector>
-
-struct ggml_tensor;
-
-const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
-    struct llama_context * ctx
-);
-
-struct llama_partial_utf8 {
-    uint32_t value;    // bit value so far (unshifted)
-    int      n_remain; // num bytes remaining; -1 indicates invalid sequence
-};
-
-struct llama_grammar_candidate {
-    size_t               index;
-    const uint32_t     * code_points;
-    llama_partial_utf8   partial_utf8;
-};
+    //
+    // Performance utils
+    //
+    // NOTE: Used by llama.cpp examples, avoid using in third-party apps. Instead, do your own performance measurements.
+    //
 
-using llama_grammar_rule  = std::vector<      llama_grammar_element>;
-using llama_grammar_stack = std::vector<const llama_grammar_element *>;
+    struct llama_perf_context_data {
+        double t_start_ms;
+        double t_load_ms;
+        double t_p_eval_ms;
+        double t_eval_ms;
 
-using llama_grammar_rules      = std::vector<llama_grammar_rule>;
-using llama_grammar_stacks     = std::vector<llama_grammar_stack>;
-using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
+        int32_t n_p_eval;
+        int32_t n_eval;
+    };
 
-const llama_grammar_rules  & llama_grammar_get_rules (const struct llama_grammar * grammar);
-      llama_grammar_stacks & llama_grammar_get_stacks(      struct llama_grammar * grammar);
+    struct llama_perf_sampler_data {
+        double t_sample_ms;
 
-void llama_grammar_accept(
-        const llama_grammar_rules  & rules,
-        const llama_grammar_stacks & stacks,
-        const uint32_t chr,
-              llama_grammar_stacks & new_stacks);
+        int32_t n_sample;
+    };
 
-std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
-        const llama_grammar_rules & rules,
-        const llama_grammar_stack & stack,
-        const llama_grammar_candidates & candidates);
+    LLAMA_API struct llama_perf_context_data llama_perf_context      (const struct llama_context * ctx);
+    LLAMA_API void                           llama_perf_context_print(const struct llama_context * ctx);
+    LLAMA_API void                           llama_perf_context_reset(      struct llama_context * ctx);
 
-std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
-        const std::string & src,
-        llama_partial_utf8 partial_start);
+    // NOTE: the following work only with samplers constructed via llama_sampler_chain_init
+    LLAMA_API struct llama_perf_sampler_data llama_perf_sampler      (const struct llama_sampler * chain);
+    LLAMA_API void                           llama_perf_sampler_print(const struct llama_sampler * chain);
+    LLAMA_API void                           llama_perf_sampler_reset(      struct llama_sampler * chain);
 
-// Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
-// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
-llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng);
+    LLAMA_API void llama_perf_dump_yaml(FILE * stream, const struct llama_context * ctx);
 
-#endif // LLAMA_API_INTERNAL
+#ifdef __cplusplus
+}
+#endif
 
 #endif // LLAMA_H
index 32ed96464c52acad7680441d602397f5e5474b99..1b9de94d724034a4ef65f28b42990a455ba99f90 100644 (file)
@@ -314,7 +314,6 @@ int main(int argc, char ** argv) {
 
     // tune these to your liking
     lcparams.n_ctx      = 2048;
-    lcparams.seed       = 1;
     lcparams.n_threads  = params.n_threads;
     lcparams.flash_attn = params.flash_attn;
 
@@ -402,6 +401,26 @@ int main(int argc, char ** argv) {
 
     llama_batch batch = llama_batch_init(llama_n_ctx(ctx_llama), 0, 1);
 
+    // init sampler
+    const float top_k = 5;
+    const float top_p = 0.80f;
+    const float temp  = 0.30f;
+
+    const int seed = 0;
+
+    auto sparams = llama_sampler_chain_default_params();
+
+    llama_sampler * smpl = llama_sampler_chain_init(sparams);
+
+    if (temp > 0.0f) {
+        llama_sampler_chain_add(smpl, llama_sampler_init_top_k(top_k));
+        llama_sampler_chain_add(smpl, llama_sampler_init_top_p(top_p, 1));
+        llama_sampler_chain_add(smpl, llama_sampler_init_temp (temp));
+        llama_sampler_chain_add(smpl, llama_sampler_init_dist (seed));
+    } else {
+        llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
+    }
+
     // init session
     std::string path_session = params.path_session;
     std::vector<llama_token> session_tokens;
@@ -700,54 +719,13 @@ int main(int argc, char ** argv) {
 
                     {
                         // out of user input, sample next token
-                        const float top_k          = 5;
-                        const float top_p          = 0.80f;
-                        const float temp           = 0.30f;
-                        const float repeat_penalty = 1.1764f;
-
-                        const int repeat_last_n    = 256;
 
                         if (!path_session.empty() && need_to_save_session) {
                             need_to_save_session = false;
                             llama_state_save_file(ctx_llama, path_session.c_str(), session_tokens.data(), session_tokens.size());
                         }
 
-                        llama_token id = 0;
-
-                        {
-                            auto logits = llama_get_logits(ctx_llama);
-                            auto n_vocab = llama_n_vocab(model_llama);
-
-                            logits[llama_token_eos(model_llama)] = 0;
-
-                            std::vector<llama_token_data> candidates;
-                            candidates.reserve(n_vocab);
-                            for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
-                                candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
-                            }
-
-                            llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
-
-                            // apply repeat penalty
-                            const float nl_logit = logits[llama_token_nl(model_llama)];
-
-                            llama_sample_repetition_penalties(ctx_llama, &candidates_p,
-                                    embd_inp.data() + std::max(0, n_past - repeat_last_n),
-                                    repeat_last_n, repeat_penalty, 0.0, 0.0f);
-
-                            logits[llama_token_nl(model_llama)] = nl_logit;
-
-                            if (temp <= 0) {
-                                // Greedy sampling
-                                id = llama_sample_token_greedy(ctx_llama, &candidates_p);
-                            } else {
-                                // Temperature sampling
-                                llama_sample_top_k(ctx_llama, &candidates_p, top_k, 1);
-                                llama_sample_top_p(ctx_llama, &candidates_p, top_p, 1);
-                                llama_sample_temp (ctx_llama, &candidates_p, temp);
-                                id = llama_sample_token(ctx_llama, &candidates_p);
-                            }
-                        }
+                        const llama_token id = llama_sampler_sample(smpl, ctx_llama, -1);
 
                         if (id != llama_token_eos(model_llama)) {
                             // add it to the context
@@ -797,8 +775,14 @@ int main(int argc, char ** argv) {
     whisper_print_timings(ctx_wsp);
     whisper_free(ctx_wsp);
 
-    llama_print_timings(ctx_llama);
+    llama_perf_sampler_print(smpl);
+    llama_perf_context_print(ctx_llama);
+
+    llama_sampler_free(smpl);
+    llama_batch_free(batch);
     llama_free(ctx_llama);
 
+    llama_backend_free();
+
     return 0;
 }
index 46650bff06d15e9146888c332e4c79913ef09fcb..f4e941cd152611034ade3d0f7462f15d320e8ab2 100644 (file)
@@ -5,6 +5,7 @@
 #include "unicode.h"
 #include "unicode-data.h"
 
+#include <algorithm>
 #include <cassert>
 #include <cstddef>
 #include <cstdint>
index 35874aa5abafebe2c2da5dcc70981f834e20830c..585a6fc02e8bb1d28987d697f50964b85de70e5e 100644 (file)
@@ -177,7 +177,7 @@ static bool ggml_graph_compute_helper(
                          int   n_threads,
          ggml_abort_callback   abort_callback,
                         void * abort_callback_data) {
-    struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
+    struct ggml_cplan plan = ggml_graph_plan(graph, n_threads, nullptr);
 
     plan.abort_callback      = abort_callback;
     plan.abort_callback_data = abort_callback_data;
@@ -2894,7 +2894,7 @@ static bool whisper_decode_internal(
             ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float));
         }
 
-        logits = gf->nodes[gf->n_nodes - 1];
+        logits = ggml_graph_node(gf, -1);
 
         if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
             return false;