MEM_REQ_SCRATCH3.at (model.type) +
scale*MEM_REQ_MODEL.at (model.type) +
scale*MEM_REQ_KV_CROSS.at(model.type) +
- scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
+ scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
// this is the memory required by one decoder
const size_t mem_required_decoder =
/*.encoder_begin_callback =*/ nullptr,
/*.encoder_begin_callback_user_data =*/ nullptr,
+
+ /*.logits_filter_callback =*/ nullptr,
+ /*.logits_filter_callback_user_data =*/ nullptr,
};
switch (strategy) {
// - applies logit filters
// - computes logprobs and probs
static void whisper_process_logits(
- const struct whisper_context & ctx,
+ struct whisper_context & ctx,
const struct whisper_full_params params,
struct whisper_decoder & decoder,
float temperature) {
logits[vocab.token_translate] = -INFINITY;
logits[vocab.token_transcribe] = -INFINITY;
+ if (params.logits_filter_callback) {
+ params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
+ }
// suppress non-speech tokens
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
});
- unsigned int cur_c = 0;
+ uint32_t cur_c = 0;
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j];
// If it returns false, the computation is aborted
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
+ // Logits filter callback
+ // Can be used to modify the logits before sampling
+ // If not NULL, called after applying temperature to logits
+ typedef void (*whisper_logits_filter_callback)(
+ struct whisper_context * ctx,
+ const whisper_token_data * tokens,
+ int n_tokens,
+ float * logits,
+ void * user_data);
+
// Parameters for the whisper_full() function
// If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
// whisper_full_default_params()
// called each time before the encoder starts
whisper_encoder_begin_callback encoder_begin_callback;
void * encoder_begin_callback_user_data;
+
+ // called by each decoder to filter obtained logits
+ whisper_logits_filter_callback logits_filter_callback;
+ void * logits_filter_callback_user_data;
};
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);