]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : add API for applying custom logits filters during decoding
authorGeorgi Gerganov <redacted>
Sun, 19 Feb 2023 16:35:01 +0000 (18:35 +0200)
committerGeorgi Gerganov <redacted>
Sun, 19 Feb 2023 16:35:01 +0000 (18:35 +0200)
whisper.cpp
whisper.h

index 04cbc36b2ca909c778cdeaef8968ba396aeb629c..9a9e562a6c9124c6e4cdfb1714ffa57b3912061a 100644 (file)
@@ -805,7 +805,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
                      MEM_REQ_SCRATCH3.at (model.type) +
                 scale*MEM_REQ_MODEL.at   (model.type) +
                 scale*MEM_REQ_KV_CROSS.at(model.type) +
-                scale*std::max(MEM_REQ_ENCODE.at(model.type),       MEM_REQ_DECODE.at(model.type));
+                scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
 
             // this is the memory required by one decoder
             const size_t mem_required_decoder =
@@ -2962,6 +2962,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
 
         /*.encoder_begin_callback           =*/ nullptr,
         /*.encoder_begin_callback_user_data =*/ nullptr,
+
+        /*.logits_filter_callback           =*/ nullptr,
+        /*.logits_filter_callback_user_data =*/ nullptr,
     };
 
     switch (strategy) {
@@ -3089,7 +3092,7 @@ static const std::vector<std::string> non_speech_tokens = {
 // - applies logit filters
 // - computes logprobs and probs
 static void whisper_process_logits(
-        const struct whisper_context & ctx,
+              struct whisper_context & ctx,
     const struct whisper_full_params   params,
               struct whisper_decoder & decoder,
                                float   temperature) {
@@ -3145,6 +3148,9 @@ static void whisper_process_logits(
         logits[vocab.token_translate]  = -INFINITY;
         logits[vocab.token_transcribe] = -INFINITY;
 
+        if (params.logits_filter_callback) {
+            params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
+        }
 
         // suppress non-speech tokens
         // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
@@ -3848,7 +3854,7 @@ int whisper_full(
                         return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
                     });
 
-                    unsigned int cur_c = 0;
+                    uint32_t cur_c = 0;
 
                     for (int j = 0; j < n_decoders_cur; ++j) {
                         auto & decoder = ctx->decoders[j];
index 7eece797c16b84f31116289aaa50e84afb7c4fa6..3eb8d08420948ca35533d19ffbd9047a55a78493 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -243,6 +243,16 @@ extern "C" {
     // If it returns false, the computation is aborted
     typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
 
+    // Logits filter callback
+    // Can be used to modify the logits before sampling
+    // If not NULL, called after applying temperature to logits
+    typedef void (*whisper_logits_filter_callback)(
+            struct whisper_context * ctx,
+          const whisper_token_data * tokens,
+                               int   n_tokens,
+                             float * logits,
+                              void * user_data);
+
     // Parameters for the whisper_full() function
     // If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
     // whisper_full_default_params()
@@ -315,6 +325,10 @@ extern "C" {
         // called each time before the encoder starts
         whisper_encoder_begin_callback encoder_begin_callback;
         void * encoder_begin_callback_user_data;
+
+        // called by each decoder to filter obtained logits
+        whisper_logits_filter_callback logits_filter_callback;
+        void * logits_filter_callback_user_data;
     };
 
     WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);