]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : add grammar-based sampling (#1229)
authorEvan Jones <redacted>
Mon, 13 Nov 2023 08:51:34 +0000 (03:51 -0500)
committerGitHub <redacted>
Mon, 13 Nov 2023 08:51:34 +0000 (10:51 +0200)
* whisper : add grammar-based sampling

* build : fix after master merge

* command : fix exception when recognizing the command

* whisper : fine-tuning grammar functionality

* command : grammar-related improvements

- option to read grammar from file
- add sample grammars for colors and chess moves
- fine-tune the performance further

* grammars : add assistant + update comments

* command : enable beam-search, add "no_timestamps", add "context", add p

* whisper : remove comment

---------

Co-authored-by: Georgi Gerganov <redacted>
Makefile
examples/CMakeLists.txt
examples/command/command.cpp
examples/grammar-parser.cpp [new file with mode: 0644]
examples/grammar-parser.h [new file with mode: 0644]
grammars/assistant.gbnf [new file with mode: 0644]
grammars/chess.gbnf [new file with mode: 0644]
grammars/colors.gbnf [new file with mode: 0644]
whisper.cpp
whisper.h

index 20b02c3a8364a0ea83b983e64dd337c631d962cb..23dbb4f3a7f6866d03b6a8316271577fdcea97cf 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -362,8 +362,8 @@ quantize: examples/quantize/quantize.cpp $(WHISPER_OBJ) $(SRC_COMMON)
 stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
        $(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
 
-command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
-       $(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
+command: examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
+       $(CXX) $(CXXFLAGS) examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
 
 lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
        $(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS)
index e089566976781ddce488a27a7a9b7c6fad7bc380..d91019197d9e7d8cb8d412b2f902079f630366f6 100644 (file)
@@ -23,6 +23,7 @@ add_library(${TARGET} STATIC
     common.cpp
     common-ggml.h
     common-ggml.cpp
+    grammar-parser.cpp
     )
 
 include(DefaultTargetOptions)
index 7045f5ff81ee40888755e5dc6d7167ec6c0523d3..51d800a2f2ef1c51f56b247129f7b788bbbb2093 100644 (file)
@@ -9,6 +9,7 @@
 #include "common-sdl.h"
 #include "common.h"
 #include "whisper.h"
+#include "grammar-parser.h"
 
 #include <sstream>
 #include <cassert>
 #include <vector>
 #include <map>
 
+bool file_exists(const std::string & fname) {
+    std::ifstream f(fname.c_str());
+    return f.good();
+}
+
 // command-line parameters
 struct whisper_params {
     int32_t n_threads  = std::min(4, (int32_t) std::thread::hardware_concurrency());
@@ -30,8 +36,12 @@ struct whisper_params {
     int32_t max_tokens = 32;
     int32_t audio_ctx  = 0;
 
-    float vad_thold    = 0.6f;
-    float freq_thold   = 100.0f;
+    float vad_thold  = 0.6f;
+    float freq_thold = 100.0f;
+
+    float grammar_penalty = 100.0f;
+
+    grammar_parser::parse_state grammar_parsed;
 
     bool speed_up      = false;
     bool translate     = false;
@@ -45,6 +55,8 @@ struct whisper_params {
     std::string fname_out;
     std::string commands;
     std::string prompt;
+    std::string context;
+    std::string grammar;
 };
 
 void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@@ -75,6 +87,9 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-f"   || arg == "--file")          { params.fname_out     = argv[++i]; }
         else if (arg == "-cmd" || arg == "--commands")      { params.commands      = argv[++i]; }
         else if (arg == "-p"   || arg == "--prompt")        { params.prompt        = argv[++i]; }
+        else if (arg == "-ctx" || arg == "--context")       { params.context       = argv[++i]; }
+        else if (                 arg == "--grammar")       { params.grammar       = argv[++i]; }
+        else if (                 arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
         else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
             whisper_print_usage(argc, argv, params);
@@ -109,16 +124,30 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -f FNAME,   --file FNAME     [%-7s] text output file name\n",                       params.fname_out.c_str());
     fprintf(stderr, "  -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n",             params.commands.c_str());
     fprintf(stderr, "  -p,         --prompt         [%-7s] the required activation prompt\n",              params.prompt.c_str());
+    fprintf(stderr, "  -ctx,       --context        [%-7s] sample text to help the transcription\n",       params.context.c_str());
+    fprintf(stderr, "  --grammar GRAMMAR            [%-7s] GBNF grammar to guide decoding\n",              params.grammar.c_str());
+    fprintf(stderr, "  --grammar-penalty N          [%-7.1f] scales down logits of nongrammar tokens\n",   params.grammar_penalty);
     fprintf(stderr, "\n");
 }
 
-std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
+std::string transcribe(
+                 whisper_context * ctx,
+            const whisper_params & params,
+        const std::vector<float> & pcmf32,
+               const std::string & grammar_rule,
+                           float & logprob_min,
+                           float & logprob_sum,
+                             int & n_tokens,
+                         int64_t & t_ms) {
     const auto t_start = std::chrono::high_resolution_clock::now();
 
-    prob = 0.0f;
+    logprob_min = 0.0f;
+    logprob_sum = 0.0f;
+    n_tokens    = 0;
     t_ms = 0;
 
-    whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
+    //whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
+    whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH);
 
     wparams.print_progress   = false;
     wparams.print_special    = params.print_special;
@@ -126,19 +155,41 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
     wparams.print_timestamps = !params.no_timestamps;
     wparams.translate        = params.translate;
     wparams.no_context       = true;
+    wparams.no_timestamps    = params.no_timestamps;
     wparams.single_segment   = true;
     wparams.max_tokens       = params.max_tokens;
     wparams.language         = params.language.c_str();
     wparams.n_threads        = params.n_threads;
 
-    wparams.audio_ctx        = params.audio_ctx;
-    wparams.speed_up         = params.speed_up;
+    wparams.audio_ctx = params.audio_ctx;
+    wparams.speed_up  = params.speed_up;
+
+    wparams.temperature     = 0.4f;
+    wparams.temperature_inc = 1.0f;
+    wparams.greedy.best_of  = 5;
+
+    wparams.beam_search.beam_size = 5;
+
+    wparams.initial_prompt = params.context.data();
+
+    const auto & grammar_parsed = params.grammar_parsed;
+    auto grammar_rules = grammar_parsed.c_rules();
+
+    if (!params.grammar_parsed.rules.empty() && !grammar_rule.empty()) {
+        if (grammar_parsed.symbol_ids.find(grammar_rule) == grammar_parsed.symbol_ids.end()) {
+            fprintf(stderr, "%s: warning: grammar rule '%s' not found - skipping grammar sampling\n", __func__, grammar_rule.c_str());
+        } else {
+            wparams.grammar_rules   = grammar_rules.data();
+            wparams.n_grammar_rules = grammar_rules.size();
+            wparams.i_start_rule    = grammar_parsed.symbol_ids.at(grammar_rule);
+            wparams.grammar_penalty = params.grammar_penalty;
+        }
+    }
 
     if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
         return "";
     }
 
-    int prob_n = 0;
     std::string result;
 
     const int n_segments = whisper_full_n_segments(ctx);
@@ -147,19 +198,17 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
 
         result += text;
 
-        const int n_tokens = whisper_full_n_tokens(ctx, i);
-        for (int j = 0; j < n_tokens; ++j) {
+        const int n = whisper_full_n_tokens(ctx, i);
+        for (int j = 0; j < n; ++j) {
             const auto token = whisper_full_get_token_data(ctx, i, j);
 
-            prob += token.p;
-            ++prob_n;
+            if(token.plog > 0.0f) exit(0);
+            logprob_min = std::min(logprob_min, token.plog);
+            logprob_sum += token.plog;
+            ++n_tokens;
         }
     }
 
-    if (prob_n > 0) {
-        prob /= prob_n;
-    }
-
     const auto t_end = std::chrono::high_resolution_clock::now();
     t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
 
@@ -250,7 +299,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
         fprintf(stderr, " ]\n");
     }
 
-    std::string  k_prompt = "select one from the available words: ";
+    std::string k_prompt = "select one from the available words: ";
     for (int i = 0; i < (int) allowed_commands.size(); ++i) {
         if (i > 0) {
             k_prompt += ", ";
@@ -418,7 +467,9 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
     bool is_running = true;
     bool ask_prompt = true;
 
-    float prob = 0.0f;
+    float logprob_min = 0.0f;
+    float logprob_sum = 0.0f;
+    int   n_tokens    = 0;
 
     std::vector<float> pcmf32_cur;
 
@@ -456,7 +507,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
                 // detect the commands
                 audio.get(params.command_ms, pcmf32_cur);
 
-                const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
+                const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "", logprob_min, logprob_sum, n_tokens, t_ms));
 
                 const auto words = get_words(txt);
 
@@ -492,18 +543,27 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
 
 // general-purpose mode
 // freely transcribe the voice into text
-int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
+int process_general_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
     bool is_running  = true;
     bool have_prompt = false;
     bool ask_prompt  = true;
 
-    float prob0 = 0.0f;
-    float prob  = 0.0f;
+    float logprob_min0 = 0.0f;
+    float logprob_min  = 0.0f;
+
+    float logprob_sum0 = 0.0f;
+    float logprob_sum  = 0.0f;
+
+    int n_tokens0 = 0;
+    int n_tokens  = 0;
 
     std::vector<float> pcmf32_cur;
     std::vector<float> pcmf32_prompt;
 
-    const std::string k_prompt = "Ok Whisper, start listening for commands.";
+    std::string k_prompt = "Ok Whisper, start listening for commands.";
+    if (!params.prompt.empty()) {
+        k_prompt = params.prompt;
+    }
 
     fprintf(stderr, "\n");
     fprintf(stderr, "%s: general-purpose mode\n", __func__);
@@ -536,9 +596,11 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
                     // wait for activation phrase
                     audio.get(params.prompt_ms, pcmf32_cur);
 
-                    const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
+                    const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "prompt", logprob_min0, logprob_sum0, n_tokens0, t_ms));
 
-                    fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
+                    const float p = 100.0f * std::exp(logprob_min0);
+
+                    fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms, p = %.2f%%)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms, p);
 
                     const float sim = similarity(txt, k_prompt);
 
