]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
main : add command-style grammar (#1998)
authorulatekh <redacted>
Thu, 28 Mar 2024 10:02:10 +0000 (03:02 -0700)
committerGitHub <redacted>
Thu, 28 Mar 2024 10:02:10 +0000 (12:02 +0200)
* Implemented command-style grammar in the main example.

Mostly just copied the relevant parts from the command example.

* main : code style

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/main/main.cpp

index 415c3b33de724836cf20f5d1f9cc00c4ab27c301..42b067e718d4d4bb116f8a4c9f3eb1f29d8281d5 100644 (file)
@@ -1,6 +1,7 @@
 #include "common.h"
 
 #include "whisper.h"
+#include "grammar-parser.h"
 
 #include <cmath>
 #include <fstream>
@@ -38,9 +39,10 @@ struct whisper_params {
     int32_t beam_size     = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
     int32_t audio_ctx     = 0;
 
-    float word_thold    =  0.01f;
-    float entropy_thold =  2.40f;
-    float logprob_thold = -1.00f;
+    float word_thold      =  0.01f;
+    float entropy_thold   =  2.40f;
+    float logprob_thold   = -1.00f;
+    float grammar_penalty = 100.0f;
 
     bool speed_up        = false;
     bool debug_mode      = false;
@@ -70,6 +72,8 @@ struct whisper_params {
     std::string prompt;
     std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
     std::string model     = "models/ggml-base.en.bin";
+    std::string grammar;
+    std::string grammar_rule;
 
     // [TDRZ] speaker turn string
     std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
@@ -80,6 +84,8 @@ struct whisper_params {
 
     std::vector<std::string> fname_inp = {};
     std::vector<std::string> fname_out = {};
+
+    grammar_parser::parse_state grammar_parsed;
 };
 
 void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@@ -154,6 +160,9 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-dtw"  || arg == "--dtw")             { params.dtw             = argv[++i]; }
         else if (arg == "-ls"   || arg == "--log-score")       { params.log_score       = true; }
         else if (arg == "-ng"   || arg == "--no-gpu")          { params.use_gpu         = false; }
+        else if (                  arg == "--grammar")         { params.grammar         = argv[++i]; }
+        else if (                  arg == "--grammar-rule")    { params.grammar_rule    = argv[++i]; }
+        else if (                  arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
         else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
             whisper_print_usage(argc, argv, params);
@@ -214,6 +223,9 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -dtw MODEL --dtw MODEL         [%-7s] compute token-level timestamps\n",                 params.dtw.c_str());
     fprintf(stderr, "  -ls,       --log-score         [%-7s] log best decoder scores of tokens\n",              params.log_score?"true":"false");
     fprintf(stderr, "  -ng,       --no-gpu            [%-7s] disable GPU\n",                                    params.use_gpu ? "false" : "true");
+    fprintf(stderr, "  --grammar GRAMMAR              [%-7s] GBNF grammar to guide decoding\n",                 params.grammar.c_str());
+    fprintf(stderr, "  --grammar-rule RULE            [%-7s] top-level GBNF grammar rule name\n",               params.grammar_rule.c_str());
+    fprintf(stderr, "  --grammar-penalty N            [%-7.1f] scales down logits of nongrammar tokens\n",      params.grammar_penalty);
     fprintf(stderr, "\n");
 }
 
@@ -926,6 +938,29 @@ int main(int argc, char ** argv) {
     // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
     whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
 
+    if (!params.grammar.empty()) {
+        auto & grammar = params.grammar_parsed;
+        if (is_file_exist(params.grammar.c_str())) {
+            // read grammar from file
+            std::ifstream ifs(params.grammar.c_str());
+            const std::string txt = std::string((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
+            grammar = grammar_parser::parse(txt.c_str());
+        } else {
+            // read grammar from string
+            grammar = grammar_parser::parse(params.grammar.c_str());
+        }
+
+        // will be empty (default) if there are parse errors
+        if (grammar.rules.empty()) {
+            fprintf(stderr, "error: failed to parse grammar \"%s\"\n", params.grammar.c_str());
+            return 4;
+        } else {
+            fprintf(stderr, "%s: grammar:\n", __func__);
+            grammar_parser::print_grammar(stderr, grammar);
+            fprintf(stderr, "\n");
+        }
+    }
+
     for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
         const auto fname_inp = params.fname_inp[f];
                const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
@@ -972,7 +1007,8 @@ int main(int argc, char ** argv) {
         {
             whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
 
-            wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
+            const bool use_grammar = (!params.grammar_parsed.rules.empty() && !params.grammar_rule.empty());
+            wparams.strategy = (params.beam_size > 1 || use_grammar) ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
 
             wparams.print_realtime   = false;
             wparams.print_progress   = params.print_progress;
@@ -1010,6 +1046,20 @@ int main(int argc, char ** argv) {
 
             whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
 
+            const auto & grammar_parsed = params.grammar_parsed;
+            auto grammar_rules = grammar_parsed.c_rules();
+
+            if (use_grammar) {
+                if (grammar_parsed.symbol_ids.find(params.grammar_rule) == grammar_parsed.symbol_ids.end()) {
+                    fprintf(stderr, "%s: warning: grammar rule '%s' not found - skipping grammar sampling\n", __func__, params.grammar_rule.c_str());
+                } else {
+                    wparams.grammar_rules = grammar_rules.data();
+                    wparams.n_grammar_rules = grammar_rules.size();
+                    wparams.i_start_rule = grammar_parsed.symbol_ids.at(params.grammar_rule);
+                    wparams.grammar_penalty = params.grammar_penalty;
+                }
+            }
+
             // this callback is called on each new segment
             if (!wparams.print_realtime) {
                 wparams.new_segment_callback           = whisper_print_segment_callback;