]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : fix indentation in llama-grammar [no ci] (#11943)
authorDaniel Bevenius <redacted>
Wed, 19 Feb 2025 05:16:23 +0000 (06:16 +0100)
committerGitHub <redacted>
Wed, 19 Feb 2025 05:16:23 +0000 (06:16 +0100)
This commit adjusts the indentation for the functions `parse_sequence`
and `parse_rule` in src/llama-grammar.cpp.

The motivation is consistency and improve readability.

src/llama-grammar.cpp

index 46e27a96ed728aa9363146bcf2fb18ff964db3b4..98af1ba3976aefe26906d6296530280a5ebe1fc5 100644 (file)
@@ -345,194 +345,194 @@ const char * llama_grammar_parser::parse_sequence(
     size_t last_sym_start = rule.size();
     const char * pos = src;
 
-        auto handle_repetitions = [&](int min_times, int max_times) {
+    auto handle_repetitions = [&](int min_times, int max_times) {
 
-            if (last_sym_start == rule.size()) {
-                throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
-            }
+        if (last_sym_start == rule.size()) {
+            throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
+        }
 
-            // apply transformation to previous symbol (last_sym_start to end) according to
-            // the following rewrite rules:
-            // S{m,n} --> S S S (m times) S'(n-m)
-            //            S'(x)   ::= S S'(x-1) |
-            //            (... n-m definitions of these S' rules ...)
-            //            S'(1)   ::= S |
-            // S{m,} -->  S S S (m times) S'
-            //            S'     ::= S S' |
-            // S*     --> S{0,}
-            //        --> S'     ::= S S' |
-            // S+     --> S{1,}
-            //        --> S S'
-            //            S'     ::= S S' |
-            // S?     --> S{0,1}
-            //        --> S'
-            //            S'     ::= S |
-
-            llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
-            if (min_times == 0) {
-                rule.resize(last_sym_start);
-            } else {
-                // Repeat the previous elements (min_times - 1) times
-                for (int i = 1; i < min_times; i++) {
-                    rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
-                }
+        // apply transformation to previous symbol (last_sym_start to end) according to
+        // the following rewrite rules:
+        // S{m,n} --> S S S (m times) S'(n-m)
+        //            S'(x)   ::= S S'(x-1) |
+        //            (... n-m definitions of these S' rules ...)
+        //            S'(1)   ::= S |
+        // S{m,} -->  S S S (m times) S'
+        //            S'     ::= S S' |
+        // S*     --> S{0,}
+        //        --> S'     ::= S S' |
+        // S+     --> S{1,}
+        //        --> S S'
+        //            S'     ::= S S' |
+        // S?     --> S{0,1}
+        //        --> S'
+        //            S'     ::= S |
+
+        llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
+        if (min_times == 0) {
+            rule.resize(last_sym_start);
+        } else {
+            // Repeat the previous elements (min_times - 1) times
+            for (int i = 1; i < min_times; i++) {
+                rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
             }
+        }
 
-            uint32_t last_rec_rule_id = 0;
-            auto n_opt = max_times < 0 ? 1 : max_times - min_times;
+        uint32_t last_rec_rule_id = 0;
+        auto n_opt = max_times < 0 ? 1 : max_times - min_times;
 
-            llama_grammar_rule rec_rule(prev_rule);
-            for (int i = 0; i < n_opt; i++) {
-                rec_rule.resize(prev_rule.size());
-                uint32_t rec_rule_id = generate_symbol_id( rule_name);
-                if (i > 0 || max_times < 0) {
-                    rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
-                }
-                rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
-                rec_rule.push_back({LLAMA_GRETYPE_END, 0});
-                add_rule( rec_rule_id, rec_rule);
-                last_rec_rule_id = rec_rule_id;
-            }
-            if (n_opt > 0) {
-                rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
+        llama_grammar_rule rec_rule(prev_rule);
+        for (int i = 0; i < n_opt; i++) {
+            rec_rule.resize(prev_rule.size());
+            uint32_t rec_rule_id = generate_symbol_id( rule_name);
+            if (i > 0 || max_times < 0) {
+                rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
             }
-        };
+            rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
+            rec_rule.push_back({LLAMA_GRETYPE_END, 0});
+            add_rule( rec_rule_id, rec_rule);
+            last_rec_rule_id = rec_rule_id;
+        }
+        if (n_opt > 0) {
+            rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
+        }
+    };
 
-        while (*pos) {
-            if (*pos == '"') { // literal string
-                pos++;
-                last_sym_start = rule.size();
-                while (*pos != '"') {
-                    if (!*pos) {
-                        throw std::runtime_error("unexpected end of input");
-                    }
-                    auto char_pair = parse_char(pos);
-                         pos       = char_pair.second;
-                    rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
+    while (*pos) {
+        if (*pos == '"') { // literal string
+            pos++;
+            last_sym_start = rule.size();
+            while (*pos != '"') {
+                if (!*pos) {
+                    throw std::runtime_error("unexpected end of input");
                 }
-                pos = parse_space(pos + 1, is_nested);
-            } else if (*pos == '[') { // char range(s)
+                auto char_pair = parse_char(pos);
+                     pos       = char_pair.second;
+                rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
+            }
+            pos = parse_space(pos + 1, is_nested);
+        } else if (*pos == '[') { // char range(s)
+            pos++;
+            enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
+            if (*pos == '^') {
                 pos++;
-                enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
-                if (*pos == '^') {
-                    pos++;
-                    start_type = LLAMA_GRETYPE_CHAR_NOT;
+                start_type = LLAMA_GRETYPE_CHAR_NOT;
+            }
+            last_sym_start = rule.size();
+            while (*pos != ']') {
+                if (!*pos) {
+                    throw std::runtime_error("unexpected end of input");
                 }
-                last_sym_start = rule.size();
-                while (*pos != ']') {
-                    if (!*pos) {
+                auto char_pair = parse_char(pos);
+                     pos       = char_pair.second;
+                enum llama_gretype type = last_sym_start < rule.size()
+                    ? LLAMA_GRETYPE_CHAR_ALT
+                    : start_type;
+
+                rule.push_back({type, char_pair.first});
+                if (pos[0] == '-' && pos[1] != ']') {
+                    if (!pos[1]) {
                         throw std::runtime_error("unexpected end of input");
                     }
-                    auto char_pair = parse_char(pos);
-                         pos       = char_pair.second;
-                    enum llama_gretype type = last_sym_start < rule.size()
-                        ? LLAMA_GRETYPE_CHAR_ALT
-                        : start_type;
-
-                    rule.push_back({type, char_pair.first});
-                    if (pos[0] == '-' && pos[1] != ']') {
-                        if (!pos[1]) {
-                            throw std::runtime_error("unexpected end of input");
-                        }
-                        auto endchar_pair = parse_char(pos + 1);
-                             pos          = endchar_pair.second;
-                        rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
-                    }
-                }
-                pos = parse_space(pos + 1, is_nested);
-            } else if (is_word_char(*pos)) { // rule reference
-                const char * name_end    = parse_name(pos);
-                uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
-                pos = parse_space(name_end, is_nested);
-                last_sym_start = rule.size();
-                rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
-            } else if (*pos == '(') { // grouping
-                // parse nested alternates into synthesized rule
-                pos = parse_space(pos + 1, true);
-                uint32_t sub_rule_id = generate_symbol_id(rule_name);
-                pos = parse_alternates(pos, rule_name, sub_rule_id, true);
-                last_sym_start = rule.size();
-                // output reference to synthesized rule
-                rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
-                if (*pos != ')') {
-                    throw std::runtime_error(std::string("expecting ')' at ") + pos);
+                    auto endchar_pair = parse_char(pos + 1);
+                         pos          = endchar_pair.second;
+                    rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
                 }
+            }
+            pos = parse_space(pos + 1, is_nested);
+        } else if (is_word_char(*pos)) { // rule reference
+            const char * name_end    = parse_name(pos);
+            uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
+            pos = parse_space(name_end, is_nested);
+            last_sym_start = rule.size();
+            rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
+        } else if (*pos == '(') { // grouping
+            // parse nested alternates into synthesized rule
+            pos = parse_space(pos + 1, true);
+            uint32_t sub_rule_id = generate_symbol_id(rule_name);
+            pos = parse_alternates(pos, rule_name, sub_rule_id, true);
+            last_sym_start = rule.size();
+            // output reference to synthesized rule
+            rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
+            if (*pos != ')') {
+                throw std::runtime_error(std::string("expecting ')' at ") + pos);
+            }
+            pos = parse_space(pos + 1, is_nested);
+        } else if (*pos == '.') { // any char
+            last_sym_start = rule.size();
+            rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
+            pos = parse_space(pos + 1, is_nested);
+        } else if (*pos == '*') {
+            pos = parse_space(pos + 1, is_nested);
+            handle_repetitions(0, -1);
+        } else if (*pos == '+') {
+            pos = parse_space(pos + 1, is_nested);
+            handle_repetitions(1, -1);
+        } else if (*pos == '?') {
+            pos = parse_space(pos + 1, is_nested);
+            handle_repetitions(0, 1);
+        } else if (*pos == '{') {
+            pos = parse_space(pos + 1, is_nested);
+
+            if (!is_digit_char(*pos)) {
+                throw std::runtime_error(std::string("expecting an int at ") + pos);
+            }
+            const char * int_end = parse_int(pos);
+            int min_times = std::stoul(std::string(pos, int_end - pos));
+            pos = parse_space(int_end, is_nested);
+
+            int max_times = -1;
+
+            if (*pos == '}') {
+                max_times = min_times;
                 pos = parse_space(pos + 1, is_nested);
-            } else if (*pos == '.') { // any char
-                last_sym_start = rule.size();
-                rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
-                pos = parse_space(pos + 1, is_nested);
-            } else if (*pos == '*') {
-                pos = parse_space(pos + 1, is_nested);
-                handle_repetitions(0, -1);
-            } else if (*pos == '+') {
-                pos = parse_space(pos + 1, is_nested);
-                handle_repetitions(1, -1);
-            } else if (*pos == '?') {
-                pos = parse_space(pos + 1, is_nested);
-                handle_repetitions(0, 1);
-            } else if (*pos == '{') {
+            } else if (*pos == ',') {
                 pos = parse_space(pos + 1, is_nested);
 
-                if (!is_digit_char(*pos)) {
-                    throw std::runtime_error(std::string("expecting an int at ") + pos);
+                if (is_digit_char(*pos)) {
+                    const char * int_end = parse_int(pos);
+                    max_times = std::stoul(std::string(pos, int_end - pos));
+                    pos = parse_space(int_end, is_nested);
                 }
-                const char * int_end = parse_int(pos);
-                int min_times = std::stoul(std::string(pos, int_end - pos));
-                pos = parse_space(int_end, is_nested);
-
-                int max_times = -1;
-
-                if (*pos == '}') {
-                    max_times = min_times;
-                    pos = parse_space(pos + 1, is_nested);
-                } else if (*pos == ',') {
-                    pos = parse_space(pos + 1, is_nested);
-
-                    if (is_digit_char(*pos)) {
-                        const char * int_end = parse_int(pos);
-                        max_times = std::stoul(std::string(pos, int_end - pos));
-                        pos = parse_space(int_end, is_nested);
-                    }
 
-                    if (*pos != '}') {
-                        throw std::runtime_error(std::string("expecting '}' at ") + pos);
-                    }
-                    pos = parse_space(pos + 1, is_nested);
-                } else {
-                    throw std::runtime_error(std::string("expecting ',' at ") + pos);
+                if (*pos != '}') {
+                    throw std::runtime_error(std::string("expecting '}' at ") + pos);
                 }
-                handle_repetitions(min_times, max_times);
+                pos = parse_space(pos + 1, is_nested);
             } else {
-                break;
+                throw std::runtime_error(std::string("expecting ',' at ") + pos);
             }
+            handle_repetitions(min_times, max_times);
+        } else {
+            break;
         }
-        return pos;
     }
+    return pos;
+}
 
 const char * llama_grammar_parser::parse_rule(const char * src) {
-        const char * name_end = parse_name(src);
-        const char * pos      = parse_space(name_end, false);
-        size_t       name_len = name_end - src;
-        uint32_t     rule_id  = get_symbol_id(src, name_len);
-        const std::string name(src, name_len);
-
-        if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
-            throw std::runtime_error(std::string("expecting ::= at ") + pos);
-        }
-        pos = parse_space(pos + 3, true);
+    const char * name_end = parse_name(src);
+    const char * pos      = parse_space(name_end, false);
+    size_t       name_len = name_end - src;
+    uint32_t     rule_id  = get_symbol_id(src, name_len);
+    const std::string name(src, name_len);
+
+    if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
+        throw std::runtime_error(std::string("expecting ::= at ") + pos);
+    }
+    pos = parse_space(pos + 3, true);
 
-        pos = parse_alternates(pos, name, rule_id, false);
+    pos = parse_alternates(pos, name, rule_id, false);
 
-        if (*pos == '\r') {
-            pos += pos[1] == '\n' ? 2 : 1;
-        } else if (*pos == '\n') {
-            pos++;
-        } else if (*pos) {
-            throw std::runtime_error(std::string("expecting newline or end at ") + pos);
-        }
-        return parse_space(pos, true);
+    if (*pos == '\r') {
+        pos += pos[1] == '\n' ? 2 : 1;
+    } else if (*pos == '\n') {
+        pos++;
+    } else if (*pos) {
+        throw std::runtime_error(std::string("expecting newline or end at ") + pos);
     }
+    return parse_space(pos, true);
+}
 
 bool llama_grammar_parser::parse(const char * src) {
     try {