@@ -559,19 +621,30 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
                     // we have heard the activation phrase, now detect the commands
                     audio.get(params.command_ms, pcmf32_cur);
 
+                    //printf("len prompt:  %.4f\n", pcmf32_prompt.size() / (float) WHISPER_SAMPLE_RATE);
+                    //printf("len command: %.4f\n", pcmf32_cur.size() / (float) WHISPER_SAMPLE_RATE);
+
+                    // prepend 3 second of silence
+                    pcmf32_cur.insert(pcmf32_cur.begin(), 3.0f*WHISPER_SAMPLE_RATE, 0.0f);
+
                     // prepend the prompt audio
                     pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
 
-                    const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
+                    const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "root", logprob_min, logprob_sum, n_tokens, t_ms));
 
-                    prob = 100.0f*(prob - prob0);
+                    //const float p = 100.0f * std::exp((logprob - logprob0) / (n_tokens - n_tokens0));
+                    const float p = 100.0f * std::exp(logprob_min);
 
                     //fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
 
                     // find the prompt in the text
                     float best_sim = 0.0f;
                     size_t best_len = 0;
-                    for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
+                    for (size_t n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
+                        if (n >= txt.size()) {
+                            break;
+                        }
+
                         const auto prompt = txt.substr(0, n);
 
                         const float sim = similarity(prompt, k_prompt);
@@ -584,9 +657,16 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
                         }
                     }
 
-                    const std::string command = ::trim(txt.substr(best_len));
+                    fprintf(stdout, "%s:   DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p);
+                    if (best_len == 0) {
+                        fprintf(stdout, "%s: WARNING: command not recognized, try again\n", __func__);
+                    } else {
+                        // cut the prompt from the decoded text
+                        const std::string command = ::trim(txt.substr(best_len));
+
+                        fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
+                    }
 
-                    fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
                     fprintf(stdout, "\n");
                 }
 
