size_t last_sym_start = rule.size();
const char * pos = src;
- auto handle_repetitions = [&](int min_times, int max_times) {
+ auto handle_repetitions = [&](int min_times, int max_times) {
- if (last_sym_start == rule.size()) {
- throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
- }
+ if (last_sym_start == rule.size()) {
+ throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
+ }
- // apply transformation to previous symbol (last_sym_start to end) according to
- // the following rewrite rules:
- // S{m,n} --> S S S (m times) S'(n-m)
- // S'(x) ::= S S'(x-1) |
- // (... n-m definitions of these S' rules ...)
- // S'(1) ::= S |
- // S{m,} --> S S S (m times) S'
- // S' ::= S S' |
- // S* --> S{0,}
- // --> S' ::= S S' |
- // S+ --> S{1,}
- // --> S S'
- // S' ::= S S' |
- // S? --> S{0,1}
- // --> S'
- // S' ::= S |
-
- llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
- if (min_times == 0) {
- rule.resize(last_sym_start);
- } else {
- // Repeat the previous elements (min_times - 1) times
- for (int i = 1; i < min_times; i++) {
- rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
- }
+ // apply transformation to previous symbol (last_sym_start to end) according to
+ // the following rewrite rules:
+ // S{m,n} --> S S S (m times) S'(n-m)
+ // S'(x) ::= S S'(x-1) |
+ // (... n-m definitions of these S' rules ...)
+ // S'(1) ::= S |
+ // S{m,} --> S S S (m times) S'
+ // S' ::= S S' |
+ // S* --> S{0,}
+ // --> S' ::= S S' |
+ // S+ --> S{1,}
+ // --> S S'
+ // S' ::= S S' |
+ // S? --> S{0,1}
+ // --> S'
+ // S' ::= S |
+
+ llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
+ if (min_times == 0) {
+ rule.resize(last_sym_start);
+ } else {
+ // Repeat the previous elements (min_times - 1) times
+ for (int i = 1; i < min_times; i++) {
+ rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
}
+ }
- uint32_t last_rec_rule_id = 0;
- auto n_opt = max_times < 0 ? 1 : max_times - min_times;
+ uint32_t last_rec_rule_id = 0;
+ auto n_opt = max_times < 0 ? 1 : max_times - min_times;
- llama_grammar_rule rec_rule(prev_rule);
- for (int i = 0; i < n_opt; i++) {
- rec_rule.resize(prev_rule.size());
- uint32_t rec_rule_id = generate_symbol_id( rule_name);
- if (i > 0 || max_times < 0) {
- rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
- }
- rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
- rec_rule.push_back({LLAMA_GRETYPE_END, 0});
- add_rule( rec_rule_id, rec_rule);
- last_rec_rule_id = rec_rule_id;
- }
- if (n_opt > 0) {
- rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
+ llama_grammar_rule rec_rule(prev_rule);
+ for (int i = 0; i < n_opt; i++) {
+ rec_rule.resize(prev_rule.size());
+ uint32_t rec_rule_id = generate_symbol_id( rule_name);
+ if (i > 0 || max_times < 0) {
+ rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
}
- };
+ rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
+ rec_rule.push_back({LLAMA_GRETYPE_END, 0});
+ add_rule( rec_rule_id, rec_rule);
+ last_rec_rule_id = rec_rule_id;
+ }
+ if (n_opt > 0) {
+ rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
+ }
+ };
- while (*pos) {
- if (*pos == '"') { // literal string
- pos++;
- last_sym_start = rule.size();
- while (*pos != '"') {
- if (!*pos) {
- throw std::runtime_error("unexpected end of input");
- }
- auto char_pair = parse_char(pos);
- pos = char_pair.second;
- rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
+ while (*pos) {
+ if (*pos == '"') { // literal string
+ pos++;
+ last_sym_start = rule.size();
+ while (*pos != '"') {
+ if (!*pos) {
+ throw std::runtime_error("unexpected end of input");
}
- pos = parse_space(pos + 1, is_nested);
- } else if (*pos == '[') { // char range(s)
+ auto char_pair = parse_char(pos);
+ pos = char_pair.second;
+ rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
+ }
+ pos = parse_space(pos + 1, is_nested);
+ } else if (*pos == '[') { // char range(s)
+ pos++;
+ enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
+ if (*pos == '^') {
pos++;
- enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
- if (*pos == '^') {
- pos++;
- start_type = LLAMA_GRETYPE_CHAR_NOT;
+ start_type = LLAMA_GRETYPE_CHAR_NOT;
+ }
+ last_sym_start = rule.size();
+ while (*pos != ']') {
+ if (!*pos) {
+ throw std::runtime_error("unexpected end of input");
}
- last_sym_start = rule.size();
- while (*pos != ']') {
- if (!*pos) {
+ auto char_pair = parse_char(pos);
+ pos = char_pair.second;
+ enum llama_gretype type = last_sym_start < rule.size()
+ ? LLAMA_GRETYPE_CHAR_ALT
+ : start_type;
+
+ rule.push_back({type, char_pair.first});
+ if (pos[0] == '-' && pos[1] != ']') {
+ if (!pos[1]) {
throw std::runtime_error("unexpected end of input");
}
- auto char_pair = parse_char(pos);
- pos = char_pair.second;
- enum llama_gretype type = last_sym_start < rule.size()
- ? LLAMA_GRETYPE_CHAR_ALT
- : start_type;
-
- rule.push_back({type, char_pair.first});
- if (pos[0] == '-' && pos[1] != ']') {
- if (!pos[1]) {
- throw std::runtime_error("unexpected end of input");
- }
- auto endchar_pair = parse_char(pos + 1);
- pos = endchar_pair.second;
- rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
- }
- }
- pos = parse_space(pos + 1, is_nested);
- } else if (is_word_char(*pos)) { // rule reference
- const char * name_end = parse_name(pos);
- uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
- pos = parse_space(name_end, is_nested);
- last_sym_start = rule.size();
- rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
- } else if (*pos == '(') { // grouping
- // parse nested alternates into synthesized rule
- pos = parse_space(pos + 1, true);
- uint32_t sub_rule_id = generate_symbol_id(rule_name);
- pos = parse_alternates(pos, rule_name, sub_rule_id, true);
- last_sym_start = rule.size();
- // output reference to synthesized rule
- rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
- if (*pos != ')') {
- throw std::runtime_error(std::string("expecting ')' at ") + pos);
+ auto endchar_pair = parse_char(pos + 1);
+ pos = endchar_pair.second;
+ rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
}
+ }
+ pos = parse_space(pos + 1, is_nested);
+ } else if (is_word_char(*pos)) { // rule reference
+ const char * name_end = parse_name(pos);
+ uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
+ pos = parse_space(name_end, is_nested);
+ last_sym_start = rule.size();
+ rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
+ } else if (*pos == '(') { // grouping
+ // parse nested alternates into synthesized rule
+ pos = parse_space(pos + 1, true);
+ uint32_t sub_rule_id = generate_symbol_id(rule_name);
+ pos = parse_alternates(pos, rule_name, sub_rule_id, true);
+ last_sym_start = rule.size();
+ // output reference to synthesized rule
+ rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
+ if (*pos != ')') {
+ throw std::runtime_error(std::string("expecting ')' at ") + pos);
+ }
+ pos = parse_space(pos + 1, is_nested);
+ } else if (*pos == '.') { // any char
+ last_sym_start = rule.size();
+ rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
+ pos = parse_space(pos + 1, is_nested);
+ } else if (*pos == '*') {
+ pos = parse_space(pos + 1, is_nested);
+ handle_repetitions(0, -1);
+ } else if (*pos == '+') {
+ pos = parse_space(pos + 1, is_nested);
+ handle_repetitions(1, -1);
+ } else if (*pos == '?') {
+ pos = parse_space(pos + 1, is_nested);
+ handle_repetitions(0, 1);
+ } else if (*pos == '{') {
+ pos = parse_space(pos + 1, is_nested);
+
+ if (!is_digit_char(*pos)) {
+ throw std::runtime_error(std::string("expecting an int at ") + pos);
+ }
+ const char * int_end = parse_int(pos);
+ int min_times = std::stoul(std::string(pos, int_end - pos));
+ pos = parse_space(int_end, is_nested);
+
+ int max_times = -1;
+
+ if (*pos == '}') {
+ max_times = min_times;
pos = parse_space(pos + 1, is_nested);
- } else if (*pos == '.') { // any char
- last_sym_start = rule.size();
- rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
- pos = parse_space(pos + 1, is_nested);
- } else if (*pos == '*') {
- pos = parse_space(pos + 1, is_nested);
- handle_repetitions(0, -1);
- } else if (*pos == '+') {
- pos = parse_space(pos + 1, is_nested);
- handle_repetitions(1, -1);
- } else if (*pos == '?') {
- pos = parse_space(pos + 1, is_nested);
- handle_repetitions(0, 1);
- } else if (*pos == '{') {
+ } else if (*pos == ',') {
pos = parse_space(pos + 1, is_nested);
- if (!is_digit_char(*pos)) {
- throw std::runtime_error(std::string("expecting an int at ") + pos);
+ if (is_digit_char(*pos)) {
+ const char * int_end = parse_int(pos);
+ max_times = std::stoul(std::string(pos, int_end - pos));
+ pos = parse_space(int_end, is_nested);
}
- const char * int_end = parse_int(pos);
- int min_times = std::stoul(std::string(pos, int_end - pos));
- pos = parse_space(int_end, is_nested);
-
- int max_times = -1;
-
- if (*pos == '}') {
- max_times = min_times;
- pos = parse_space(pos + 1, is_nested);
- } else if (*pos == ',') {
- pos = parse_space(pos + 1, is_nested);
-
- if (is_digit_char(*pos)) {
- const char * int_end = parse_int(pos);
- max_times = std::stoul(std::string(pos, int_end - pos));
- pos = parse_space(int_end, is_nested);
- }
- if (*pos != '}') {
- throw std::runtime_error(std::string("expecting '}' at ") + pos);
- }
- pos = parse_space(pos + 1, is_nested);
- } else {
- throw std::runtime_error(std::string("expecting ',' at ") + pos);
+ if (*pos != '}') {
+ throw std::runtime_error(std::string("expecting '}' at ") + pos);
}
- handle_repetitions(min_times, max_times);
+ pos = parse_space(pos + 1, is_nested);
} else {
- break;
+ throw std::runtime_error(std::string("expecting ',' at ") + pos);
}
+ handle_repetitions(min_times, max_times);
+ } else {
+ break;
}
- return pos;
}
+ return pos;
+}
const char * llama_grammar_parser::parse_rule(const char * src) {
- const char * name_end = parse_name(src);
- const char * pos = parse_space(name_end, false);
- size_t name_len = name_end - src;
- uint32_t rule_id = get_symbol_id(src, name_len);
- const std::string name(src, name_len);
-
- if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
- throw std::runtime_error(std::string("expecting ::= at ") + pos);
- }
- pos = parse_space(pos + 3, true);
+ const char * name_end = parse_name(src);
+ const char * pos = parse_space(name_end, false);
+ size_t name_len = name_end - src;
+ uint32_t rule_id = get_symbol_id(src, name_len);
+ const std::string name(src, name_len);
+
+ if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
+ throw std::runtime_error(std::string("expecting ::= at ") + pos);
+ }
+ pos = parse_space(pos + 3, true);
- pos = parse_alternates(pos, name, rule_id, false);
+ pos = parse_alternates(pos, name, rule_id, false);
- if (*pos == '\r') {
- pos += pos[1] == '\n' ? 2 : 1;
- } else if (*pos == '\n') {
- pos++;
- } else if (*pos) {
- throw std::runtime_error(std::string("expecting newline or end at ") + pos);
- }
- return parse_space(pos, true);
+ if (*pos == '\r') {
+ pos += pos[1] == '\n' ? 2 : 1;
+ } else if (*pos == '\n') {
+ pos++;
+ } else if (*pos) {
+ throw std::runtime_error(std::string("expecting newline or end at ") + pos);
}
+ return parse_space(pos, true);
+}
bool llama_grammar_parser::parse(const char * src) {
try {