]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
sync : latest whisper.cpp
authorGeorgi Gerganov <redacted>
Sun, 26 Feb 2023 19:10:50 +0000 (21:10 +0200)
committerGeorgi Gerganov <redacted>
Sun, 26 Feb 2023 19:10:50 +0000 (21:10 +0200)
examples/whisper/main.cpp
examples/whisper/whisper.cpp
examples/whisper/whisper.h

index 5bd7e424c61c85e4514c475b895d7c18b8f5b3b5..b8366b79f4581cc474930a99080d562f4374eb75 100644 (file)
@@ -91,12 +91,12 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
 bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
     for (int i = 1; i < argc; i++) {
         std::string arg = argv[i];
-           
+
         if (arg == "-"){
             params.fname_inp.push_back(arg);
             continue;
         }
-       
+
         if (arg[0] != '-') {
             params.fname_inp.push_back(arg);
             continue;
index 331d4084c6b7ed7a17bfb6462c2bb8fc2f735055..3a21581c682c641ac20ccc2de7e7c25470f0df64 100644 (file)
@@ -592,16 +592,16 @@ struct whisper_context {
 
     mutable std::mt19937 rng; // used for sampling at t > 0.0
 
-    int lang_id;
+    int lang_id = 0; // english by default
 
     // [EXPERIMENTAL] token-level timestamps data
-    int64_t t_beg;
-    int64_t t_last;
+    int64_t t_beg = 0;
+    int64_t t_last = 0;
     whisper_token tid_last;
     std::vector<float> energy; // PCM signal energy
 
     // [EXPERIMENTAL] speed-up techniques
-    int32_t exp_n_audio_ctx; // 0 - use default
+    int32_t exp_n_audio_ctx = 0; // 0 - use default
 
     void use_buf(struct ggml_context * ctx, int i) {
 #if defined(WHISPER_USE_SCRATCH)
@@ -805,7 +805,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
                      MEM_REQ_SCRATCH3.at (model.type) +
                 scale*MEM_REQ_MODEL.at   (model.type) +
                 scale*MEM_REQ_KV_CROSS.at(model.type) +
-                scale*std::max(MEM_REQ_ENCODE.at(model.type),       MEM_REQ_DECODE.at(model.type));
+                scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
 
             // this is the memory required by one decoder
             const size_t mem_required_decoder =
@@ -2936,7 +2936,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
         /*.language         =*/ "en",
 
         /*.suppress_blank   =*/ true,
-        /*.suppress_non_speech_tokens =*/true,
+        /*.suppress_non_speech_tokens =*/ false,
 
         /*.temperature      =*/  0.0f,
         /*.max_initial_ts   =*/  1.0f,
@@ -2962,6 +2962,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
 
         /*.encoder_begin_callback           =*/ nullptr,
         /*.encoder_begin_callback_user_data =*/ nullptr,
+
+        /*.logits_filter_callback           =*/ nullptr,
+        /*.logits_filter_callback_user_data =*/ nullptr,
     };
 
     switch (strategy) {
@@ -3078,8 +3081,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool
     return res;
 }
 
-static const std::vector<std::string> non_speech_tokens
-{
+static const std::vector<std::string> non_speech_tokens = {
     "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
     "_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--",
     "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
@@ -3090,7 +3092,7 @@ static const std::vector<std::string> non_speech_tokens
 // - applies logit filters
 // - computes logprobs and probs
 static void whisper_process_logits(
-        const struct whisper_context & ctx,
+              struct whisper_context & ctx,
     const struct whisper_full_params   params,
               struct whisper_decoder & decoder,
                                float   temperature) {
@@ -3146,29 +3148,27 @@ static void whisper_process_logits(
         logits[vocab.token_translate]  = -INFINITY;
         logits[vocab.token_transcribe] = -INFINITY;
 
+        if (params.logits_filter_callback) {
+            params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
+        }
 
         // suppress non-speech tokens
         // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
-        if (params.suppress_non_speech_tokens)
-        {
-            for (const std::string &token : non_speech_tokens)
-            {
-                std::string suppress_tokens[] = {token, " " + token};
-                for (const std::string &suppress_token : suppress_tokens)
-                {
-                    if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end())
-                    {
+        if (params.suppress_non_speech_tokens) {
+            for (const std::string & token : non_speech_tokens) {
+                const std::string suppress_tokens[] = {token, " " + token};
+                for (const std::string & suppress_token : suppress_tokens) {
+                    if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) {
                         logits[vocab.token_to_id.at(suppress_token)] = -INFINITY;
                     }
                 }
             }
+
             // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
-            if (vocab.token_to_id.find(" -") != vocab.token_to_id.end())
-            {
+            if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) {
                 logits[vocab.token_to_id.at(" -")] = -INFINITY;
             }
-            if (vocab.token_to_id.find(" '") != vocab.token_to_id.end())
-            {
+            if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) {
                 logits[vocab.token_to_id.at(" '")] = -INFINITY;
             }
         }
@@ -3854,7 +3854,7 @@ int whisper_full(
                         return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
                     });
 
-                    unsigned int cur_c = 0;
+                    uint32_t cur_c = 0;
 
                     for (int j = 0; j < n_decoders_cur; ++j) {
                         auto & decoder = ctx->decoders[j];
@@ -4339,7 +4339,7 @@ int whisper_full_n_segments(struct whisper_context * ctx) {
 }
 
 int whisper_full_lang_id(struct whisper_context * ctx) {
-    return ctx->lang_id; 
+    return ctx->lang_id;
 }
 
 int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
index 7eece797c16b84f31116289aaa50e84afb7c4fa6..3eb8d08420948ca35533d19ffbd9047a55a78493 100644 (file)
@@ -243,6 +243,16 @@ extern "C" {
     // If it returns false, the computation is aborted
     typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
 
+    // Logits filter callback
+    // Can be used to modify the logits before sampling
+    // If not NULL, called after applying temperature to logits
+    typedef void (*whisper_logits_filter_callback)(
+            struct whisper_context * ctx,
+          const whisper_token_data * tokens,
+                               int   n_tokens,
+                             float * logits,
+                              void * user_data);
+
     // Parameters for the whisper_full() function
     // If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
     // whisper_full_default_params()
@@ -315,6 +325,10 @@ extern "C" {
         // called each time before the encoder starts
         whisper_encoder_begin_callback encoder_begin_callback;
         void * encoder_begin_callback_user_data;
+
+        // called by each decoder to filter obtained logits
+        whisper_logits_filter_callback logits_filter_callback;
+        void * logits_filter_callback_user_data;
     };
 
     WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);