]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
sync : whisper.cpp (grammar-parser, skip)
authorGeorgi Gerganov <redacted>
Sun, 7 Apr 2024 14:02:17 +0000 (17:02 +0300)
committerGeorgi Gerganov <redacted>
Sun, 7 Apr 2024 14:02:17 +0000 (17:02 +0300)
examples/whisper/CMakeLists.txt
examples/whisper/grammar-parser.cpp [new file with mode: 0644]
examples/whisper/grammar-parser.h [new file with mode: 0644]
scripts/sync-whisper-am.sh
scripts/sync-whisper.sh

index 63f8cd40891d66ae94bcd6bb7ead2afe61f00b01..fb5b0854b1ebc5e4062383f3aa035fc3a46c1ea6 100644 (file)
@@ -10,7 +10,7 @@ target_link_libraries(whisper-cpp PRIVATE
     )
 
 set(TEST_TARGET whisper)
-add_executable(${TEST_TARGET} main.cpp)
+add_executable(${TEST_TARGET} main.cpp grammar-parser.cpp)
 target_link_libraries(${TEST_TARGET} PRIVATE whisper-cpp common)
 target_include_directories(${TEST_TARGET} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/..)
 target_include_directories(${TEST_TARGET} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include/ggml)
diff --git a/examples/whisper/grammar-parser.cpp b/examples/whisper/grammar-parser.cpp
new file mode 100644 (file)
index 0000000..2daaaef
--- /dev/null
@@ -0,0 +1,423 @@
+#include "grammar-parser.h"
+#include <cstdint>
+#include <cwchar>
+#include <string>
+#include <utility>
+#include <stdexcept>
+#include <exception>
+
+namespace grammar_parser {
+    // NOTE: assumes valid utf8 (but checks for overrun)
+    // copied from whisper.cpp
+    std::pair<uint32_t, const char *> decode_utf8(const char * src) {
+        static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
+        uint8_t  first_byte = static_cast<uint8_t>(*src);
+        uint8_t  highbits   = first_byte >> 4;
+        int      len        = lookup[highbits];
+        uint8_t  mask       = (1 << (8 - len)) - 1;
+        uint32_t value      = first_byte & mask;
+        const char * end    = src + len; // may overrun!
+        const char * pos    = src + 1;
+        for ( ; pos < end && *pos; pos++) {
+            value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
+        }
+        return std::make_pair(value, pos);
+    }
+
+    uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) {
+        uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
+        auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id));
+        return result.first->second;
+    }
+
+    uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) {
+        uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
+        state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
+        return next_id;
+    }
+
+    void add_rule(
+            parse_state & state,
+            uint32_t      rule_id,
+            const std::vector<whisper_grammar_element> & rule) {
+        if (state.rules.size() <= rule_id) {
+            state.rules.resize(rule_id + 1);
+        }
+        state.rules[rule_id] = rule;
+    }
+
+    bool is_word_char(char c) {
+        return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
+    }
+
+    std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
+        const char * pos   = src;
+        const char * end   = src + size;
+        uint32_t     value = 0;
+        for ( ; pos < end && *pos; pos++) {
+            value <<= 4;
+            char c = *pos;
+            if ('a' <= c && c <= 'f') {
+                value += c - 'a' + 10;
+            } else if ('A' <= c && c <= 'F') {
+                value += c - 'A' + 10;
+            } else if ('0' <= c && c <= '9') {
+                value += c - '0';
+            } else {
+                break;
+            }
+        }
+        if (pos != end) {
+            throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
+        }
+        return std::make_pair(value, pos);
+    }
+
+    const char * parse_space(const char * src, bool newline_ok) {
+        const char * pos = src;
+        while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
+                (newline_ok && (*pos == '\r' || *pos == '\n'))) {
+            if (*pos == '#') {
+                while (*pos && *pos != '\r' && *pos != '\n') {
+                    pos++;
+                }
+            } else {
+                pos++;
+            }
+        }
+        return pos;
+    }
+
+    const char * parse_name(const char * src) {
+        const char * pos = src;
+        while (is_word_char(*pos)) {
+            pos++;
+        }
+        if (pos == src) {
+            throw std::runtime_error(std::string("expecting name at ") + src);
+        }
+        return pos;
+    }
+
+    std::pair<uint32_t, const char *> parse_char(const char * src) {
+        if (*src == '\\') {
+            switch (src[1]) {
+                case 'x': return parse_hex(src + 2, 2);
+                case 'u': return parse_hex(src + 2, 4);
+                case 'U': return parse_hex(src + 2, 8);
+                case 't': return std::make_pair('\t', src + 2);
+                case 'r': return std::make_pair('\r', src + 2);
+                case 'n': return std::make_pair('\n', src + 2);
+                case '\\':
+                case '"':
+                case '[':
+                case ']':
+                    return std::make_pair(src[1], src + 2);
+                default:
+                    throw std::runtime_error(std::string("unknown escape at ") + src);
+            }
+        } else if (*src) {
+            return decode_utf8(src);
+        }
+        throw std::runtime_error("unexpected end of input");
+    }
+
+    const char * parse_alternates(
+            parse_state       & state,
+            const char        * src,
+            const std::string & rule_name,
+            uint32_t            rule_id,
+            bool                is_nested);
+
+    const char * parse_sequence(
+            parse_state                        & state,
+            const char                         * src,
+            const std::string                  & rule_name,
+            std::vector<whisper_grammar_element> & out_elements,
+            bool                                 is_nested) {
+        size_t last_sym_start = out_elements.size();
+        const char * pos = src;
+        while (*pos) {
+            if (*pos == '"') { // literal string
+                pos++;
+                last_sym_start = out_elements.size();
+                while (*pos != '"') {
+                    auto char_pair = parse_char(pos);
+                         pos       = char_pair.second;
+                    out_elements.push_back({WHISPER_GRETYPE_CHAR, char_pair.first});
+                }
+                pos = parse_space(pos + 1, is_nested);
+            } else if (*pos == '[') { // char range(s)
+                pos++;
+                enum whisper_gretype start_type = WHISPER_GRETYPE_CHAR;
+                if (*pos == '^') {
+                    pos++;
+                    start_type = WHISPER_GRETYPE_CHAR_NOT;
+                }
+                last_sym_start = out_elements.size();
+                while (*pos != ']') {
+                    auto char_pair = parse_char(pos);
+                         pos       = char_pair.second;
+                    enum whisper_gretype type = last_sym_start < out_elements.size()
+                        ? WHISPER_GRETYPE_CHAR_ALT
+                        : start_type;
+
+                    out_elements.push_back({type, char_pair.first});
+                    if (pos[0] == '-' && pos[1] != ']') {
+                        auto endchar_pair = parse_char(pos + 1);
+                             pos          = endchar_pair.second;
+                        out_elements.push_back({WHISPER_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
+                    }
+                }
+                pos = parse_space(pos + 1, is_nested);
+            } else if (is_word_char(*pos)) { // rule reference
+                const char * name_end    = parse_name(pos);
+                uint32_t     ref_rule_id = get_symbol_id(state, pos, name_end - pos);
+                pos = parse_space(name_end, is_nested);
+                last_sym_start = out_elements.size();
+                out_elements.push_back({WHISPER_GRETYPE_RULE_REF, ref_rule_id});
+            } else if (*pos == '(') { // grouping
+                // parse nested alternates into synthesized rule
+                pos = parse_space(pos + 1, true);
+                uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
+                pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
+                last_sym_start = out_elements.size();
+                // output reference to synthesized rule
+                out_elements.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
+                if (*pos != ')') {
+                    throw std::runtime_error(std::string("expecting ')' at ") + pos);
+                }
+                pos = parse_space(pos + 1, is_nested);
+            } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator
+                if (last_sym_start == out_elements.size()) {
+                    throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos);
+                }
+
+                // apply transformation to previous symbol (last_sym_start to end) according to
+                // rewrite rules:
+                // S* --> S' ::= S S' |
+                // S+ --> S' ::= S S' | S
+                // S? --> S' ::= S |
+                uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
+                std::vector<whisper_grammar_element> sub_rule;
+                // add preceding symbol to generated rule
+                sub_rule.insert(
+                    sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
+                if (*pos == '*' || *pos == '+') {
+                    // cause generated rule to recurse
+                    sub_rule.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
+                }
+                // mark start of alternate def
+                sub_rule.push_back({WHISPER_GRETYPE_ALT, 0});
+                if (*pos == '+') {
+                    // add preceding symbol as alternate only for '+' (otherwise empty)
+                    sub_rule.insert(
+                        sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
+                }
+                sub_rule.push_back({WHISPER_GRETYPE_END, 0});
+                add_rule(state, sub_rule_id, sub_rule);
+
+                // in original rule, replace previous symbol with reference to generated rule
+                out_elements.resize(last_sym_start);
+                out_elements.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
+
+                pos = parse_space(pos + 1, is_nested);
+            } else {
+                break;
+            }
+        }
+        return pos;
+    }
+
+    const char * parse_alternates(
+            parse_state       & state,
+            const char        * src,
+            const std::string & rule_name,
+            uint32_t            rule_id,
+            bool                is_nested) {
+        std::vector<whisper_grammar_element> rule;
+        const char * pos = parse_sequence(state, src, rule_name, rule, is_nested);
+        while (*pos == '|') {
+            rule.push_back({WHISPER_GRETYPE_ALT, 0});
+            pos = parse_space(pos + 1, true);
+            pos = parse_sequence(state, pos, rule_name, rule, is_nested);
+        }
+        rule.push_back({WHISPER_GRETYPE_END, 0});
+        add_rule(state, rule_id, rule);
+        return pos;
+    }
+
+    const char * parse_rule(parse_state & state, const char * src) {
+        const char * name_end = parse_name(src);
+        const char * pos      = parse_space(name_end, false);
+        size_t       name_len = name_end - src;
+        uint32_t     rule_id  = get_symbol_id(state, src, name_len);
+        const std::string name(src, name_len);
+
+        if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
+            throw std::runtime_error(std::string("expecting ::= at ") + pos);
+        }
+        pos = parse_space(pos + 3, true);
+
+        pos = parse_alternates(state, pos, name, rule_id, false);
+
+        if (*pos == '\r') {
+            pos += pos[1] == '\n' ? 2 : 1;
+        } else if (*pos == '\n') {
+            pos++;
+        } else if (*pos) {
+            throw std::runtime_error(std::string("expecting newline or end at ") + pos);
+        }
+        return parse_space(pos, true);
+    }
+
+    parse_state parse(const char * src) {
+        try {
+            parse_state state;
+            const char * pos = parse_space(src, true);
+            while (*pos) {
+                pos = parse_rule(state, pos);
+            }
+            return state;
+        } catch (const std::exception & err) {
+            fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
+            return parse_state();
+        }
+    }
+
+    void print_grammar_char(FILE * file, uint32_t c) {
+        if (0x20 <= c && c <= 0x7f) {
+            fprintf(file, "%c", static_cast<char>(c));
+        } else {
+            // cop out of encoding UTF-8
+            fprintf(file, "<U+%04X>", c);
+        }
+    }
+
+    bool is_char_element(whisper_grammar_element elem) {
+        switch (elem.type) {
+            case WHISPER_GRETYPE_CHAR:           return true;
+            case WHISPER_GRETYPE_CHAR_NOT:       return true;
+            case WHISPER_GRETYPE_CHAR_ALT:       return true;
+            case WHISPER_GRETYPE_CHAR_RNG_UPPER: return true;
+            default:                           return false;
+        }
+    }
+
+    void print_rule_binary(FILE * file, const std::vector<whisper_grammar_element> & rule) {
+        for (auto elem : rule) {
+            switch (elem.type) {
+                case WHISPER_GRETYPE_END:            fprintf(file, "END");            break;
+                case WHISPER_GRETYPE_ALT:            fprintf(file, "ALT");            break;
+                case WHISPER_GRETYPE_RULE_REF:       fprintf(file, "RULE_REF");       break;
+                case WHISPER_GRETYPE_CHAR:           fprintf(file, "CHAR");           break;
+                case WHISPER_GRETYPE_CHAR_NOT:       fprintf(file, "CHAR_NOT");       break;
+                case WHISPER_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
+                case WHISPER_GRETYPE_CHAR_ALT:       fprintf(file, "CHAR_ALT");       break;
+            }
+            switch (elem.type) {
+                case WHISPER_GRETYPE_END:
+                case WHISPER_GRETYPE_ALT:
+                case WHISPER_GRETYPE_RULE_REF:
+                    fprintf(file, "(%u) ", elem.value);
+                    break;
+                case WHISPER_GRETYPE_CHAR:
+                case WHISPER_GRETYPE_CHAR_NOT:
+                case WHISPER_GRETYPE_CHAR_RNG_UPPER:
+                case WHISPER_GRETYPE_CHAR_ALT:
+                    fprintf(file, "(\"");
+                    print_grammar_char(file, elem.value);
+                    fprintf(file, "\") ");
+                    break;
+            }
+        }
+        fprintf(file, "\n");
+    }
+
+    void print_rule(
+            FILE     * file,
+            uint32_t   rule_id,
+            const std::vector<whisper_grammar_element> & rule,
+            const std::map<uint32_t, std::string>    & symbol_id_names) {
+        if (rule.empty() || rule.back().type != WHISPER_GRETYPE_END) {
+            throw std::runtime_error(
+                "malformed rule, does not end with WHISPER_GRETYPE_END: " + std::to_string(rule_id));
+        }
+        fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
+        for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
+            whisper_grammar_element elem = rule[i];
+            switch (elem.type) {
+                case WHISPER_GRETYPE_END:
+                    throw std::runtime_error(
+                        "unexpected end of rule: " + std::to_string(rule_id) + "," +
+                        std::to_string(i));
+                case WHISPER_GRETYPE_ALT:
+                    fprintf(file, "| ");
+                    break;
+                case WHISPER_GRETYPE_RULE_REF:
+                    fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
+                    break;
+                case WHISPER_GRETYPE_CHAR:
+                    fprintf(file, "[");
+                    print_grammar_char(file, elem.value);
+                    break;
+                case WHISPER_GRETYPE_CHAR_NOT:
+                    fprintf(file, "[^");
+                    print_grammar_char(file, elem.value);
+                    break;
+                case WHISPER_GRETYPE_CHAR_RNG_UPPER:
+                    if (i == 0 || !is_char_element(rule[i - 1])) {
+                        throw std::runtime_error(
+                            "WHISPER_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
+                            std::to_string(rule_id) + "," + std::to_string(i));
+                    }
+                    fprintf(file, "-");
+                    print_grammar_char(file, elem.value);
+                    break;
+                case WHISPER_GRETYPE_CHAR_ALT:
+                    if (i == 0 || !is_char_element(rule[i - 1])) {
+                        throw std::runtime_error(
+                            "WHISPER_GRETYPE_CHAR_ALT without preceding char: " +
+                            std::to_string(rule_id) + "," + std::to_string(i));
+                    }
+                    print_grammar_char(file, elem.value);
+                    break;
+            }
+            if (is_char_element(elem)) {
+                switch (rule[i + 1].type) {
+                    case WHISPER_GRETYPE_CHAR_ALT:
+                    case WHISPER_GRETYPE_CHAR_RNG_UPPER:
+                        break;
+                    default:
+                        fprintf(file, "] ");
+                }
+            }
+        }
+        fprintf(file, "\n");
+    }
+
+    void print_grammar(FILE * file, const parse_state & state) {
+        try {
+            std::map<uint32_t, std::string> symbol_id_names;
+            for (auto kv : state.symbol_ids) {
+                symbol_id_names[kv.second] = kv.first;
+            }
+            for (size_t i = 0, end = state.rules.size(); i < end; i++) {
+                // fprintf(file, "%zu: ", i);
+                // print_rule_binary(file, state.rules[i]);
+                print_rule(file, uint32_t(i), state.rules[i], symbol_id_names);
+                // fprintf(file, "\n");
+            }
+        } catch (const std::exception & err) {
+            fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
+        }
+    }
+
+    std::vector<const whisper_grammar_element *> parse_state::c_rules() const{
+        std::vector<const whisper_grammar_element *> ret;
+        for (const auto & rule : rules) {
+            ret.push_back(rule.data());
+        }
+        return ret;
+    }
+}
diff --git a/examples/whisper/grammar-parser.h b/examples/whisper/grammar-parser.h
new file mode 100644 (file)
index 0000000..47d019c
--- /dev/null
@@ -0,0 +1,29 @@
+// Implements a parser for an extended Backus-Naur form (BNF), producing the
+// binary context-free grammar format specified by whisper.h. Supports character
+// ranges, grouping, and repetition operators. As an example, a grammar for
+// arithmetic might look like:
+//
+// root  ::= expr
+// expr  ::= term ([-+*/] term)*
+// term  ::= num | "(" space expr ")" space
+// num   ::= [0-9]+ space
+// space ::= [ \t\n]*
+
+#pragma once
+#include "whisper.h"
+#include <vector>
+#include <map>
+#include <cstdint>
+#include <string>
+
+namespace grammar_parser {
+    struct parse_state {
+        std::map<std::string, uint32_t>                   symbol_ids;
+        std::vector<std::vector<whisper_grammar_element>> rules;
+
+        std::vector<const whisper_grammar_element *>      c_rules() const;
+    };
+
+    parse_state parse(const char * src);
+    void print_grammar(FILE * file, const parse_state & state);
+}
index ef66f65fd4ff58fa15bf115e10b1e05d85d70457..c19da23aa8283cc5ba6dc3a83bec69bb24186884 100755 (executable)
@@ -66,6 +66,8 @@ while read c; do
     examples/common.cpp \
     examples/common-ggml.h \
     examples/common-ggml.cpp \
