throw std::runtime_error("unexpected end of input");
}
+static std::pair<uint32_t, const char *> parse_token(const llama_vocab * vocab, const char * src) {
+ const char * pos = src;
+ if (*pos != '<') {
+ throw std::runtime_error(std::string("expecting '<' at ") + pos);
+ }
+ pos++;
+
+ // Parse <[id]>
+ if (*pos == '[') {
+ pos++;
+ const char * int_end = parse_int(pos);
+ uint32_t token_id = std::stoul(std::string(pos, int_end - pos));
+ pos = int_end;
+ if (*pos != ']') {
+ throw std::runtime_error(std::string("expecting ']' at ") + pos);
+ }
+ pos++;
+ if (*pos != '>') {
+ throw std::runtime_error(std::string("expecting '>' at ") + pos);
+ }
+ pos++;
+ return std::make_pair(token_id, pos);
+ }
+
+ if (vocab == nullptr) {
+ throw std::runtime_error(std::string("no vocab to parse token at ") + src);
+ }
+
+ // Parse <token> and tokenize to obtain the token id
+ while (*pos != 0 && *pos != '>') {
+ pos++;
+ }
+ if (*pos != '>') {
+ throw std::runtime_error(std::string("expecting '>' at ") + pos);
+ }
+ pos++;
+
+ llama_token tokens[2];
+ int32_t n_tokens = vocab->tokenize(src, static_cast<int32_t>(pos - src), tokens, 2, false, true);
+ if (n_tokens != 1) {
+ // must tokenize to exactly 1 token
+ throw std::runtime_error("invalid token '" + std::string(src, pos - src) + "'");
+ }
+ return std::make_pair(tokens[0], pos);
+}
+
static void print_grammar_char(FILE * file, uint32_t c) {
if (0x20 <= c && c <= 0x7f) {
fprintf(file, "%c", static_cast<char>(c));
case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break;
+ case LLAMA_GRETYPE_TOKEN: fprintf(file, "TOKEN"); break;
+ case LLAMA_GRETYPE_TOKEN_NOT: fprintf(file, "TOKEN_NOT"); break;
}
switch (elem.type) {
case LLAMA_GRETYPE_END:
print_grammar_char(file, elem.value);
fprintf(file, "\") ");
break;
+ case LLAMA_GRETYPE_TOKEN:
+ fprintf(file, "<[");
+ fprintf(file, "%u", elem.value);
+ fprintf(file, "]> ");
+ break;
+ case LLAMA_GRETYPE_TOKEN_NOT:
+ fprintf(file, "!");
+ fprintf(file, "<[");
+ fprintf(file, "%u", elem.value);
+ fprintf(file, "]> ");
+ break;
}
}
fprintf(file, "\n");
case LLAMA_GRETYPE_CHAR_ANY:
fprintf(file, ".");
break;
+ case LLAMA_GRETYPE_TOKEN:
+ fprintf(file, "<[");
+ fprintf(file, "%u", elem.value);
+ fprintf(file, "]> ");
+ break;
+ case LLAMA_GRETYPE_TOKEN_NOT:
+ fprintf(file, "!");
+ fprintf(file, "<[");
+ fprintf(file, "%u", elem.value);
+ fprintf(file, "]> ");
+ break;
}
if (is_char_element(elem)) {
switch (rule[i + 1].type) {
}
}
pos = parse_space(pos + 1, is_nested);
+ } else if (*pos == '<' || *pos == '!') { // token
+ auto type = LLAMA_GRETYPE_TOKEN;
+ if (*pos == '!') { // token inverse
+ type = LLAMA_GRETYPE_TOKEN_NOT;
+ pos++;
+ }
+ auto token_pair = parse_token(vocab, pos);
+ const char * token_end = token_pair.second;
+ last_sym_start = rule.size();
+ rule.push_back({type, token_pair.first});
+ pos = parse_space(token_end, is_nested);
} else if (is_word_char(*pos)) { // rule reference
const char * name_end = parse_name(pos);
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
return !is_positive_char;
}
+// returns true iff token matches the rule at pos (regular or inverse)
+// asserts that pos is pointing to a token element
+static bool llama_grammar_match_token(
+ const llama_grammar_element * pos,
+ const llama_token token) {
+ GGML_ASSERT(pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT);
+ if (pos->type == LLAMA_GRETYPE_TOKEN) {
+ return pos->value == static_cast<uint32_t>(token);
+ }
+ if (pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
+ return pos->value != static_cast<uint32_t>(token);
+ }
+ return false;
+}
+
// transforms a grammar pushdown stack into N possible stacks, all ending
// at a character range (terminal element)
static void llama_grammar_advance_stack(
case LLAMA_GRETYPE_CHAR:
case LLAMA_GRETYPE_CHAR_NOT:
case LLAMA_GRETYPE_CHAR_ANY:
+ case LLAMA_GRETYPE_TOKEN:
+ case LLAMA_GRETYPE_TOKEN_NOT:
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
// only add the stack if it's not a duplicate of one we already have
new_stacks.emplace_back(stack);
return grammar->stacks;
}
+static void llama_grammar_accept_chr(
+ struct llama_grammar & grammar,
+ const llama_grammar_stack & stack,
+ uint32_t chr,
+ llama_grammar_stacks & new_stacks) {
+ if (stack.empty()) {
+ return;
+ }
+
+ const llama_grammar_element * pos = stack.back();
+
+ // ignore if this turns into a token
+ if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
+ return;
+ }
+
+ auto match = llama_grammar_match_char(pos, chr);
+ if (match.first) {
+ llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
+ if (!llama_grammar_is_end_of_sequence(match.second)) {
+ new_stack.push_back(match.second);
+ }
+ llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks);
+ }
+}
+
void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
llama_grammar_stacks stacks_new;
stacks_new.reserve(grammar->stacks.size());
for (const auto & stack : grammar->stacks) {
- if (stack.empty()) {
- continue;
- }
-
- auto match = llama_grammar_match_char(stack.back(), chr);
- if (match.first) {
- const llama_grammar_element * pos = match.second;
-
- // update top of stack to next element, if any
- llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
- if (!llama_grammar_is_end_of_sequence(pos)) {
- new_stack.push_back(pos);
- }
- llama_grammar_advance_stack(grammar->rules, new_stack, stacks_new);
- }
+ llama_grammar_accept_chr(*grammar, stack, chr, stacks_new);
}
grammar->stacks = std::move(stacks_new);
const llama_grammar_element * stack_pos = stack.back();
+ // if the top of the stack is a token rule, then we only need to check the token id
+ if (stack_pos->type == LLAMA_GRETYPE_TOKEN || stack_pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
+ for (const auto & tok : candidates) {
+ if (*tok.code_points == 0) {
+ // reached the end of a token consumed by char rules, reject iff it ended
+ // in a partial response
+ if (tok.partial_utf8.n_remain != 0) {
+ rejects.push_back(tok);
+ }
+ } else if (!llama_grammar_match_token(stack_pos, tok.id)) {
+ rejects.push_back(tok);
+ }
+ }
+ return rejects;
+ }
+
llama_grammar_candidates next_candidates;
next_candidates.reserve(candidates.size());
rejects.push_back(tok);
}
} else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
- next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 });
+ next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8, tok.id });
} else {
rejects.push_back(tok);
}
auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
for (const auto & tok : next_rejects) {
- rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
+ rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8, tok.id });
}
return rejects;
vocab,
std::move(vec_rules),
std::move(stacks),
- /* .partial_utf8 = */ {},
- /* .lazy =*/ false,
- /* .awaiting_trigger = */ false,
- /* .trigger_buffer = */ "",
- /* .trigger_tokens = */ {},
- /* .trigger_patterns = */ {},
+ /* .partial_utf8 = */ {},
+ /* .lazy = */ false,
+ /* .awaiting_trigger = */ false,
+ /* .trigger_buffer = */ "",
+ /* .trigger_buffer_positions = */ {},
+ /* .trigger_tokens = */ {},
+ /* .trigger_patterns = */ {},
};
}
size_t num_trigger_patterns,
const llama_token * trigger_tokens,
size_t num_trigger_tokens) {
- llama_grammar_parser parser;
+ llama_grammar_parser parser(vocab);
// if there is a grammar, parse it
// rules will be empty (default) if there are parse errors
vocab,
std::move(vec_rules),
std::move(stacks),
- /* .partial_utf8 = */ {},
- /* .lazy = */ lazy,
- /* .awaiting_trigger = */ lazy,
- /* .trigger_buffer = */ "",
+ /* .partial_utf8 = */ {},
+ /* .lazy = */ lazy,
+ /* .awaiting_trigger = */ lazy,
+ /* .trigger_buffer = */ "",
+ /* .trigger_buffer_positions = */ {},
std::move(vec_trigger_tokens),
std::move(vec_trigger_patterns),
};
grammar.lazy,
grammar.awaiting_trigger,
grammar.trigger_buffer,
+ grammar.trigger_buffer_positions,
grammar.trigger_tokens,
grammar.trigger_patterns,
};
cur_p->data[i].logit = -INFINITY;
} else {
candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
- candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
+ candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second, id });
}
}
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
grammar.awaiting_trigger = false;
grammar.trigger_buffer.clear();
- llama_grammar_accept_str(grammar, piece);
+ llama_grammar_accept_token(grammar, token, piece);
LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
return;
} else {
+ auto position = std::make_pair(grammar.trigger_buffer.size(), grammar.trigger_buffer.size() + piece.size());
+ grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
grammar.trigger_buffer += piece;
std::smatch match;
if (start == std::string::npos) {
start = match.position(0);
}
+
+ // replay tokens that overlap with [start, end)
+ for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {
+ auto [tok_start, tok_end] = tok_pos;
+ if (tok_end <= start) {
+ continue;
+ }
+
+ size_t piece_start = (tok_start < start) ? start : tok_start; // allow for partial token pieces
+ size_t piece_len = tok_end - piece_start;
+ auto tok_piece = grammar.trigger_buffer.substr(piece_start, piece_len);
+ llama_grammar_accept_token(grammar, tok, tok_piece);
+ }
+
auto constrained_str = grammar.trigger_buffer.substr(start);
- // std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
grammar.trigger_buffer.clear();
- llama_grammar_accept_str(grammar, constrained_str);
+ grammar.trigger_buffer_positions.clear();
LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str());
return;
}
GGML_ABORT("fatal error");
}
- llama_grammar_accept_str(grammar, piece);
+ llama_grammar_accept_token(grammar, token, piece);
}
void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) {
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
}
}
+
+void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token token, const std::string & piece) {
+ // Note terminating 0 in decoded string
+ const auto decoded = decode_utf8(piece, grammar.partial_utf8);
+ const auto & code_points = decoded.first;
+
+ llama_grammar_stacks stacks_new;
+ stacks_new.reserve(grammar.stacks.size());
+
+ for (const auto & stack : grammar.stacks) {
+ if (stack.empty()) {
+ continue;
+ }
+
+ const llama_grammar_element * pos = stack.back();
+
+ if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
+ if (llama_grammar_match_token(pos, token)) {
+ llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
+ if (!llama_grammar_is_end_of_sequence(pos + 1)) {
+ new_stack.push_back(pos + 1);
+ }
+ llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new);
+ }
+ } else {
+ llama_grammar_stacks current_stacks = {stack};
+
+ for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
+ llama_grammar_stacks next_stacks;
+
+ for (const auto & cur_stack : current_stacks) {
+ llama_grammar_accept_chr(grammar, cur_stack, *it, next_stacks);
+ }
+
+ current_stacks = std::move(next_stacks);
+ if (current_stacks.empty()) {
+ break;
+ }
+ }
+
+ for (auto & surviving_stack : current_stacks) {
+ if (std::find(stacks_new.begin(), stacks_new.end(), surviving_stack) == stacks_new.end()) {
+ stacks_new.emplace_back(surviving_stack);
+ }
+ }
+ }
+ }
+
+ grammar.stacks = std::move(stacks_new);
+ grammar.partial_utf8 = decoded.second;
+
+ if (grammar.stacks.empty()) {
+ throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece + " (" + std::to_string(token) + ")");
+ }
+}
+
return grammar_fails;
}
+struct token_and_piece {
+ llama_token token;
+ std::string piece;
+};
+
+// token() encodes a 32-bit ID as 5 bytes: a 0xff marker followed by the ID in big-endian order.
+static std::string token(llama_token id) {
+ return std::string{
+ static_cast<char>(0xff),
+ static_cast<char>((id >> 24) & 0xff),
+ static_cast<char>((id >> 16) & 0xff),
+ static_cast<char>((id >> 8) & 0xff),
+ static_cast<char>(id & 0xff)
+ };
+}
+
+// parse_tokens() parses the token encodes above and UTF-8 text.
+static std::vector<token_and_piece> parse_tokens(const std::string & input) {
+ std::vector<token_and_piece> result;
+ result.reserve(input.size());
+ size_t offset = 0;
+ while (offset < input.size()) {
+ try {
+ if (static_cast<unsigned char>(input[offset]) == 0xff) {
+ if (offset + 5 > input.size()) {
+ throw std::runtime_error("not enough bytes for token id");
+ }
+ uint32_t val =
+ (static_cast<unsigned char>(input[offset + 1]) << 24) |
+ (static_cast<unsigned char>(input[offset + 2]) << 16) |
+ (static_cast<unsigned char>(input[offset + 3]) << 8) |
+ (static_cast<unsigned char>(input[offset + 4]));
+ auto piece = "<[" + std::to_string(val) + "]>";
+ result.push_back({static_cast<llama_token>(val), piece});
+ offset += 5;
+ } else {
+ uint32_t cpt = unicode_cpt_from_utf8(input, offset);
+ result.push_back({0, unicode_cpt_to_utf8(cpt)});
+ }
+ } catch (const std::invalid_argument & /*ex*/) {
+ // Silently ignore invalid UTF-8 input to avoid leaking the exception beyond llama_tokenize
+ ++offset;
+ result.push_back({0, unicode_cpt_to_utf8(0xFFFD)}); // replacement character
+ }
+ }
+ return result;
+}
+
static bool match_string(const std::string & input, llama_grammar * grammar) {
- const auto cpts = unicode_cpts_from_utf8(input);
+ const auto parsed = parse_tokens(input);
auto & stacks_cur = llama_grammar_get_stacks(grammar);
- for (const auto & cpt : cpts) {
- llama_grammar_accept(grammar, cpt);
+ for (const auto & in : parsed) {
+ try {
+ llama_grammar_accept_token(*grammar, in.token, in.piece);
+ } catch (const std::runtime_error & /*e*/) {
+ // normally this shouldn't get hit because of llama_grammar_apply
+ return false;
+ }
if (stacks_cur.empty()) {
// no stacks means that the grammar failed to match at this point
"12a45",
}
);
+
+ // Test case for a simple grammar with tokens
+ test_grammar(
+ "simple grammar with tokens",
+ R"""(
+ root ::= <[10]> content <[11]>
+ content ::= (!<[11]>)*)""",
+ // Passing strings
+ {
+ token(10) + "hello world" + token(11),
+ token(10) + "text with " + token(12) + " other tokens " + token(13) + " mixed in" + token(11),
+ token(10) + token(11),
+ token(10) + token(12) + token(13) + token(14) + token(15) + token(11),
+ token(10) + "a" + token(11),
+ },
+ // Failing strings
+ {
+ token(10) + "missing end token",
+ token(10),
+ "missing start token" + token(11),
+ token(10) + token(11) + token(11), // double end token
+ token(11) + "wrong order" + token(10),
+ }
+ );
}
static void test_complex_grammar() {
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
}
);
+
+ // Test case for a more complex grammar with tokens
+ test_grammar(
+ "complex grammar with tokens",
+ R"""(
+ root ::= reasoning+ content tool-call*
+ reasoning ::= <[10]> (!<[11]>)* <[11]>
+ content ::= <[20]> (!<[21]>)* <[21]>
+ tool-call ::= <[12]> name <[13]> args <[14]>
+ name ::= (!<[13]>)+
+ args ::= (!<[14]>)*)""",
+ // Passing strings
+ {
+ token(10) + "I am thinking" + token(11) + token(20) + "hello world!" + token(21) + token(12) + "search" + token(13) + "query=test" + token(14),
+ token(10) + "reasoning 1" + token(11) + token(10) + "reasoning 2" + token(11) + token(20) + token(21) + token(12) + "tool" + token(13) + token(14),
+ token(10) + token(11) + token(20) + "content" + token(21),
+ token(10) + "think" + token(12) + " nested" + token(11) + token(20) + token(10) + "more content" + token(21) + token(12) + "fn" + token(13) + "x=1,y=2" + token(14) + token(12) + "fn2" + token(13) + token(14),
+ token(10) + "reasoning" + token(11) + token(10) + "more" + token(11) + token(10) + "even more" + token(11) + token(20) + "text" + token(21) + token(12) + "a" + token(13) + "b" + token(14) + token(12) + "c" + token(13) + "d" + token(14),
+ },
+ // Failing strings
+ {
+ token(20) + "content only" + token(21),
+ token(10) + "no closing reasoning",
+ token(10) + token(11) + token(20) + "no closing content",
+ token(10) + token(11) + token(20) + token(21) + token(12) + "incomplete tool",
+ token(10) + token(11) + token(11) + token(20) + token(21),
+ }
+ );
}
static void test_special_chars() {