stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
-command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
- $(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
+command: examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
+ $(CXX) $(CXXFLAGS) examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS)
common.cpp
common-ggml.h
common-ggml.cpp
+ grammar-parser.cpp
)
include(DefaultTargetOptions)
#include "common-sdl.h"
#include "common.h"
#include "whisper.h"
+#include "grammar-parser.h"
#include <sstream>
#include <cassert>
#include <vector>
#include <map>
+bool file_exists(const std::string & fname) {
+ std::ifstream f(fname.c_str());
+ return f.good();
+}
+
// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t max_tokens = 32;
int32_t audio_ctx = 0;
- float vad_thold = 0.6f;
- float freq_thold = 100.0f;
+ float vad_thold = 0.6f;
+ float freq_thold = 100.0f;
+
+ float grammar_penalty = 100.0f;
+
+ grammar_parser::parse_state grammar_parsed;
bool speed_up = false;
bool translate = false;
std::string fname_out;
std::string commands;
std::string prompt;
+ std::string context;
+ std::string grammar;
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; }
+ else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
+ else if ( arg == "--grammar") { params.grammar = argv[++i]; }
+ else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str());
fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str());
+ fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str());
+ fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
+ fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
fprintf(stderr, "\n");
}
-std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
+std::string transcribe(
+ whisper_context * ctx,
+ const whisper_params & params,
+ const std::vector<float> & pcmf32,
+ const std::string & grammar_rule,
+ float & logprob_min,
+ float & logprob_sum,
+ int & n_tokens,
+ int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now();
- prob = 0.0f;
+ logprob_min = 0.0f;
+ logprob_sum = 0.0f;
+ n_tokens = 0;
t_ms = 0;
- whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
+ //whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
+ whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH);
wparams.print_progress = false;
wparams.print_special = params.print_special;
wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate;
wparams.no_context = true;
+ wparams.no_timestamps = params.no_timestamps;
wparams.single_segment = true;
wparams.max_tokens = params.max_tokens;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
- wparams.audio_ctx = params.audio_ctx;
- wparams.speed_up = params.speed_up;
+ wparams.audio_ctx = params.audio_ctx;
+ wparams.speed_up = params.speed_up;
+
+ wparams.temperature = 0.4f;
+ wparams.temperature_inc = 1.0f;
+ wparams.greedy.best_of = 5;
+
+ wparams.beam_search.beam_size = 5;
+
+ wparams.initial_prompt = params.context.data();
+
+ const auto & grammar_parsed = params.grammar_parsed;
+ auto grammar_rules = grammar_parsed.c_rules();
+
+ if (!params.grammar_parsed.rules.empty() && !grammar_rule.empty()) {
+ if (grammar_parsed.symbol_ids.find(grammar_rule) == grammar_parsed.symbol_ids.end()) {
+ fprintf(stderr, "%s: warning: grammar rule '%s' not found - skipping grammar sampling\n", __func__, grammar_rule.c_str());
+ } else {
+ wparams.grammar_rules = grammar_rules.data();
+ wparams.n_grammar_rules = grammar_rules.size();
+ wparams.i_start_rule = grammar_parsed.symbol_ids.at(grammar_rule);
+ wparams.grammar_penalty = params.grammar_penalty;
+ }
+ }
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
return "";
}
- int prob_n = 0;
std::string result;
const int n_segments = whisper_full_n_segments(ctx);
result += text;
- const int n_tokens = whisper_full_n_tokens(ctx, i);
- for (int j = 0; j < n_tokens; ++j) {
+ const int n = whisper_full_n_tokens(ctx, i);
+ for (int j = 0; j < n; ++j) {
const auto token = whisper_full_get_token_data(ctx, i, j);
- prob += token.p;
- ++prob_n;
+ if(token.plog > 0.0f) exit(0);
+ logprob_min = std::min(logprob_min, token.plog);
+ logprob_sum += token.plog;
+ ++n_tokens;
}
}
- if (prob_n > 0) {
- prob /= prob_n;
- }
-
const auto t_end = std::chrono::high_resolution_clock::now();
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
fprintf(stderr, " ]\n");
}
- std::string k_prompt = "select one from the available words: ";
+ std::string k_prompt = "select one from the available words: ";
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
if (i > 0) {
k_prompt += ", ";
bool is_running = true;
bool ask_prompt = true;
- float prob = 0.0f;
+ float logprob_min = 0.0f;
+ float logprob_sum = 0.0f;
+ int n_tokens = 0;
std::vector<float> pcmf32_cur;
// detect the commands
audio.get(params.command_ms, pcmf32_cur);
- const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
+ const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "", logprob_min, logprob_sum, n_tokens, t_ms));
const auto words = get_words(txt);
// general-purpose mode
// freely transcribe the voice into text
-int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) {
+int process_general_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
bool is_running = true;
bool have_prompt = false;
bool ask_prompt = true;
- float prob0 = 0.0f;
- float prob = 0.0f;
+ float logprob_min0 = 0.0f;
+ float logprob_min = 0.0f;
+
+ float logprob_sum0 = 0.0f;
+ float logprob_sum = 0.0f;
+
+ int n_tokens0 = 0;
+ int n_tokens = 0;
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;
- const std::string k_prompt = "Ok Whisper, start listening for commands.";
+ std::string k_prompt = "Ok Whisper, start listening for commands.";
+ if (!params.prompt.empty()) {
+ k_prompt = params.prompt;
+ }
fprintf(stderr, "\n");
fprintf(stderr, "%s: general-purpose mode\n", __func__);
// wait for activation phrase
audio.get(params.prompt_ms, pcmf32_cur);
- const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
+ const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "prompt", logprob_min0, logprob_sum0, n_tokens0, t_ms));
- fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
+ const float p = 100.0f * std::exp(logprob_min0);
+
+ fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms, p = %.2f%%)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms, p);
const float sim = similarity(txt, k_prompt);
// we have heard the activation phrase, now detect the commands
audio.get(params.command_ms, pcmf32_cur);
+ //printf("len prompt: %.4f\n", pcmf32_prompt.size() / (float) WHISPER_SAMPLE_RATE);
+ //printf("len command: %.4f\n", pcmf32_cur.size() / (float) WHISPER_SAMPLE_RATE);
+
+ // prepend 3 second of silence
+ pcmf32_cur.insert(pcmf32_cur.begin(), 3.0f*WHISPER_SAMPLE_RATE, 0.0f);
+
// prepend the prompt audio
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
- const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
+ const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "root", logprob_min, logprob_sum, n_tokens, t_ms));
- prob = 100.0f*(prob - prob0);
+ //const float p = 100.0f * std::exp((logprob - logprob0) / (n_tokens - n_tokens0));
+ const float p = 100.0f * std::exp(logprob_min);
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
// find the prompt in the text
float best_sim = 0.0f;
size_t best_len = 0;
- for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
+ for (size_t n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
+ if (n >= txt.size()) {
+ break;
+ }
+
const auto prompt = txt.substr(0, n);
const float sim = similarity(prompt, k_prompt);
}
}
- const std::string command = ::trim(txt.substr(best_len));
+ fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p);
+ if (best_len == 0) {
+ fprintf(stdout, "%s: WARNING: command not recognized, try again\n", __func__);
+ } else {
+ // cut the prompt from the decoded text
+ const std::string command = ::trim(txt.substr(best_len));
+
+ fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
+ }
- fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
fprintf(stdout, "\n");
}
int ret_val = 0;
- if (!params.commands.empty()) {
- ret_val = process_command_list(ctx, audio, params);
- } else if (!params.prompt.empty()) {
- ret_val = always_prompt_transcription(ctx, audio, params);
- } else {
- ret_val = process_general_transcription(ctx, audio, params);
+ if (!params.grammar.empty()) {
+ auto & grammar = params.grammar_parsed;
+ if (file_exists(params.grammar.c_str())) {
+ // read grammar from file
+ std::ifstream ifs(params.grammar.c_str());
+ const std::string txt = std::string((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
+ grammar = grammar_parser::parse(txt.c_str());
+ } else {
+ // read grammar from string
+ grammar = grammar_parser::parse(params.grammar.c_str());
+ }
+
+ // will be empty (default) if there are parse errors
+ if (grammar.rules.empty()) {
+ ret_val = 1;
+ } else {
+ fprintf(stderr, "%s: grammar:\n", __func__);
+ grammar_parser::print_grammar(stderr, grammar);
+ fprintf(stderr, "\n");
+ }
+ }
+
+ if (ret_val == 0) {
+ if (!params.commands.empty()) {
+ ret_val = process_command_list(ctx, audio, params);
+ } else if (!params.prompt.empty() && params.grammar_parsed.rules.empty()) {
+ ret_val = always_prompt_transcription(ctx, audio, params);
+ } else {
+ ret_val = process_general_transcription(ctx, audio, params);
+ }
}
audio.pause();
--- /dev/null
+#include "grammar-parser.h"
+#include <cstdint>
+#include <cwchar>
+#include <string>
+#include <utility>
+#include <stdexcept>
+#include <exception>
+
+namespace grammar_parser {
+ // NOTE: assumes valid utf8 (but checks for overrun)
+ // copied from whisper.cpp
+ std::pair<uint32_t, const char *> decode_utf8(const char * src) {
+ static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
+ uint8_t first_byte = static_cast<uint8_t>(*src);
+ uint8_t highbits = first_byte >> 4;
+ int len = lookup[highbits];
+ uint8_t mask = (1 << (8 - len)) - 1;
+ uint32_t value = first_byte & mask;
+ const char * end = src + len; // may overrun!
+ const char * pos = src + 1;
+ for ( ; pos < end && *pos; pos++) {
+ value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
+ }
+ return std::make_pair(value, pos);
+ }
+
+ uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) {
+ uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
+ auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id));
+ return result.first->second;
+ }
+
+ uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) {
+ uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
+ state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
+ return next_id;
+ }
+
+ void add_rule(
+ parse_state & state,
+ uint32_t rule_id,
+ const std::vector<whisper_grammar_element> & rule) {
+ if (state.rules.size() <= rule_id) {
+ state.rules.resize(rule_id + 1);
+ }
+ state.rules[rule_id] = rule;
+ }
+
+ bool is_word_char(char c) {
+ return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
+ }
+
+ std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
+ const char * pos = src;
+ const char * end = src + size;
+ uint32_t value = 0;
+ for ( ; pos < end && *pos; pos++) {
+ value <<= 4;
+ char c = *pos;
+ if ('a' <= c && c <= 'f') {
+ value += c - 'a' + 10;
+ } else if ('A' <= c && c <= 'F') {
+ value += c - 'A' + 10;
+ } else if ('0' <= c && c <= '9') {
+ value += c - '0';
+ } else {
+ break;
+ }
+ }
+ if (pos != end) {
+ throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
+ }
+ return std::make_pair(value, pos);
+ }
+
+ const char * parse_space(const char * src, bool newline_ok) {
+ const char * pos = src;
+ while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
+ (newline_ok && (*pos == '\r' || *pos == '\n'))) {
+ if (*pos == '#') {
+ while (*pos && *pos != '\r' && *pos != '\n') {
+ pos++;
+ }
+ } else {
+ pos++;
+ }
+ }
+ return pos;
+ }
+
+ const char * parse_name(const char * src) {
+ const char * pos = src;
+ while (is_word_char(*pos)) {
+ pos++;
+ }
+ if (pos == src) {
+ throw std::runtime_error(std::string("expecting name at ") + src);
+ }
+ return pos;
+ }
+
+ std::pair<uint32_t, const char *> parse_char(const char * src) {
+ if (*src == '\\') {
+ switch (src[1]) {
+ case 'x': return parse_hex(src + 2, 2);
+ case 'u': return parse_hex(src + 2, 4);
+ case 'U': return parse_hex(src + 2, 8);
+ case 't': return std::make_pair('\t', src + 2);
+ case 'r': return std::make_pair('\r', src + 2);
+ case 'n': return std::make_pair('\n', src + 2);
+ case '\\':
+ case '"':
+ case '[':
+ case ']':
+ return std::make_pair(src[1], src + 2);
+ default:
+ throw std::runtime_error(std::string("unknown escape at ") + src);
+ }
+ } else if (*src) {
+ return decode_utf8(src);
+ }
+ throw std::runtime_error("unexpected end of input");
+ }
+
+ const char * parse_alternates(
+ parse_state & state,
+ const char * src,
+ const std::string & rule_name,
+ uint32_t rule_id,
+ bool is_nested);
+
+ const char * parse_sequence(
+ parse_state & state,
+ const char * src,
+ const std::string & rule_name,
+ std::vector<whisper_grammar_element> & out_elements,
+ bool is_nested) {
+ size_t last_sym_start = out_elements.size();
+ const char * pos = src;
+ while (*pos) {
+ if (*pos == '"') { // literal string
+ pos++;
+ last_sym_start = out_elements.size();
+ while (*pos != '"') {
+ auto char_pair = parse_char(pos);
+ pos = char_pair.second;
+ out_elements.push_back({WHISPER_GRETYPE_CHAR, char_pair.first});
+ }
+ pos = parse_space(pos + 1, is_nested);
+ } else if (*pos == '[') { // char range(s)
+ pos++;
+ enum whisper_gretype start_type = WHISPER_GRETYPE_CHAR;
+ if (*pos == '^') {
+ pos++;
+ start_type = WHISPER_GRETYPE_CHAR_NOT;
+ }
+ last_sym_start = out_elements.size();
+ while (*pos != ']') {
+ auto char_pair = parse_char(pos);
+ pos = char_pair.second;
+ enum whisper_gretype type = last_sym_start < out_elements.size()
+ ? WHISPER_GRETYPE_CHAR_ALT
+ : start_type;
+
+ out_elements.push_back({type, char_pair.first});
+ if (pos[0] == '-' && pos[1] != ']') {
+ auto endchar_pair = parse_char(pos + 1);
+ pos = endchar_pair.second;
+ out_elements.push_back({WHISPER_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
+ }
+ }
+ pos = parse_space(pos + 1, is_nested);
+ } else if (is_word_char(*pos)) { // rule reference
+ const char * name_end = parse_name(pos);
+ uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos);
+ pos = parse_space(name_end, is_nested);
+ last_sym_start = out_elements.size();
+ out_elements.push_back({WHISPER_GRETYPE_RULE_REF, ref_rule_id});
+ } else if (*pos == '(') { // grouping
+ // parse nested alternates into synthesized rule
+ pos = parse_space(pos + 1, true);
+ uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
+ pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
+ last_sym_start = out_elements.size();
+ // output reference to synthesized rule
+ out_elements.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
+ if (*pos != ')') {
+ throw std::runtime_error(std::string("expecting ')' at ") + pos);
+ }
+ pos = parse_space(pos + 1, is_nested);
+ } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator
+ if (last_sym_start == out_elements.size()) {
+ throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos);
+ }
+
+ // apply transformation to previous symbol (last_sym_start to end) according to
+ // rewrite rules:
+ // S* --> S' ::= S S' |
+ // S+ --> S' ::= S S' | S
+ // S? --> S' ::= S |
+ uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
+ std::vector<whisper_grammar_element> sub_rule;
+ // add preceding symbol to generated rule
+ sub_rule.insert(
+ sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
+ if (*pos == '*' || *pos == '+') {
+ // cause generated rule to recurse
+ sub_rule.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
+ }
+ // mark start of alternate def
+ sub_rule.push_back({WHISPER_GRETYPE_ALT, 0});
+ if (*pos == '+') {
+ // add preceding symbol as alternate only for '+' (otherwise empty)
+ sub_rule.insert(
+ sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
+ }
+ sub_rule.push_back({WHISPER_GRETYPE_END, 0});
+ add_rule(state, sub_rule_id, sub_rule);
+
+ // in original rule, replace previous symbol with reference to generated rule
+ out_elements.resize(last_sym_start);
+ out_elements.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
+
+ pos = parse_space(pos + 1, is_nested);
+ } else {
+ break;
+ }
+ }
+ return pos;
+ }
+
+ const char * parse_alternates(
+ parse_state & state,
+ const char * src,
+ const std::string & rule_name,
+ uint32_t rule_id,
+ bool is_nested) {
+ std::vector<whisper_grammar_element> rule;
+ const char * pos = parse_sequence(state, src, rule_name, rule, is_nested);
+ while (*pos == '|') {
+ rule.push_back({WHISPER_GRETYPE_ALT, 0});
+ pos = parse_space(pos + 1, true);
+ pos = parse_sequence(state, pos, rule_name, rule, is_nested);
+ }
+ rule.push_back({WHISPER_GRETYPE_END, 0});
+ add_rule(state, rule_id, rule);
+ return pos;
+ }
+
+ const char * parse_rule(parse_state & state, const char * src) {
+ const char * name_end = parse_name(src);
+ const char * pos = parse_space(name_end, false);
+ size_t name_len = name_end - src;
+ uint32_t rule_id = get_symbol_id(state, src, name_len);
+ const std::string name(src, name_len);
+
+ if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
+ throw std::runtime_error(std::string("expecting ::= at ") + pos);
+ }
+ pos = parse_space(pos + 3, true);
+
+ pos = parse_alternates(state, pos, name, rule_id, false);
+
+ if (*pos == '\r') {
+ pos += pos[1] == '\n' ? 2 : 1;
+ } else if (*pos == '\n') {
+ pos++;
+ } else if (*pos) {
+ throw std::runtime_error(std::string("expecting newline or end at ") + pos);
+ }
+ return parse_space(pos, true);
+ }
+
+ parse_state parse(const char * src) {
+ try {
+ parse_state state;
+ const char * pos = parse_space(src, true);
+ while (*pos) {
+ pos = parse_rule(state, pos);
+ }
+ return state;
+ } catch (const std::exception & err) {
+ fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
+ return parse_state();
+ }
+ }
+
+ void print_grammar_char(FILE * file, uint32_t c) {
+ if (0x20 <= c && c <= 0x7f) {
+ fprintf(file, "%c", static_cast<char>(c));
+ } else {
+ // cop out of encoding UTF-8
+ fprintf(file, "<U+%04X>", c);
+ }
+ }
+
+ bool is_char_element(whisper_grammar_element elem) {
+ switch (elem.type) {
+ case WHISPER_GRETYPE_CHAR: return true;
+ case WHISPER_GRETYPE_CHAR_NOT: return true;
+ case WHISPER_GRETYPE_CHAR_ALT: return true;
+ case WHISPER_GRETYPE_CHAR_RNG_UPPER: return true;
+ default: return false;
+ }
+ }
+
+ void print_rule_binary(FILE * file, const std::vector<whisper_grammar_element> & rule) {
+ for (auto elem : rule) {
+ switch (elem.type) {
+ case WHISPER_GRETYPE_END: fprintf(file, "END"); break;
+ case WHISPER_GRETYPE_ALT: fprintf(file, "ALT"); break;
+ case WHISPER_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
+ case WHISPER_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
+ case WHISPER_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
+ case WHISPER_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
+ case WHISPER_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
+ }
+ switch (elem.type) {
+ case WHISPER_GRETYPE_END:
+ case WHISPER_GRETYPE_ALT:
+ case WHISPER_GRETYPE_RULE_REF:
+ fprintf(file, "(%u) ", elem.value);
+ break;
+ case WHISPER_GRETYPE_CHAR:
+ case WHISPER_GRETYPE_CHAR_NOT:
+ case WHISPER_GRETYPE_CHAR_RNG_UPPER:
+ case WHISPER_GRETYPE_CHAR_ALT:
+ fprintf(file, "(\"");
+ print_grammar_char(file, elem.value);
+ fprintf(file, "\") ");
+ break;
+ }
+ }
+ fprintf(file, "\n");
+ }
+
+ void print_rule(
+ FILE * file,
+ uint32_t rule_id,
+ const std::vector<whisper_grammar_element> & rule,
+ const std::map<uint32_t, std::string> & symbol_id_names) {
+ if (rule.empty() || rule.back().type != WHISPER_GRETYPE_END) {
+ throw std::runtime_error(
+ "malformed rule, does not end with WHISPER_GRETYPE_END: " + std::to_string(rule_id));
+ }
+ fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
+ for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
+ whisper_grammar_element elem = rule[i];
+ switch (elem.type) {
+ case WHISPER_GRETYPE_END:
+ throw std::runtime_error(
+ "unexpected end of rule: " + std::to_string(rule_id) + "," +
+ std::to_string(i));
+ case WHISPER_GRETYPE_ALT:
+ fprintf(file, "| ");
+ break;
+ case WHISPER_GRETYPE_RULE_REF:
+ fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
+ break;
+ case WHISPER_GRETYPE_CHAR:
+ fprintf(file, "[");
+ print_grammar_char(file, elem.value);
+ break;
+ case WHISPER_GRETYPE_CHAR_NOT:
+ fprintf(file, "[^");
+ print_grammar_char(file, elem.value);
+ break;
+ case WHISPER_GRETYPE_CHAR_RNG_UPPER:
+ if (i == 0 || !is_char_element(rule[i - 1])) {
+ throw std::runtime_error(
+ "WHISPER_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
+ std::to_string(rule_id) + "," + std::to_string(i));
+ }
+ fprintf(file, "-");
+ print_grammar_char(file, elem.value);
+ break;
+ case WHISPER_GRETYPE_CHAR_ALT:
+ if (i == 0 || !is_char_element(rule[i - 1])) {
+ throw std::runtime_error(
+ "WHISPER_GRETYPE_CHAR_ALT without preceding char: " +
+ std::to_string(rule_id) + "," + std::to_string(i));
+ }
+ print_grammar_char(file, elem.value);
+ break;
+ }
+ if (is_char_element(elem)) {
+ switch (rule[i + 1].type) {
+ case WHISPER_GRETYPE_CHAR_ALT:
+ case WHISPER_GRETYPE_CHAR_RNG_UPPER:
+ break;
+ default:
+ fprintf(file, "] ");
+ }
+ }
+ }
+ fprintf(file, "\n");
+ }
+
+ void print_grammar(FILE * file, const parse_state & state) {
+ try {
+ std::map<uint32_t, std::string> symbol_id_names;
+ for (auto kv : state.symbol_ids) {
+ symbol_id_names[kv.second] = kv.first;
+ }
+ for (size_t i = 0, end = state.rules.size(); i < end; i++) {
+ // fprintf(file, "%zu: ", i);
+ // print_rule_binary(file, state.rules[i]);
+ print_rule(file, uint32_t(i), state.rules[i], symbol_id_names);
+ // fprintf(file, "\n");
+ }
+ } catch (const std::exception & err) {
+ fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
+ }
+ }
+
+ std::vector<const whisper_grammar_element *> parse_state::c_rules() const{
+ std::vector<const whisper_grammar_element *> ret;
+ for (const auto & rule : rules) {
+ ret.push_back(rule.data());
+ }
+ return ret;
+ }
+}
--- /dev/null
+// Implements a parser for an extended Backus-Naur form (BNF), producing the
+// binary context-free grammar format specified by whisper.h. Supports character
+// ranges, grouping, and repetition operators. As an example, a grammar for
+// arithmetic might look like:
+//
+// root ::= expr
+// expr ::= term ([-+*/] term)*
+// term ::= num | "(" space expr ")" space
+// num ::= [0-9]+ space
+// space ::= [ \t\n]*
+
+#pragma once
+#include "whisper.h"
+#include <vector>
+#include <map>
+#include <cstdint>
+#include <string>
+
+namespace grammar_parser {
+ struct parse_state {
+ std::map<std::string, uint32_t> symbol_ids;
+ std::vector<std::vector<whisper_grammar_element>> rules;
+
+ std::vector<const whisper_grammar_element *> c_rules() const;
+ };
+
+ parse_state parse(const char * src);
+ void print_grammar(FILE * file, const parse_state & state);
+}
--- /dev/null
+# - "turn on lights."
+# - "set thermostat to 22."
+# - "increase TV by 10."
+# - "decrease oven by 50."
+# - "play music."
+# - "stop podcast."
+# - "schedule cleaning at 3pm."
+# - "cancel cleaning."
+# - "remind me to buy milk at 5pm."
+# - "show me security system."
+# - "hide washing machine."
+# - "what is the lights status?"
+# - "what is the current thermostat value?"
+# - "what is the security system status?"
+# - "what is the door lock status?"
+# - "what is the camera battery level?"
+# - "what is the weather like today?"
+# - "what is the forecast for tomorrow?"
+# - "what is the time?"
+# - "what is my schedule for today?"
+# - "what tasks do I have?"
+# - "what reminders do I have?"
+#
+# example:
+#
+# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/assistant.gbnf --prompt "Ok Whisper, start listening for commands." --context "Whisper is a home assistant. It recognizes voice commands. Time is 11pm." --grammar-penalty 10
+#
+
+root ::= init " " (command | question) "."
+prompt ::= init
+
+# leading space is very important!
+init ::= " Ok Whisper, start listening for commands."
+
+command ::= "Turn " ("on" | "off") " " device | "Set " device " to " value |
+ "Increase " device " by " value | "Decrease " device " by " value |
+ "Play " media | "Stop " media | "Schedule " task " at " time | "Cancel " task |
+ "Remind me to " task " at " time | "Show me " device | "Hide " device
+
+question ::= "What is the " device " status?" | "What is the current " device " value?" |
+ "What is the " device " temperature?" | "What is the " device " humidity?" |
+ "What is the " device " power consumption?" | "What is the " device " battery level?" |
+ "What is the weather like today?" | "What is the forecast for tomorrow?" |
+ "What is the time?" | "What is my schedule for today?" | "What tasks do I have?" |
+ "What reminders do I have?"
+
+device ::= "lights" | "thermostat" | "security system" | "door lock" | "camera" | "speaker" | "TV" |
+ "music player" | "coffee machine" | "oven" | "refrigerator" | "washing machine" |
+ "vacuum cleaner"
+
+value ::= [0-9]+
+
+media ::= "music" | "radio" | "podcast" | "audiobook" | "TV show" | "movie"
+
+task ::= [a-zA-Z]+ (" " [a-zA-Z]+)?
+
+time ::= [0-9] [0-9]? ("am" | "pm")?
--- /dev/null
+# - bishop to c3
+# - rook to d4
+# - knight to e5
+# - d4 d5 knight to c3
+# - c3 queen to d4 king b1
+# - pawn to a1 bishop to b2 knight to c3
+#
+# The prompt (--prompt) is the initial phrase that the user has to say.
+# This is used to prime Whisper with how the user is expected to speak.
+#
+# Provide long context (--context) with sample moves to help Whisper decode the correct sequence.
+# Longer context is better, but it slightly increases the processing time.
+#
+# example:
+#
+# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/chess.gbnf --prompt "rook to b4, f3," --context "d4 d5 knight to c3, pawn to a1, bishop to b2 king e8," --grammar-penalty 100
+#
+
+root ::= init move move? move? "."
+prompt ::= init "."
+
+# leading space is very important!
+init ::= " rook to b4, f3"
+
+move ::= ", " ((piece | pawn | king) " " "to "?)? [a-h] [1-8]
+
+piece ::= "bishop" | "rook" | "knight" | "queen"
+king ::= "king"
+pawn ::= "pawn"
--- /dev/null
+# - red
+# - green
+# - blue
+#
+# example:
+#
+# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/colors.gbnf --prompt "red, green, blue," --context "green, red, blue,"
+#
+
+root ::= init color "."
+prompt ::= init "."
+
+# leading space is very important!
+init ::= " red, green, blue"
+
+color ::= ", " ("red" | "green" | "blue")
std::map<std::string, struct ggml_tensor *> tensors;
};
+struct whisper_partial_utf8 {
+ uint32_t value; // bit value so far (unshifted)
+ int n_remain; // num bytes remaining; -1 indicates invalid sequence
+};
+
+struct whisper_grammar {
+ /*const*/ std::vector<std::vector<whisper_grammar_element>> rules;
+ std::vector<std::vector<const whisper_grammar_element *>> stacks;
+
+ // buffer for partially generated UTF-8 sequence from accepted tokens
+ whisper_partial_utf8 partial_utf8;
+};
+
+struct whisper_grammar_candidate {
+ whisper_token id;
+ const uint32_t * code_points;
+ whisper_partial_utf8 partial_utf8;
+};
+
struct whisper_sequence {
std::vector<whisper_token_data> tokens;
// the currently generated sequence of tokens
whisper_sequence sequence;
+ // grammar parse state of generated sequence of tokens
+ whisper_grammar grammar;
+
int seek_delta; // the window shift found so far based on the decoded timestamp tokens
bool failed; // has the current segment failed to decode?
return s.c_str();
}
+//////////////////////////////////
+// Grammar - ported from llama.cpp
+//////////////////////////////////
+
+// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
+// pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
+std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
+ const char * src,
+ whisper_partial_utf8 partial_start) {
+ static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
+ const char * pos = src;
+ std::vector<uint32_t> code_points;
+ uint32_t value = partial_start.value;
+ int n_remain = partial_start.n_remain;
+
+ // continue previous decode, if applicable
+ while (*pos != 0 && n_remain > 0) {
+ uint8_t next_byte = static_cast<uint8_t>(*pos);
+ if ((next_byte >> 6) != 2) {
+ // invalid sequence, abort
+ code_points.push_back(0);
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
+ }
+ value = (value << 6) + (next_byte & 0x3F);
+ ++pos;
+ --n_remain;
+ }
+
+ if (partial_start.n_remain > 0 && n_remain == 0) {
+ code_points.push_back(value);
+ }
+
+ // decode any subsequent utf-8 sequences, which may end in an incomplete one
+ while (*pos != 0) {
+ uint8_t first_byte = static_cast<uint8_t>(*pos);
+ uint8_t highbits = first_byte >> 4;
+ n_remain = lookup[highbits] - 1;
+
+ if (n_remain < 0) {
+ // invalid sequence, abort
+ code_points.clear();
+ code_points.push_back(0);
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain });
+ }
+
+ uint8_t mask = (1 << (7 - n_remain)) - 1;
+ value = first_byte & mask;
+ ++pos;
+ while (*pos != 0 && n_remain > 0) {
+ value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
+ ++pos;
+ --n_remain;
+ }
+ if (n_remain == 0) {
+ code_points.push_back(value);
+ }
+ }
+ code_points.push_back(0);
+
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ value, n_remain });
+}
+
+// returns true iff pos points to the end of one of the definitions of a rule
+static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element * pos) {
+ switch (pos->type) {
+ case WHISPER_GRETYPE_END: return true; // NOLINT
+ case WHISPER_GRETYPE_ALT: return true; // NOLINT
+ default: return false;
+ }
+}
+
+// returns true iff chr satisfies the char range at pos (regular or inverse range)
+// asserts that pos is pointing to a char range element
+static std::pair<bool, const whisper_grammar_element *> whisper_grammar_match_char(
+ const whisper_grammar_element * pos,
+ const uint32_t chr) {
+
+ bool found = false;
+ bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
+
+ WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT
+
+ do {
+ if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
+ // inclusive range, e.g. [a-z]
+ found = found || (pos->value <= chr && chr <= pos[1].value);
+ pos += 2;
+ } else {
+ // exact char match, e.g. [a] or "a"
+ found = found || pos->value == chr;
+ pos += 1;
+ }
+ } while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
+
+ return std::make_pair(found == is_positive_char, pos);
+}
+
+// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
+// range at pos (regular or inverse range)
+// asserts that pos is pointing to a char range element
+static bool whisper_grammar_match_partial_char(
+ const whisper_grammar_element * pos,
+ const whisper_partial_utf8 partial_utf8) {
+
+ bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
+ WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT);
+
+ uint32_t partial_value = partial_utf8.value;
+ int n_remain = partial_utf8.n_remain;
+
+ // invalid sequence or 7-bit char split across 2 bytes (overlong)
+ if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
+ return false;
+ }
+
+ // range of possible code points this partial UTF-8 sequence could complete to
+ uint32_t low = partial_value << (n_remain * 6);
+ uint32_t high = low | ((1 << (n_remain * 6)) - 1);
+
+ if (low == 0) {
+ if (n_remain == 2) {
+ low = 1 << 11;
+ } else if (n_remain == 3) {
+ low = 1 << 16;
+ }
+ }
+
+ do {
+ if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
+ // inclusive range, e.g. [a-z]
+ if (pos->value <= high && low <= pos[1].value) {
+ return is_positive_char;
+ }
+ pos += 2;
+ } else {
+ // exact char match, e.g. [a] or "a"
+ if (low <= pos->value && pos->value <= high) {
+ return is_positive_char;
+ }
+ pos += 1;
+ }
+ } while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
+
+ return !is_positive_char;
+}
+
+
+// transforms a grammar pushdown stack into N possible stacks, all ending
+// at a character range (terminal element)
+static void whisper_grammar_advance_stack(
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
+ const std::vector<const whisper_grammar_element *> & stack,
+ std::vector<std::vector<const whisper_grammar_element *>> & new_stacks) {
+
+ if (stack.empty()) {
+ new_stacks.push_back(stack);
+ return;
+ }
+
+ const whisper_grammar_element * pos = stack.back();
+
+ switch (pos->type) {
+ case WHISPER_GRETYPE_RULE_REF: {
+ const size_t rule_id = static_cast<size_t>(pos->value);
+ const whisper_grammar_element * subpos = rules[rule_id].data();
+ do {
+ // init new stack without the top (pos)
+ std::vector<const whisper_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
+ if (!whisper_grammar_is_end_of_sequence(pos + 1)) {
+ // if this rule ref is followed by another element, add that to stack
+ new_stack.push_back(pos + 1);
+ }
+ if (!whisper_grammar_is_end_of_sequence(subpos)) {
+ // if alternate is nonempty, add to stack
+ new_stack.push_back(subpos);
+ }
+ whisper_grammar_advance_stack(rules, new_stack, new_stacks);
+ while (!whisper_grammar_is_end_of_sequence(subpos)) {
+ // scan to end of alternate def
+ subpos++;
+ }
+ if (subpos->type == WHISPER_GRETYPE_ALT) {
+ // there's another alternate def of this rule to process
+ subpos++;
+ } else {
+ break;
+ }
+ } while (true);
+ break;
+ }
+ case WHISPER_GRETYPE_CHAR:
+ case WHISPER_GRETYPE_CHAR_NOT:
+ new_stacks.push_back(stack);
+ break;
+ default:
+ // end of alternate (WHISPER_GRETYPE_END, WHISPER_GRETYPE_ALT) or middle of char range
+ // (WHISPER_GRETYPE_CHAR_ALT, WHISPER_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
+ // those
+ WHISPER_ASSERT(false);
+ }
+}
+
+// takes a set of possible pushdown stacks on a grammar, which are required to
+// be positioned at a character range (see `whisper_grammar_advance_stack`), and
+// produces the N possible stacks if the given char is accepted at those
+// positions
+static std::vector<std::vector<const whisper_grammar_element *>> whisper_grammar_accept(
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
+ const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
+ const uint32_t chr) {
+
+ std::vector<std::vector<const whisper_grammar_element *>> new_stacks;
+
+ for (const auto & stack : stacks) {
+ if (stack.empty()) {
+ continue;
+ }
+
+ auto match = whisper_grammar_match_char(stack.back(), chr);
+ if (match.first) {
+ const whisper_grammar_element * pos = match.second;
+
+ // update top of stack to next element, if any
+ std::vector<const whisper_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
+ if (!whisper_grammar_is_end_of_sequence(pos)) {
+ new_stack.push_back(pos);
+ }
+ whisper_grammar_advance_stack(rules, new_stack, new_stacks);
+ }
+ }
+
+ return new_stacks;
+}
+
+static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates(
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
+ const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
+ const std::vector<whisper_grammar_candidate> & candidates);
+
+static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates_for_stack(
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
+ const std::vector<const whisper_grammar_element *> & stack,
+ const std::vector<whisper_grammar_candidate> & candidates) {
+
+ std::vector<whisper_grammar_candidate> rejects;
+
+ if (stack.empty()) {
+ for (auto tok : candidates) {
+ if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
+ rejects.push_back(tok);
+ }
+ }
+ return rejects;
+ }
+
+ const whisper_grammar_element * stack_pos = stack.back();
+
+ std::vector<whisper_grammar_candidate> next_candidates;
+ for (auto tok : candidates) {
+ if (*tok.code_points == 0) {
+ // reached end of full codepoints in token, reject iff it ended in a partial sequence
+ // that cannot satisfy this position in grammar
+ if (tok.partial_utf8.n_remain != 0 &&
+ !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
+ rejects.push_back(tok);
+ }
+ } else if (whisper_grammar_match_char(stack_pos, *tok.code_points).first) {
+ next_candidates.push_back({ tok.id, tok.code_points + 1, tok.partial_utf8 });
+ } else {
+ rejects.push_back(tok);
+ }
+ }
+
+ const auto * stack_pos_after = whisper_grammar_match_char(stack_pos, 0).second;
+
+ // update top of stack to next element, if any
+ std::vector<const whisper_grammar_element *> stack_after(stack.begin(), stack.end() - 1);
+ if (!whisper_grammar_is_end_of_sequence(stack_pos_after)) {
+ stack_after.push_back(stack_pos_after);
+ }
+ std::vector<std::vector<const whisper_grammar_element *>> next_stacks;
+ whisper_grammar_advance_stack(rules, stack_after, next_stacks);
+
+ auto next_rejects = whisper_grammar_reject_candidates(rules, next_stacks, next_candidates);
+ for (auto tok : next_rejects) {
+ rejects.push_back({ tok.id, tok.code_points - 1, tok.partial_utf8 });
+ }
+
+ return rejects;
+}
+
+static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates(
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
+ const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
+ const std::vector<whisper_grammar_candidate> & candidates) {
+ if (candidates.empty() || stacks.empty()) {
+ return std::vector<whisper_grammar_candidate>();
+ }
+
+ auto rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
+
+ for (size_t i = 1, size = stacks.size(); i < size; ++i) {
+ rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
+ }
+ return rejects;
+}
+
+static struct whisper_grammar whisper_grammar_init(
+ const whisper_grammar_element ** rules,
+ size_t n_rules,
+ size_t i_start_rule) {
+ const whisper_grammar_element * pos;
+
+ // copy rule definitions into vectors
+ std::vector<std::vector<whisper_grammar_element>> vec_rules(n_rules);
+ for (size_t i = 0; i < n_rules; i++) {
+ for (pos = rules[i]; pos->type != WHISPER_GRETYPE_END; pos++) {
+ vec_rules[i].push_back(*pos);
+ }
+ vec_rules[i].push_back({WHISPER_GRETYPE_END, 0});
+ }
+
+ // loop over alternates of start rule to build initial stacks
+ std::vector<std::vector<const whisper_grammar_element *>> stacks;
+ pos = rules[i_start_rule];
+ do {
+ std::vector<const whisper_grammar_element *> stack;
+ if (!whisper_grammar_is_end_of_sequence(pos)) {
+ // if alternate is nonempty, add to stack
+ stack.push_back(pos);
+ }
+ whisper_grammar_advance_stack(vec_rules, stack, stacks);
+ while (!whisper_grammar_is_end_of_sequence(pos)) {
+ // scan to end of alternate def
+ pos++;
+ }
+ if (pos->type == WHISPER_GRETYPE_ALT) {
+ // there's another alternate def of this rule to process
+ pos++;
+ } else {
+ break;
+ }
+ } while (true);
+
+ return { std::move(vec_rules), std::move(stacks), {} };
+}
+
+static void whisper_suppress_invalid_grammar(
+ whisper_context & ctx,
+ const whisper_full_params & params,
+ std::vector<float> & logits,
+ const whisper_grammar & grammar) {
+
+ if (grammar.rules.empty() || grammar.stacks.empty()) {
+ return;
+ }
+
+ //bool allow_eot = false;
+ //for (const auto & stack : grammar.stacks) {
+ // if (stack.empty()) {
+ // allow_eot = true;
+ // break;
+ // }
+ //}
+
+ const whisper_token eot = whisper_token_eot(&ctx);
+
+ std::vector<std::pair<std::vector<uint32_t>, whisper_partial_utf8>> candidates_decoded;
+ std::vector<whisper_grammar_candidate> candidates_grammar;
+
+ for (whisper_token id = 0; id < eot; ++id) {
+ const std::string & text = ctx.vocab.id_to_token[id];
+ if (!text.empty()) {
+ candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8));
+ candidates_grammar.push_back({ id, candidates_decoded.back().first.data(), candidates_decoded.back().second });
+ }
+ }
+
+ const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
+
+ for (const auto & reject : rejects) {
+ logits[reject.id] -= params.grammar_penalty;
+ }
+
+ // when the grammar allows a continuation, we penalize the end-of-text token
+ //if (!allow_eot) {
+ // logits[eot] -= params.grammar_penalty;
+ //}
+ //fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size());
+}
+
+static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar & grammar, whisper_token token) {
+ if (grammar.rules.empty() || grammar.stacks.empty()) {
+ return;
+ }
+
+ //fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str());
+
+ const std::string & text = ctx.vocab.id_to_token[token];
+
+ if (text.rfind("[_", 0) == 0) {
+ // fprintf(stderr, " (skipped)\n");
+ return;
+ }
+ // fprintf(stderr, "\n");
+
+ // Note terminating 0 in decoded string
+ const auto decoded = decode_utf8(text.c_str(), grammar.partial_utf8);
+ const auto & code_points = decoded.first;
+ for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
+ grammar.stacks = whisper_grammar_accept(grammar.rules, grammar.stacks, *it);
+ }
+ grammar.partial_utf8 = decoded.second;
+}
+
+//////////////
+// END grammar
+//////////////
+
////////////////////////////////////////////////////////////////////////////
struct whisper_context_params * whisper_context_default_params_by_ref() {
/*.translate =*/ false,
/*.no_context =*/ true,
+ /*.no_timestamps =*/ false,
/*.single_segment =*/ false,
/*.print_special =*/ false,
/*.print_progress =*/ true,
/*.logits_filter_callback =*/ nullptr,
/*.logits_filter_callback_user_data =*/ nullptr,
+
+ /*.grammar_rules =*/ nullptr,
+ /*.n_grammar_rules =*/ 0,
+ /*.i_start_rule =*/ 0,
+ /*.grammar_penalty =*/ 100.0f,
};
switch (strategy) {
// suppress <|notimestamps|> token
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
logits[vocab.token_not] = -INFINITY;
+ if (params.no_timestamps) {
+ for (int i = vocab.token_beg; i < n_logits; ++i) {
+ logits[i] = -INFINITY;
+ }
+ }
// suppress sot and nosp tokens
logits[vocab.token_sot] = -INFINITY;
logits[vocab.token_transcribe] = -INFINITY;
logits[vocab.token_prev] = -INFINITY;
+ // suppress lang tokens
+ for (size_t i = 0; i < g_lang.size(); ++i) {
+ logits[whisper_token_lang(&ctx, i)] = -INFINITY;
+ }
+
+ // suppress prev token
+ logits[vocab.token_prev] = -INFINITY;
+
if (params.logits_filter_callback) {
params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
}
//WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
if (timestamp_logprob > max_text_token_logprob) {
+ //printf("sampling timestamp\n");
for (int i = 0; i < vocab.token_beg; ++i) {
logits[i] = -INFINITY;
logprobs[i] = -INFINITY;
}
+ } else if (params.n_grammar_rules > 0) {
+ whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar);
+
+ // populate the logprobs array (log_softmax)
+ {
+ const float logit_max = *std::max_element(logits.begin(), logits.end());
+ float logsumexp = 0.0f;
+ for (int i = 0; i < n_logits; ++i) {
+ if (logits[i] > -INFINITY) {
+ logsumexp += expf(logits[i] - logit_max);
+ }
+ }
+ logsumexp = logf(logsumexp) + logit_max;
+
+ for (int i = 0; i < n_logits; ++i) {
+ if (logits[i] > -INFINITY) {
+ logprobs[i] = logits[i] - logsumexp;
+ } else {
+ logprobs[i] = -INFINITY;
+ }
+ }
+ }
}
}
}
#if 0
// print first 100 logits - token string : logit
- for (int i = 0; i < 100; i++) {
- const auto token = vocab.id_to_token.at(i);
- const auto prob = probs[i];
- const auto logit = logits[i];
- const auto logprob = logprobs[i];
- printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
+ //for (int i = 0; i < 10; i++) {
+ // const auto token = vocab.id_to_token.at(i);
+ // const auto prob = probs[i];
+ // const auto logit = logits[i];
+ // const auto logprob = logprobs[i];
+ // printf("%16s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
+ //}
+
+ // print sorted
+ {
+ std::vector<std::pair<float, int>> pairs;
+
+ for (int i = 0; i < n_logits; ++i) {
+ pairs.push_back(std::make_pair(probs[i], i));
+ }
+
+ std::sort(pairs.begin(), pairs.end(), [](const std::pair<float, int>& a, const std::pair<float, int>& b) {
+ return a.first > b.first;
+ });
+
+ for (int i = 0; i < 10; i++) {
+ const auto token = vocab.id_to_token.at(pairs[i].second);
+ const auto prob = pairs[i].first;
+ const auto logit = logits[pairs[i].second];
+ const auto logprob = logprobs[pairs[i].second];
+ printf("%16s : id=%6d prob=%9.5f logit=%9.5f logprob=%9.5f '%s'\n", token.c_str(), pairs[i].second, prob, logit, logprob, token.c_str());
+ }
+
+ printf("----------------\n");
}
// "And", "and", " And", " and"
- printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
- printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
- printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
- printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
- printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);
-
- printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
- printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
- printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
- printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
- printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
-
- printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
- printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
- printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
- printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
- printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
+ //printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
+ //printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
+ //printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
+ //printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
+ //printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);
+
+ //printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
+ //printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
+ //printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
+ //printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
+ //printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
+
+ //printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
+ //printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
+ //printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
+ //printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
+ //printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
#endif
}
ptsum = sum_ts;
}
+ std::discrete_distribution<> dist(probs.begin(), probs.end());
+
for (int i = 0; i < k; ++i) {
- const auto id = logits_id[i].second;
+ const auto id = dist(state.rng);
+ //printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
state->exp_n_audio_ctx = params.audio_ctx;
// these tokens determine the task that will be performed
- std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
+ std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx), };
if (whisper_is_multilingual(ctx)) {
const int lang_id = whisper_lang_id(params.language);
}
}
+ // distilled models require the "no_timestamps" token
{
const bool is_distil = ctx->model.hparams.n_text_layer == 2;
-
- // distilled models require the "no_timestamps" token
- // TODO: add input parameter (#1229)
- if (is_distil) {
+ if (is_distil && !params.no_timestamps) {
WHISPER_LOG_WARN("%s: using distilled model - forcing no_timestamps\n", __func__);
- prompt_init.push_back(whisper_token_not(ctx));
+ params.no_timestamps = true;
}
}
+ if (params.no_timestamps) {
+ prompt_init.push_back(whisper_token_not(ctx));
+ }
+
int seek = seek_start;
std::vector<whisper_token> prompt;
n_decoders_cur = std::max(1, n_decoders_cur);
- WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur);
+ WHISPER_PRINT_DEBUG("\n%s: strategy = %d, decoding with %d decoders, temperature = %.2f\n", __func__, params.strategy, n_decoders_cur, t_cur);
// TAGS: WHISPER_DECODER_INIT
for (int j = 0; j < n_decoders_cur; ++j) {
decoder.failed = false;
decoder.completed = false;
decoder.has_ts = false;
+
+ if (params.grammar_rules != nullptr) {
+ decoder.grammar = whisper_grammar_init(
+ params.grammar_rules, params.n_grammar_rules, params.i_start_rule);
+ } else {
+ decoder.grammar = {};
+ }
}
// init prompt and kv cache for the current iteration
continue;
}
+ if (cur_c >= beam_candidates.size()) {
+ cur_c = 0;
+ }
+
auto & cur = beam_candidates[cur_c++];
while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
has_ts = true;
}
+ whisper_grammar_accept_token(*ctx, decoder.grammar, token.id);
+
#ifdef WHISPER_DEBUG
{
const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]";
void (*close)(void * ctx);
} whisper_model_loader;
+ // grammar element type
+ enum whisper_gretype {
+ // end of rule definition
+ WHISPER_GRETYPE_END = 0,
+
+ // start of alternate definition for rule
+ WHISPER_GRETYPE_ALT = 1,
+
+ // non-terminal element: reference to rule
+ WHISPER_GRETYPE_RULE_REF = 2,
+
+ // terminal element: character (code point)
+ WHISPER_GRETYPE_CHAR = 3,
+
+ // inverse char(s) ([^a], [^a-b] [^abc])
+ WHISPER_GRETYPE_CHAR_NOT = 4,
+
+ // modifies a preceding WHISPER_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
+ // be an inclusive range ([a-z])
+ WHISPER_GRETYPE_CHAR_RNG_UPPER = 5,
+
+ // modifies a preceding WHISPER_GRETYPE_CHAR or
+ // WHISPER_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
+ WHISPER_GRETYPE_CHAR_ALT = 6,
+ };
+
+ typedef struct whisper_grammar_element {
+ enum whisper_gretype type;
+ uint32_t value; // Unicode code point or rule ID
+ } whisper_grammar_element;
+
// Various functions for loading a ggml whisper model.
// Allocate (almost) all memory needed for the model.
// Return NULL on failure
bool translate;
bool no_context; // do not use past transcription (if any) as initial prompt for the decoder
+ bool no_timestamps; // do not generate timestamps
bool single_segment; // force single segment output (useful for streaming)
bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
bool print_progress; // print progress information
// called by each decoder to filter obtained logits
whisper_logits_filter_callback logits_filter_callback;
void * logits_filter_callback_user_data;
+
+ const whisper_grammar_element ** grammar_rules;
+ size_t n_grammar_rules;
+ size_t i_start_rule;
+ float grammar_penalty;
};
// NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params()