@@ -654,12 +734,36 @@ int main(int argc, char ** argv) {
 
     int  ret_val = 0;
 
-    if (!params.commands.empty()) {
-        ret_val = process_command_list(ctx, audio, params);
-    } else if (!params.prompt.empty()) {
-        ret_val = always_prompt_transcription(ctx, audio, params);
-    } else {
-        ret_val = process_general_transcription(ctx, audio, params);
+    if (!params.grammar.empty()) {
+        auto & grammar = params.grammar_parsed;
+        if (file_exists(params.grammar.c_str())) {
+            // read grammar from file
+            std::ifstream ifs(params.grammar.c_str());
+            const std::string txt = std::string((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
+            grammar = grammar_parser::parse(txt.c_str());
+        } else {
+            // read grammar from string
+            grammar = grammar_parser::parse(params.grammar.c_str());
+        }
+
+        // will be empty (default) if there are parse errors
+        if (grammar.rules.empty()) {
+            ret_val = 1;
+        } else {
+            fprintf(stderr, "%s: grammar:\n", __func__);
+            grammar_parser::print_grammar(stderr, grammar);
+            fprintf(stderr, "\n");
+        }
+    }
+
+    if (ret_val == 0) {
+        if (!params.commands.empty()) {
+            ret_val = process_command_list(ctx, audio, params);
+        } else if (!params.prompt.empty() && params.grammar_parsed.rules.empty()) {
+            ret_val = always_prompt_transcription(ctx, audio, params);
+        } else {
+            ret_val = process_general_transcription(ctx, audio, params);
+        }
     }
 
     audio.pause();
diff --git a/examples/grammar-parser.cpp b/examples/grammar-parser.cpp
new file mode 100644 (file)
index 0000000..2daaaef
--- /dev/null
@@ -0,0 +1,423 @@
+#include "grammar-parser.h"
+#include <cstdint>
+#include <cwchar>
+#include <string>
+#include <utility>
+#include <stdexcept>
+#include <exception>
+
+namespace grammar_parser {
+    // NOTE: assumes valid utf8 (but checks for overrun)
+    // copied from whisper.cpp
+    std::pair<uint32_t, const char *> decode_utf8(const char * src) {
+        static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
+        uint8_t  first_byte = static_cast<uint8_t>(*src);
+        uint8_t  highbits   = first_byte >> 4;
+        int      len        = lookup[highbits];
+        uint8_t  mask       = (1 << (8 - len)) - 1;
+        uint32_t value      = first_byte & mask;
+        const char * end    = src + len; // may overrun!
+        const char * pos    = src + 1;
+        for ( ; pos < end && *pos; pos++) {
+            value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
+        }
+        return std::make_pair(value, pos);
+    }
+
+    uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) {
+        uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
+        auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id));
+        return result.first->second;
+    }
+
+    uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) {
+        uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
+        state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
+        return next_id;
+    }
+
+    void add_rule(
+            parse_state & state,
+            uint32_t      rule_id,
+            const std::vector<whisper_grammar_element> & rule) {
+        if (state.rules.size() <= rule_id) {
+            state.rules.resize(rule_id + 1);
+        }
+        state.rules[rule_id] = rule;
+    }
+
+    bool is_word_char(char c) {
+        return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
+    }
+
+    std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
+        const char * pos   = src;
+        const char * end   = src + size;
+        uint32_t     value = 0;
+        for ( ; pos < end && *pos; pos++) {
+            value <<= 4;
+            char c = *pos;
+            if ('a' <= c && c <= 'f') {
+                value += c - 'a' + 10;
+            } else if ('A' <= c && c <= 'F') {
+                value += c - 'A' + 10;
+            } else if ('0' <= c && c <= '9') {
+                value += c - '0';
+            } else {
+                break;
+            }
+        }
+        if (pos != end) {
+            throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
+        }
+        return std::make_pair(value, pos);
+    }
+
+    const char * parse_space(const char * src, bool newline_ok) {
+        const char * pos = src;
+        while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
+                (newline_ok && (*pos == '\r' || *pos == '\n'))) {
+            if (*pos == '#') {
+                while (*pos && *pos != '\r' && *pos != '\n') {
+                    pos++;
+                }
+            } else {
+                pos++;
+            }
+        }
+        return pos;
+    }
+
+    const char * parse_name(const char * src) {
+        const char * pos = src;
+        while (is_word_char(*pos)) {
+            pos++;
+        }
+        if (pos == src) {
+            throw std::runtime_error(std::string("expecting name at ") + src);
+        }
+        return pos;
+    }
+
+    std::pair<uint32_t, const char *> parse_char(const char * src) {
+        if (*src == '\\') {
+            switch (src[1]) {
+                case 'x': return parse_hex(src + 2, 2);
+                case 'u': return parse_hex(src + 2, 4);
+                case 'U': return parse_hex(src + 2, 8);
+                case 't': return std::make_pair('\t', src + 2);
+                case 'r': return std::make_pair('\r', src + 2);
+                case 'n': return std::make_pair('\n', src + 2);
+                case '\\':
+                case '"':
+                case '[':
+                case ']':
+                    return std::make_pair(src[1], src + 2);
+                default:
+                    throw std::runtime_error(std::string("unknown escape at ") + src);
+            }
+        } else if (*src) {
+            return decode_utf8(src);
+        }
+        throw std::runtime_error("unexpected end of input");
+    }
+
+    const char * parse_alternates(
+            parse_state       & state,
+            const char        * src,
+            const std::string & rule_name,
+            uint32_t            rule_id,
+            bool                is_nested);
+
+    const char * parse_sequence(
+            parse_state                        & state,
+            const char                         * src,
+            const std::string                  & rule_name,
+            std::vector<whisper_grammar_element> & out_elements,
+            bool                                 is_nested) {
+        size_t last_sym_start = out_elements.size();
+        const char * pos = src;
+        while (*pos) {
+            if (*pos == '"') { // literal string
+                pos++;
+                last_sym_start = out_elements.size();
+                while (*pos != '"') {
+                    auto char_pair = parse_char(pos);
+                         pos       = char_pair.second;
+                    out_elements.push_back({WHISPER_GRETYPE_CHAR, char_pair.first});
+                }
+                pos = parse_space(pos + 1, is_nested);
+            } else if (*pos == '[') { // char range(s)
+                pos++;
+                enum whisper_gretype start_type = WHISPER_GRETYPE_CHAR;
+                if (*pos == '^') {
+                    pos++;
+                    start_type = WHISPER_GRETYPE_CHAR_NOT;
+                }
+                last_sym_start = out_elements.size();
+                while (*pos != ']') {
+                    auto char_pair = parse_char(pos);
+                         pos       = char_pair.second;
+                    enum whisper_gretype type = last_sym_start < out_elements.size()
+                        ? WHISPER_GRETYPE_CHAR_ALT
+                        : start_type;
+
+                    out_elements.push_back({type, char_pair.first});
+                    if (pos[0] == '-' && pos[1] != ']') {
+                        auto endchar_pair = parse_char(pos + 1);
+                             pos          = endchar_pair.second;
+                        out_elements.push_back({WHISPER_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
+                    }
+                }
+                pos = parse_space(pos + 1, is_nested);
+            } else if (is_word_char(*pos)) { // rule reference
+                const char * name_end    = parse_name(pos);
+                uint32_t     ref_rule_id = get_symbol_id(state, pos, name_end - pos);
+                pos = parse_space(name_end, is_nested);
+                last_sym_start = out_elements.size();
+                out_elements.push_back({WHISPER_GRETYPE_RULE_REF, ref_rule_id});
+            } else if (*pos == '(') { // grouping
+                // parse nested alternates into synthesized rule
+                pos = parse_space(pos + 1, true);
+                uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
+                pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
+                last_sym_start = out_elements.size();
+                // output reference to synthesized rule
+                out_elements.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
+                if (*pos != ')') {
+                    throw std::runtime_error(std::string("expecting ')' at ") + pos);
+                }
+                pos = parse_space(pos + 1, is_nested);
+            } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator
+                if (last_sym_start == out_elements.size()) {
+                    throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos);
+                }
+
+                // apply transformation to previous symbol (last_sym_start to end) according to
+                // rewrite rules:
+                // S* --> S' ::= S S' |
+                // S+ --> S' ::= S S' | S
+                // S? --> S' ::= S |
+                uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
+                std::vector<whisper_grammar_element> sub_rule;
+                // add preceding symbol to generated rule
+                sub_rule.insert(
+                    sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
+                if (*pos == '*' || *pos == '+') {
+                    // cause generated rule to recurse
+                    sub_rule.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
+                }
+                // mark start of alternate def
+                sub_rule.push_back({WHISPER_GRETYPE_ALT, 0});
+                if (*pos == '+') {
+                    // add preceding symbol as alternate only for '+' (otherwise empty)
+                    sub_rule.insert(
+                        sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
+                }
+                sub_rule.push_back({WHISPER_GRETYPE_END, 0});
+                add_rule(state, sub_rule_id, sub_rule);
+
+                // in original rule, replace previous symbol with reference to generated rule
+                out_elements.resize(last_sym_start);
+                out_elements.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
+
+                pos = parse_space(pos + 1, is_nested);
+            } else {
+                break;
+            }
+        }
+        return pos;
+    }
+
+    const char * parse_alternates(
+            parse_state       & state,
+            const char        * src,
+            const std::string & rule_name,
+            uint32_t            rule_id,
+            bool                is_nested) {
+        std::vector<whisper_grammar_element> rule;
+        const char * pos = parse_sequence(state, src, rule_name, rule, is_nested);
+        while (*pos == '|') {
+            rule.push_back({WHISPER_GRETYPE_ALT, 0});
+            pos = parse_space(pos + 1, true);
+            pos = parse_sequence(state, pos, rule_name, rule, is_nested);
+        }
+        rule.push_back({WHISPER_GRETYPE_END, 0});
+        add_rule(state, rule_id, rule);
+        return pos;
+    }
+
+    const char * parse_rule(parse_state & state, const char * src) {
+        const char * name_end = parse_name(src);
+        const char * pos      = parse_space(name_end, false);
+        size_t       name_len = name_end - src;
+        uint32_t     rule_id  = get_symbol_id(state, src, name_len);
+        const std::string name(src, name_len);
+
+        if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
+            throw std::runtime_error(std::string("expecting ::= at ") + pos);
+        }
+        pos = parse_space(pos + 3, true);
+
+        pos = parse_alternates(state, pos, name, rule_id, false);
+
+        if (*pos == '\r') {
+            pos += pos[1] == '\n' ? 2 : 1;
+        } else if (*pos == '\n') {
+            pos++;
+        } else if (*pos) {
+            throw std::runtime_error(std::string("expecting newline or end at ") + pos);
+        }
+        return parse_space(pos, true);
+    }
+
+    parse_state parse(const char * src) {
+        try {
+            parse_state state;
+            const char * pos = parse_space(src, true);
+            while (*pos) {
+                pos = parse_rule(state, pos);
+            }
+            return state;
+        } catch (const std::exception & err) {
+            fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
+            return parse_state();
+        }
+    }
+
+    void print_grammar_char(FILE * file, uint32_t c) {
+        if (0x20 <= c && c <= 0x7f) {
+            fprintf(file, "%c", static_cast<char>(c));
+        } else {
+            // cop out of encoding UTF-8
+            fprintf(file, "<U+%04X>", c);
+        }
+    }
+
+    bool is_char_element(whisper_grammar_element elem) {
+        switch (elem.type) {
+            case WHISPER_GRETYPE_CHAR:           return true;
+            case WHISPER_GRETYPE_CHAR_NOT:       return true;
+            case WHISPER_GRETYPE_CHAR_ALT:       return true;
+            case WHISPER_GRETYPE_CHAR_RNG_UPPER: return true;
+            default:                           return false;
+        }
+    }
+
+    void print_rule_binary(FILE * file, const std::vector<whisper_grammar_element> & rule) {
+        for (auto elem : rule) {
+            switch (elem.type) {
+                case WHISPER_GRETYPE_END:            fprintf(file, "END");            break;
+                case WHISPER_GRETYPE_ALT:            fprintf(file, "ALT");            break;
+                case WHISPER_GRETYPE_RULE_REF:       fprintf(file, "RULE_REF");       break;
+                case WHISPER_GRETYPE_CHAR:           fprintf(file, "CHAR");           break;
+                case WHISPER_GRETYPE_CHAR_NOT:       fprintf(file, "CHAR_NOT");       break;
+                case WHISPER_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
+                case WHISPER_GRETYPE_CHAR_ALT:       fprintf(file, "CHAR_ALT");       break;
+            }
+            switch (elem.type) {
+                case WHISPER_GRETYPE_END:
+                case WHISPER_GRETYPE_ALT:
+                case WHISPER_GRETYPE_RULE_REF:
+                    fprintf(file, "(%u) ", elem.value);
+                    break;
+                case WHISPER_GRETYPE_CHAR:
+                case WHISPER_GRETYPE_CHAR_NOT:
+                case WHISPER_GRETYPE_CHAR_RNG_UPPER:
+                case WHISPER_GRETYPE_CHAR_ALT:
+                    fprintf(file, "(\"");
+                    print_grammar_char(file, elem.value);
+                    fprintf(file, "\") ");
+                    break;
+            }
+        }
+        fprintf(file, "\n");
+    }
+
+    void print_rule(
+            FILE     * file,
+            uint32_t   rule_id,
+            const std::vector<whisper_grammar_element> & rule,
+            const std::map<uint32_t, std::string>    & symbol_id_names) {
+        if (rule.empty() || rule.back().type != WHISPER_GRETYPE_END) {
+            throw std::runtime_error(
+                "malformed rule, does not end with WHISPER_GRETYPE_END: " + std::to_string(rule_id));
+        }
+        fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
+        for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
+            whisper_grammar_element elem = rule[i];
+            switch (elem.type) {
+                case WHISPER_GRETYPE_END:
+                    throw std::runtime_error(
+                        "unexpected end of rule: " + std::to_string(rule_id) + "," +
+                        std::to_string(i));
+                case WHISPER_GRETYPE_ALT:
+                    fprintf(file, "| ");
+                    break;
+                case WHISPER_GRETYPE_RULE_REF:
+                    fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
+                    break;
+                case WHISPER_GRETYPE_CHAR:
+                    fprintf(file, "[");
+                    print_grammar_char(file, elem.value);
+                    break;
+                case WHISPER_GRETYPE_CHAR_NOT:
+                    fprintf(file, "[^");
+                    print_grammar_char(file, elem.value);
+                    break;
+                case WHISPER_GRETYPE_CHAR_RNG_UPPER:
+                    if (i == 0 || !is_char_element(rule[i - 1])) {
+                        throw std::runtime_error(
+                            "WHISPER_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
+                            std::to_string(rule_id) + "," + std::to_string(i));
+                    }
+                    fprintf(file, "-");
+                    print_grammar_char(file, elem.value);
+                    break;
+                case WHISPER_GRETYPE_CHAR_ALT:
+                    if (i == 0 || !is_char_element(rule[i - 1])) {
+                        throw std::runtime_error(
+                            "WHISPER_GRETYPE_CHAR_ALT without preceding char: " +
+                            std::to_string(rule_id) + "," + std::to_string(i));
+                    }
+                    print_grammar_char(file, elem.value);
+                    break;
+            }
+            if (is_char_element(elem)) {
+                switch (rule[i + 1].type) {
+                    case WHISPER_GRETYPE_CHAR_ALT:
+                    case WHISPER_GRETYPE_CHAR_RNG_UPPER:
+                        break;
+                    default:
+                        fprintf(file, "] ");
+                }
+            }
+        }
+        fprintf(file, "\n");
+    }
+
+    void print_grammar(FILE * file, const parse_state & state) {
+        try {
+            std::map<uint32_t, std::string> symbol_id_names;
+            for (auto kv : state.symbol_ids) {
+                symbol_id_names[kv.second] = kv.first;
+            }
+            for (size_t i = 0, end = state.rules.size(); i < end; i++) {
+                // fprintf(file, "%zu: ", i);
+                // print_rule_binary(file, state.rules[i]);
+                print_rule(file, uint32_t(i), state.rules[i], symbol_id_names);
+                // fprintf(file, "\n");
+            }
+        } catch (const std::exception & err) {
+            fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
+        }
+    }
+
+    std::vector<const whisper_grammar_element *> parse_state::c_rules() const{
+        std::vector<const whisper_grammar_element *> ret;
+        for (const auto & rule : rules) {
+            ret.push_back(rule.data());
+        }
+        return ret;
+    }
+}
diff --git a/examples/grammar-parser.h b/examples/grammar-parser.h
new file mode 100644 (file)
index 0000000..47d019c
--- /dev/null
@@ -0,0 +1,29 @@
+// Implements a parser for an extended Backus-Naur form (BNF), producing the
+// binary context-free grammar format specified by whisper.h. Supports character
+// ranges, grouping, and repetition operators. As an example, a grammar for
+// arithmetic might look like:
+//
+// root  ::= expr
+// expr  ::= term ([-+*/] term)*
+// term  ::= num | "(" space expr ")" space
+// num   ::= [0-9]+ space
+// space ::= [ \t\n]*
+
+#pragma once
+#include "whisper.h"
+#include <vector>
+#include <map>
+#include <cstdint>
+#include <string>
+
+namespace grammar_parser {
+    struct parse_state {
+        std::map<std::string, uint32_t>                   symbol_ids;
+        std::vector<std::vector<whisper_grammar_element>> rules;
+
+        std::vector<const whisper_grammar_element *>      c_rules() const;
+    };
+
+    parse_state parse(const char * src);
+    void print_grammar(FILE * file, const parse_state & state);
+}
diff --git a/grammars/assistant.gbnf b/grammars/assistant.gbnf
new file mode 100644 (file)
index 0000000..c445778
--- /dev/null
@@ -0,0 +1,57 @@
+# - "turn on lights."
+# - "set thermostat to 22."
+# - "increase TV by 10."
+# - "decrease oven by 50."
+# - "play music."
+# - "stop podcast."
+# - "schedule cleaning at 3pm."
+# - "cancel cleaning."
+# - "remind me to buy milk at 5pm."
+# - "show me security system."
+# - "hide washing machine."
+# - "what is the lights status?"
+# - "what is the current thermostat value?"
+# - "what is the security system status?"
+# - "what is the door lock status?"
+# - "what is the camera battery level?"
+# - "what is the weather like today?"
+# - "what is the forecast for tomorrow?"
+# - "what is the time?"
+# - "what is my schedule for today?"
+# - "what tasks do I have?"
+# - "what reminders do I have?"
+#
+# example:
+#
+#   ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/assistant.gbnf --prompt "Ok Whisper, start listening for commands." --context "Whisper is a home assistant. It recognizes voice commands. Time is 11pm." --grammar-penalty 10
+#
+
+root   ::= init " " (command | question) "."
+prompt ::= init
+
+# leading space is very important!
+init ::= " Ok Whisper, start listening for commands."
+
+command ::= "Turn " ("on" | "off") " " device | "Set " device " to " value |
+            "Increase " device " by " value | "Decrease " device " by " value |
+            "Play " media | "Stop " media | "Schedule " task " at " time | "Cancel " task |
+            "Remind me to " task " at " time | "Show me " device | "Hide " device
+
+question ::= "What is the " device " status?" | "What is the current " device " value?" |
+             "What is the " device " temperature?" | "What is the " device " humidity?" |
+             "What is the " device " power consumption?" | "What is the " device " battery level?" |
+             "What is the weather like today?" | "What is the forecast for tomorrow?" |
+             "What is the time?" | "What is my schedule for today?" | "What tasks do I have?" |
+             "What reminders do I have?"
+
+device ::= "lights" | "thermostat" | "security system" | "door lock" | "camera" | "speaker" | "TV" |
+           "music player" | "coffee machine" | "oven" | "refrigerator" | "washing machine" |
+           "vacuum cleaner"
+
+value ::= [0-9]+
+
+media ::= "music" | "radio" | "podcast" | "audiobook" | "TV show" | "movie"
+
+task ::= [a-zA-Z]+ (" " [a-zA-Z]+)?
+
+time ::= [0-9] [0-9]? ("am" | "pm")?
diff --git a/grammars/chess.gbnf b/grammars/chess.gbnf
new file mode 100644 (file)
index 0000000..ec8c842
--- /dev/null
@@ -0,0 +1,29 @@
+# - bishop to c3
+# - rook to d4
+# - knight to e5
+# - d4 d5 knight to c3
+# - c3 queen to d4 king b1
+# - pawn to a1 bishop to b2 knight to c3
+#
+# The prompt (--prompt) is the initial phrase that the user has to say.
+# This is used to prime Whisper with how the user is expected to speak.
+#
+# Provide long context (--context) with sample moves to help Whisper decode the correct sequence.
+# Longer context is better, but it slightly increases the processing time.
+#
+# example:
+#
+#   ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/chess.gbnf --prompt "rook to b4, f3," --context "d4 d5 knight to c3, pawn to a1, bishop to b2 king e8," --grammar-penalty 100
+#
+
+root   ::= init move move? move? "."
+prompt ::= init "."
+
+# leading space is very important!
+init ::= " rook to b4, f3"
+
+move ::= ", " ((piece | pawn | king) " " "to "?)? [a-h] [1-8]
+
+piece ::= "bishop" | "rook" | "knight" | "queen"
+king  ::= "king"
+pawn  ::= "pawn"
diff --git a/grammars/colors.gbnf b/grammars/colors.gbnf
new file mode 100644 (file)
index 0000000..1d99450
--- /dev/null
@@ -0,0 +1,16 @@
+# - red
+# - green
+# - blue
+#
+# example:
+#
+#   ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/colors.gbnf --prompt "red, green, blue," --context "green, red, blue,"
+#
+
+root   ::= init color "."
+prompt ::= init "."
+
+# leading space is very important!
+init ::= " red, green, blue"
+
+color ::= ", " ("red" | "green" | "blue")
index 1c7d7e9440dcb82cdedf3172aec31ee01c74e46e..c0e91152703b1ead197e2f710dbe5eb1e0858ec7 100644 (file)
@@ -579,6 +579,25 @@ struct whisper_model {
     std::map<std::string, struct ggml_tensor *> tensors;
 };
 