+    examples/grammar-parser.h \
+    examples/grammar-parser.cpp \
     examples/main/main.cpp \
     examples/quantize/quantize.cpp \
     >> $SRC_GGML/whisper-src.patch
@@ -127,6 +129,8 @@ if [ -f $SRC_GGML/whisper-src.patch ]; then
     # examples/common.cpp            -> examples/common.cpp
     # examples/common-ggml.h         -> examples/common-ggml.h
     # examples/common-ggml.cpp       -> examples/common-ggml.cpp
+    # examples/grammar-parser.h      -> examples/whisper/grammar-parser.h
+    # examples/grammar-parser.cpp    -> examples/whisper/grammar-parser.cpp
     # examples/main/main.cpp         -> examples/whisper/main.cpp
     # examples/quantize/quantize.cpp -> examples/whisper/quantize.cpp
 
@@ -163,6 +167,8 @@ if [ -f $SRC_GGML/whisper-src.patch ]; then
         -e 's/\/examples\/common\.cpp/\/examples\/common.cpp/g' \
         -e 's/\/examples\/common-ggml\.h/\/examples\/common-ggml.h/g' \
         -e 's/\/examples\/common-ggml\.cpp/\/examples\/common-ggml.cpp/g' \
+        -e 's/\/examples\/grammar-parser\.h/\/examples\/whisper\/grammar-parser.h/g' \
+        -e 's/\/examples\/grammar-parser\.cpp/\/examples\/whisper\/grammar-parser.cpp/g' \
         -e 's/\/examples\/main\/main\.cpp/\/examples\/whisper\/main.cpp/g' \
         -e 's/\/examples\/quantize\/quantize\.cpp/\/examples\/whisper\/quantize.cpp/g' \
         > whisper-src.patch.tmp
index 7ccde84717912a35eb09249e8e5352a31d67c8c0..b6cbac81e107bce19904d650f239ce5c50209c63 100755 (executable)
@@ -33,6 +33,8 @@ cp -rpv ../whisper.cpp/examples/common.h              examples/common.h
 cp -rpv ../whisper.cpp/examples/common.cpp            examples/common.cpp
 cp -rpv ../whisper.cpp/examples/common-ggml.h         examples/common-ggml.h
 cp -rpv ../whisper.cpp/examples/common-ggml.cpp       examples/common-ggml.cpp
+cp -rpv ../whisper.cpp/examples/grammar-parser.h      examples/whisper/grammar-parser.h
+cp -rpv ../whisper.cpp/examples/grammar-parser.cpp    examples/whisper/grammar-parser.cpp
 
 cp -rpv ../whisper.cpp/whisper.h                      examples/whisper/whisper.h
 cp -rpv ../whisper.cpp/whisper.cpp                    examples/whisper/whisper.cpp