+struct whisper_partial_utf8 {
+    uint32_t value;    // bit value so far (unshifted)
+    int      n_remain; // num bytes remaining; -1 indicates invalid sequence
+};
+
+struct whisper_grammar {
+    /*const*/ std::vector<std::vector<whisper_grammar_element>>   rules;
+    std::vector<std::vector<const whisper_grammar_element *>> stacks;
+
+    // buffer for partially generated UTF-8 sequence from accepted tokens
+    whisper_partial_utf8                                      partial_utf8;
+};
+
+struct whisper_grammar_candidate {
+    whisper_token          id;
+    const uint32_t       * code_points;
+    whisper_partial_utf8   partial_utf8;
+};
+
 struct whisper_sequence {
     std::vector<whisper_token_data> tokens;
 
@@ -600,6 +619,9 @@ struct whisper_decoder {
     // the currently generated sequence of tokens
     whisper_sequence sequence;
 
+    // grammar parse state of generated sequence of tokens
+    whisper_grammar  grammar;
+
     int seek_delta; // the window shift found so far based on the decoded timestamp tokens
 
     bool failed;    // has the current segment failed to decode?
@@ -3685,6 +3707,425 @@ const char * whisper_print_system_info(void) {
     return s.c_str();
 }
 
+//////////////////////////////////
+// Grammar - ported from llama.cpp
+//////////////////////////////////
+
+// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
+// pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
+std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
+        const char         * src,
+        whisper_partial_utf8   partial_start) {
+    static const int      lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
+    const char          * pos      = src;
+    std::vector<uint32_t> code_points;
+    uint32_t              value    = partial_start.value;
+    int                   n_remain = partial_start.n_remain;
+
+    // continue previous decode, if applicable
+    while (*pos != 0 && n_remain > 0) {
+        uint8_t next_byte = static_cast<uint8_t>(*pos);
+        if ((next_byte >> 6) != 2) {
+            // invalid sequence, abort
+            code_points.push_back(0);
+            return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
+        }
+        value = (value << 6) + (next_byte & 0x3F);
+        ++pos;
+        --n_remain;
+    }
+
+    if (partial_start.n_remain > 0 && n_remain == 0) {
+        code_points.push_back(value);
+    }
+
+    // decode any subsequent utf-8 sequences, which may end in an incomplete one
+    while (*pos != 0) {
+        uint8_t  first_byte = static_cast<uint8_t>(*pos);
+        uint8_t  highbits   = first_byte >> 4;
+                 n_remain   = lookup[highbits] - 1;
+
+        if (n_remain < 0) {
+            // invalid sequence, abort
+            code_points.clear();
+            code_points.push_back(0);
+            return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain });
+        }
+
+        uint8_t  mask       = (1 << (7 - n_remain)) - 1;
+                 value      = first_byte & mask;
+        ++pos;
+        while (*pos != 0 && n_remain > 0) {
+            value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
+            ++pos;
+            --n_remain;
+        }
+        if (n_remain == 0) {
+            code_points.push_back(value);
+        }
+    }
+    code_points.push_back(0);
+
+    return std::make_pair(std::move(code_points), whisper_partial_utf8{ value, n_remain });
+}
+
+// returns true iff pos points to the end of one of the definitions of a rule
+static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element * pos) {
+    switch (pos->type) {
+        case WHISPER_GRETYPE_END: return true;  // NOLINT
+        case WHISPER_GRETYPE_ALT: return true;  // NOLINT
+        default:                return false;
+    }
+}
+
+// returns true iff chr satisfies the char range at pos (regular or inverse range)
+// asserts that pos is pointing to a char range element
+static std::pair<bool, const whisper_grammar_element *> whisper_grammar_match_char(
+        const whisper_grammar_element * pos,
+        const uint32_t                chr) {
+
+    bool found            = false;
+    bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
+
+    WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT
+
+    do {
+        if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
+            // inclusive range, e.g. [a-z]
+            found = found || (pos->value <= chr && chr <= pos[1].value);
+            pos += 2;
+        } else {
+            // exact char match, e.g. [a] or "a"
+            found = found || pos->value == chr;
+            pos += 1;
+        }
+    } while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
+
+    return std::make_pair(found == is_positive_char, pos);
+}
+
+// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
+// range at pos (regular or inverse range)
+// asserts that pos is pointing to a char range element
+static bool whisper_grammar_match_partial_char(
+        const whisper_grammar_element * pos,
+        const whisper_partial_utf8      partial_utf8) {
+
+    bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
+    WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT);
+
+    uint32_t partial_value = partial_utf8.value;
+    int      n_remain      = partial_utf8.n_remain;
+
+    // invalid sequence or 7-bit char split across 2 bytes (overlong)
+    if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
+        return false;
+    }
+
+    // range of possible code points this partial UTF-8 sequence could complete to
+    uint32_t low  = partial_value << (n_remain * 6);
+    uint32_t high = low | ((1 << (n_remain * 6)) - 1);
+
+    if (low == 0) {
+        if (n_remain == 2) {
+            low = 1 << 11;
+        } else if (n_remain == 3) {
+            low = 1 << 16;
+        }
+    }
+
+    do {
+        if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
+            // inclusive range, e.g. [a-z]
+            if (pos->value <= high && low <= pos[1].value) {
+                return is_positive_char;
+            }
+            pos += 2;
+        } else {
+            // exact char match, e.g. [a] or "a"
+            if (low <= pos->value && pos->value <= high) {
+                return is_positive_char;
+            }
+            pos += 1;
+        }
+    } while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
+
+    return !is_positive_char;
+}
+
+
+// transforms a grammar pushdown stack into N possible stacks, all ending
+// at a character range (terminal element)
+static void whisper_grammar_advance_stack(
+        const std::vector<std::vector<whisper_grammar_element>>   & rules,
+        const std::vector<const whisper_grammar_element *>        & stack,
+        std::vector<std::vector<const whisper_grammar_element *>> & new_stacks) {
+
+    if (stack.empty()) {
+        new_stacks.push_back(stack);
+        return;
+    }
+
+    const whisper_grammar_element * pos = stack.back();
+
+    switch (pos->type) {
+        case WHISPER_GRETYPE_RULE_REF: {
+            const size_t                  rule_id = static_cast<size_t>(pos->value);
+            const whisper_grammar_element * subpos  = rules[rule_id].data();
+            do {
+                // init new stack without the top (pos)
+                std::vector<const whisper_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
+                if (!whisper_grammar_is_end_of_sequence(pos + 1)) {
+                    // if this rule ref is followed by another element, add that to stack
+                    new_stack.push_back(pos + 1);
+                }
+                if (!whisper_grammar_is_end_of_sequence(subpos)) {
+                    // if alternate is nonempty, add to stack
+                    new_stack.push_back(subpos);
+                }
+                whisper_grammar_advance_stack(rules, new_stack, new_stacks);
+                while (!whisper_grammar_is_end_of_sequence(subpos)) {
+                    // scan to end of alternate def
+                    subpos++;
+                }
+                if (subpos->type == WHISPER_GRETYPE_ALT) {
+                    // there's another alternate def of this rule to process
+                    subpos++;
+                } else {
+                    break;
+                }
+            } while (true);
+            break;
+        }
+        case WHISPER_GRETYPE_CHAR:
+        case WHISPER_GRETYPE_CHAR_NOT:
+            new_stacks.push_back(stack);
+            break;
+        default:
+            // end of alternate (WHISPER_GRETYPE_END, WHISPER_GRETYPE_ALT) or middle of char range
+            // (WHISPER_GRETYPE_CHAR_ALT, WHISPER_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
+            // those
+            WHISPER_ASSERT(false);
+    }
+}
+
+// takes a set of possible pushdown stacks on a grammar, which are required to
+// be positioned at a character range (see `whisper_grammar_advance_stack`), and
+// produces the N possible stacks if the given char is accepted at those
+// positions
+static std::vector<std::vector<const whisper_grammar_element *>> whisper_grammar_accept(
+        const std::vector<std::vector<whisper_grammar_element>>         & rules,
+        const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
+        const uint32_t                                                  chr) {
+
+    std::vector<std::vector<const whisper_grammar_element *>> new_stacks;
+
+    for (const auto & stack : stacks) {
+        if (stack.empty()) {
+            continue;
+        }
+
+        auto match = whisper_grammar_match_char(stack.back(), chr);
+        if (match.first) {
+            const whisper_grammar_element * pos = match.second;
+
+            // update top of stack to next element, if any
+            std::vector<const whisper_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
+            if (!whisper_grammar_is_end_of_sequence(pos)) {
+                new_stack.push_back(pos);
+            }
+            whisper_grammar_advance_stack(rules, new_stack, new_stacks);
+        }
+    }
+
+    return new_stacks;
+}
+
+static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates(
+        const std::vector<std::vector<whisper_grammar_element>>         & rules,
+        const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
+        const std::vector<whisper_grammar_candidate>                    & candidates);
+
+static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates_for_stack(
+        const std::vector<std::vector<whisper_grammar_element>> & rules,
+        const std::vector<const whisper_grammar_element *>      & stack,
+        const std::vector<whisper_grammar_candidate>            & candidates) {
+
+    std::vector<whisper_grammar_candidate> rejects;
+
+    if (stack.empty()) {
+        for (auto tok : candidates) {
+            if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
+                rejects.push_back(tok);
+            }
+        }
+        return rejects;
+    }
+
+    const whisper_grammar_element * stack_pos = stack.back();
+
+    std::vector<whisper_grammar_candidate> next_candidates;
+    for (auto tok : candidates) {
+        if (*tok.code_points == 0) {
+            // reached end of full codepoints in token, reject iff it ended in a partial sequence
+            // that cannot satisfy this position in grammar
+            if (tok.partial_utf8.n_remain != 0 &&
+                    !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
+                rejects.push_back(tok);
+            }
+        } else if (whisper_grammar_match_char(stack_pos, *tok.code_points).first) {
+            next_candidates.push_back({ tok.id, tok.code_points + 1, tok.partial_utf8 });
+        } else {
+            rejects.push_back(tok);
+        }
+    }
+
+    const auto * stack_pos_after = whisper_grammar_match_char(stack_pos, 0).second;
+
+    // update top of stack to next element, if any
+    std::vector<const whisper_grammar_element *> stack_after(stack.begin(), stack.end() - 1);
+    if (!whisper_grammar_is_end_of_sequence(stack_pos_after)) {
+        stack_after.push_back(stack_pos_after);
+    }
+    std::vector<std::vector<const whisper_grammar_element *>> next_stacks;
+    whisper_grammar_advance_stack(rules, stack_after, next_stacks);
+
+    auto next_rejects = whisper_grammar_reject_candidates(rules, next_stacks, next_candidates);
+    for (auto tok : next_rejects) {
+        rejects.push_back({ tok.id, tok.code_points - 1, tok.partial_utf8 });
+    }
+
+    return rejects;
+}
+
+static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates(
+        const std::vector<std::vector<whisper_grammar_element>>         & rules,
+        const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
+        const std::vector<whisper_grammar_candidate>                    & candidates) {
+    if (candidates.empty() || stacks.empty()) {
+        return std::vector<whisper_grammar_candidate>();
+    }
+
+    auto rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
+
+    for (size_t i = 1, size = stacks.size(); i < size; ++i) {
+        rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
+    }
+    return rejects;
+}
+
+static struct whisper_grammar whisper_grammar_init(
+            const whisper_grammar_element ** rules,
+                                 size_t      n_rules,
+                                 size_t      i_start_rule) {
+    const whisper_grammar_element * pos;
+
+    // copy rule definitions into vectors
+    std::vector<std::vector<whisper_grammar_element>> vec_rules(n_rules);
+    for (size_t i = 0; i < n_rules; i++) {
+        for (pos = rules[i]; pos->type != WHISPER_GRETYPE_END; pos++) {
+            vec_rules[i].push_back(*pos);
+        }
+        vec_rules[i].push_back({WHISPER_GRETYPE_END, 0});
+    }
+
+    // loop over alternates of start rule to build initial stacks
+    std::vector<std::vector<const whisper_grammar_element *>> stacks;
+    pos = rules[i_start_rule];
+    do {
+        std::vector<const whisper_grammar_element *> stack;
+        if (!whisper_grammar_is_end_of_sequence(pos)) {
+            // if alternate is nonempty, add to stack
+            stack.push_back(pos);
+        }
+        whisper_grammar_advance_stack(vec_rules, stack, stacks);
+        while (!whisper_grammar_is_end_of_sequence(pos)) {
+            // scan to end of alternate def
+            pos++;
+        }
+        if (pos->type == WHISPER_GRETYPE_ALT) {
+            // there's another alternate def of this rule to process
+            pos++;
+        } else {
+            break;
+        }
+    } while (true);
+
+    return { std::move(vec_rules), std::move(stacks), {} };
+}
+
+static void whisper_suppress_invalid_grammar(
+             whisper_context  & ctx,
+    const whisper_full_params & params,
+           std::vector<float> & logits,
+    const     whisper_grammar & grammar) {
+
+    if (grammar.rules.empty() || grammar.stacks.empty()) {
+        return;
+    }
+
+    //bool allow_eot = false;
+    //for (const auto & stack : grammar.stacks) {
+    //    if (stack.empty()) {
+    //        allow_eot = true;
+    //        break;
+    //    }
+    //}
+
+    const whisper_token eot = whisper_token_eot(&ctx);
+
+    std::vector<std::pair<std::vector<uint32_t>, whisper_partial_utf8>> candidates_decoded;
+    std::vector<whisper_grammar_candidate>                              candidates_grammar;
+
+    for (whisper_token id = 0; id < eot; ++id) {
+        const std::string & text = ctx.vocab.id_to_token[id];
+        if (!text.empty()) {
+            candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8));
+            candidates_grammar.push_back({ id, candidates_decoded.back().first.data(), candidates_decoded.back().second });
+        }
+    }
+
+    const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
+
+    for (const auto & reject : rejects) {
+        logits[reject.id] -= params.grammar_penalty;
+    }
+
+    // when the grammar allows a continuation, we penalize the end-of-text token
+    //if (!allow_eot) {
+    //    logits[eot] -= params.grammar_penalty;
+    //}
+    //fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size());
+}
+
+static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar & grammar, whisper_token token) {
+    if (grammar.rules.empty() || grammar.stacks.empty()) {
+        return;
+    }
+
+    //fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str());
+
+    const std::string & text = ctx.vocab.id_to_token[token];
+
+    if (text.rfind("[_", 0) == 0) {
+        // fprintf(stderr, " (skipped)\n");
+        return;
+    }
+    // fprintf(stderr, "\n");
+
+    // Note terminating 0 in decoded string
+    const auto   decoded     = decode_utf8(text.c_str(), grammar.partial_utf8);
+    const auto & code_points = decoded.first;
+    for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
+        grammar.stacks = whisper_grammar_accept(grammar.rules, grammar.stacks, *it);
+    }
+    grammar.partial_utf8 = decoded.second;
+}
+
+//////////////
+// END grammar
+//////////////
+
 ////////////////////////////////////////////////////////////////////////////
 
 struct whisper_context_params * whisper_context_default_params_by_ref() {
@@ -3714,6 +4155,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
 
         /*.translate         =*/ false,
         /*.no_context        =*/ true,
+        /*.no_timestamps     =*/ false,
         /*.single_segment    =*/ false,
         /*.print_special     =*/ false,
         /*.print_progress    =*/ true,
@@ -3776,6 +4218,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
 
         /*.logits_filter_callback           =*/ nullptr,
         /*.logits_filter_callback_user_data =*/ nullptr,
+
+        /*.grammar_rules   =*/ nullptr,
+        /*.n_grammar_rules =*/ 0,
+        /*.i_start_rule    =*/ 0,
+        /*.grammar_penalty =*/ 100.0f,
     };
 
     switch (strategy) {
@@ -3927,6 +4374,11 @@ static void whisper_process_logits(
         // suppress <|notimestamps|> token
         // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
         logits[vocab.token_not] = -INFINITY;
+        if (params.no_timestamps) {
+            for (int i = vocab.token_beg; i < n_logits; ++i) {
+                logits[i] = -INFINITY;
+            }
+        }
 
         // suppress sot and nosp tokens
         logits[vocab.token_sot]  = -INFINITY;
@@ -3942,6 +4394,14 @@ static void whisper_process_logits(
         logits[vocab.token_transcribe] = -INFINITY;
         logits[vocab.token_prev]       = -INFINITY;
 
+        // suppress lang tokens
+        for (size_t i = 0; i < g_lang.size(); ++i) {
+            logits[whisper_token_lang(&ctx, i)] = -INFINITY;
+        }
+
+        // suppress prev token
+        logits[vocab.token_prev] = -INFINITY;
+
         if (params.logits_filter_callback) {
             params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
         }
@@ -4052,10 +4512,33 @@ static void whisper_process_logits(
             //WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
 
             if (timestamp_logprob > max_text_token_logprob) {
+                //printf("sampling timestamp\n");
                 for (int i = 0; i < vocab.token_beg; ++i) {
                     logits[i]   = -INFINITY;
                     logprobs[i] = -INFINITY;
                 }
+            } else if (params.n_grammar_rules > 0) {
+                whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar);
+
+                // populate the logprobs array (log_softmax)
+                {
+                    const float logit_max = *std::max_element(logits.begin(), logits.end());
+                    float logsumexp = 0.0f;
+                    for (int i = 0; i < n_logits; ++i) {
+                        if (logits[i] > -INFINITY) {
+                            logsumexp += expf(logits[i] - logit_max);
+                        }
+                    }
+                    logsumexp = logf(logsumexp) + logit_max;
+
+                    for (int i = 0; i < n_logits; ++i) {
+                        if (logits[i] > -INFINITY) {
+                            logprobs[i] = logits[i] - logsumexp;
+                        } else {
+                            logprobs[i] = -INFINITY;
+                        }
+                    }
+                }
             }
         }
     }
@@ -4073,32 +4556,55 @@ static void whisper_process_logits(
 
 #if 0
     // print first 100 logits - token string : logit
-    for (int i = 0; i < 100; i++) {
-        const auto token   = vocab.id_to_token.at(i);
-        const auto prob    = probs[i];
-        const auto logit   = logits[i];
-        const auto logprob = logprobs[i];
-        printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
+    //for (int i = 0; i < 10; i++) {
+    //    const auto token   = vocab.id_to_token.at(i);
+    //    const auto prob    = probs[i];
+    //    const auto logit   = logits[i];
+    //    const auto logprob = logprobs[i];
+    //    printf("%16s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
+    //}
+
+    // print sorted
+    {
+        std::vector<std::pair<float, int>> pairs;
+
+        for (int i = 0; i < n_logits; ++i) {
+            pairs.push_back(std::make_pair(probs[i], i));
+        }
+
+        std::sort(pairs.begin(), pairs.end(), [](const std::pair<float, int>& a, const std::pair<float, int>& b) {
+            return a.first > b.first;
+        });
+
+        for (int i = 0; i < 10; i++) {
+            const auto token   = vocab.id_to_token.at(pairs[i].second);
+            const auto prob    = pairs[i].first;
+            const auto logit   = logits[pairs[i].second];
+            const auto logprob = logprobs[pairs[i].second];
+            printf("%16s : id=%6d prob=%9.5f logit=%9.5f logprob=%9.5f '%s'\n", token.c_str(), pairs[i].second, prob, logit, logprob, token.c_str());
+        }
+
+        printf("----------------\n");
     }
 
     // "And", "and", " And", " and"
-    printf("logits[\"and\"]  = %f\n", logits[vocab.token_to_id.at("and")]);
-    printf("logits[\"And\"]  = %f\n", logits[vocab.token_to_id.at("And")]);
-    printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
-    printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
-    printf("logits[\" so\"]  = %f\n", logits[vocab.token_to_id.at(" so")]);
-
-    printf("logprobs[\"and\"]  = %f\n", logprobs[vocab.token_to_id.at("and")]);
-    printf("logprobs[\"And\"]  = %f\n", logprobs[vocab.token_to_id.at("And")]);
-    printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
-    printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
-    printf("logprobs[\" so\"]  = %f\n", logprobs[vocab.token_to_id.at(" so")]);
-
-    printf("probs[\"and\"]  = %f\n", probs[vocab.token_to_id.at("and")]);
-    printf("probs[\"And\"]  = %f\n", probs[vocab.token_to_id.at("And")]);
-    printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
-    printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
-    printf("probs[\" so\"]  = %f\n", probs[vocab.token_to_id.at(" so")]);
+    //printf("logits[\"and\"]  = %f\n", logits[vocab.token_to_id.at("and")]);
+    //printf("logits[\"And\"]  = %f\n", logits[vocab.token_to_id.at("And")]);
+    //printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
+    //printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
+    //printf("logits[\" so\"]  = %f\n", logits[vocab.token_to_id.at(" so")]);
+
+    //printf("logprobs[\"and\"]  = %f\n", logprobs[vocab.token_to_id.at("and")]);
+    //printf("logprobs[\"And\"]  = %f\n", logprobs[vocab.token_to_id.at("And")]);
+    //printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
+    //printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
+    //printf("logprobs[\" so\"]  = %f\n", logprobs[vocab.token_to_id.at(" so")]);
+
+    //printf("probs[\"and\"]  = %f\n", probs[vocab.token_to_id.at("and")]);
+    //printf("probs[\"And\"]  = %f\n", probs[vocab.token_to_id.at("And")]);
+    //printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
+    //printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
+    //printf("probs[\" so\"]  = %f\n", probs[vocab.token_to_id.at(" so")]);
 #endif
 }
 
@@ -4223,8 +4729,11 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
         ptsum = sum_ts;
     }
 
+    std::discrete_distribution<> dist(probs.begin(), probs.end());
+
     for (int i = 0; i < k; ++i) {
-        const auto id = logits_id[i].second;
+        const auto id = dist(state.rng);
+        //printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
 
         result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
 
@@ -4553,7 +5062,7 @@ int whisper_full_with_state(
     state->exp_n_audio_ctx = params.audio_ctx;
 
     // these tokens determine the task that will be performed
-    std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
+    std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx), };
 
     if (whisper_is_multilingual(ctx)) {
         const int lang_id = whisper_lang_id(params.language);
@@ -4566,17 +5075,19 @@ int whisper_full_with_state(
         }
     }
 
+    // distilled models require the "no_timestamps" token
     {
         const bool is_distil = ctx->model.hparams.n_text_layer == 2;
-
-        // distilled models require the "no_timestamps" token
-        // TODO: add input parameter (#1229)
-        if (is_distil) {
+        if (is_distil && !params.no_timestamps) {
             WHISPER_LOG_WARN("%s: using distilled model - forcing no_timestamps\n", __func__);
-            prompt_init.push_back(whisper_token_not(ctx));
+            params.no_timestamps = true;
         }
     }
 
+    if (params.no_timestamps) {
+        prompt_init.push_back(whisper_token_not(ctx));
+    }
+
     int seek = seek_start;
 
     std::vector<whisper_token> prompt;
@@ -4652,7 +5163,7 @@ int whisper_full_with_state(
 
             n_decoders_cur = std::max(1, n_decoders_cur);
 
-            WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur);
+            WHISPER_PRINT_DEBUG("\n%s: strategy = %d, decoding with %d decoders, temperature = %.2f\n", __func__, params.strategy, n_decoders_cur, t_cur);
 
             // TAGS: WHISPER_DECODER_INIT
             for (int j = 0; j < n_decoders_cur; ++j) {
@@ -4673,6 +5184,13 @@ int whisper_full_with_state(
                 decoder.failed    = false;
                 decoder.completed = false;
                 decoder.has_ts    = false;
+
+                if (params.grammar_rules != nullptr) {
+                    decoder.grammar = whisper_grammar_init(
+                        params.grammar_rules, params.n_grammar_rules, params.i_start_rule);
+                } else {
+                    decoder.grammar = {};
+                }
             }
 
             // init prompt and kv cache for the current iteration
@@ -4790,6 +5308,10 @@ int whisper_full_with_state(
                             continue;
                         }
 
+                        if (cur_c >= beam_candidates.size()) {
+                            cur_c = 0;
+                        }
+
                         auto & cur = beam_candidates[cur_c++];
 
                         while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
@@ -4844,6 +5366,8 @@ int whisper_full_with_state(
                             has_ts = true;
                         }
 
+                        whisper_grammar_accept_token(*ctx, decoder.grammar, token.id);
+
 #ifdef WHISPER_DEBUG
                         {
                             const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]";
index 0ea5237e5f2f7c0d51a2432439f8d425e23b78e6..50f84a822d083e8ac6965b66093f2b5af1146538 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -109,6 +109,37 @@ extern "C" {
         void  (*close)(void * ctx);
     } whisper_model_loader;
 
+    // grammar element type
+    enum whisper_gretype {
+        // end of rule definition
+        WHISPER_GRETYPE_END            = 0,
+
+        // start of alternate definition for rule
+        WHISPER_GRETYPE_ALT            = 1,
+
+        // non-terminal element: reference to rule
+        WHISPER_GRETYPE_RULE_REF       = 2,
+
+        // terminal element: character (code point)
+        WHISPER_GRETYPE_CHAR           = 3,
+
+        // inverse char(s) ([^a], [^a-b] [^abc])
+        WHISPER_GRETYPE_CHAR_NOT       = 4,
+
+        // modifies a preceding WHISPER_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
+        // be an inclusive range ([a-z])
+        WHISPER_GRETYPE_CHAR_RNG_UPPER = 5,
+
+        // modifies a preceding WHISPER_GRETYPE_CHAR or
+        // WHISPER_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
+        WHISPER_GRETYPE_CHAR_ALT       = 6,
+    };
+
+    typedef struct whisper_grammar_element {
+        enum whisper_gretype type;
+        uint32_t             value; // Unicode code point or rule ID
+    } whisper_grammar_element;
+
     // Various functions for loading a ggml whisper model.
     // Allocate (almost) all memory needed for the model.
     // Return NULL on failure
@@ -402,6 +433,7 @@ extern "C" {
 
         bool translate;
         bool no_context;        // do not use past transcription (if any) as initial prompt for the decoder
+        bool no_timestamps;     // do not generate timestamps
         bool single_segment;    // force single segment output (useful for streaming)
         bool print_special;     // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
         bool print_progress;    // print progress information
@@ -479,6 +511,11 @@ extern "C" {
         // called by each decoder to filter obtained logits
         whisper_logits_filter_callback logits_filter_callback;
         void * logits_filter_callback_user_data;
+
+        const whisper_grammar_element ** grammar_rules;
+        size_t                           n_grammar_rules;
+        size_t                           i_start_rule;
+        float                            grammar_penalty;
     };
 
     // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params()