mutable std::mt19937 rng; // used for sampling at t > 0.0
- int lang_id;
+ int lang_id = 0; // english by default
// [EXPERIMENTAL] token-level timestamps data
- int64_t t_beg;
- int64_t t_last;
+ int64_t t_beg = 0;
+ int64_t t_last = 0;
whisper_token tid_last;
std::vector<float> energy; // PCM signal energy
// [EXPERIMENTAL] speed-up techniques
- int32_t exp_n_audio_ctx; // 0 - use default
+ int32_t exp_n_audio_ctx = 0; // 0 - use default
void use_buf(struct ggml_context * ctx, int i) {
#if defined(WHISPER_USE_SCRATCH)
MEM_REQ_SCRATCH3.at (model.type) +
scale*MEM_REQ_MODEL.at (model.type) +
scale*MEM_REQ_KV_CROSS.at(model.type) +
- scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
+ scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
// this is the memory required by one decoder
const size_t mem_required_decoder =
/*.language =*/ "en",
/*.suppress_blank =*/ true,
- /*.suppress_non_speech_tokens =*/true,
+ /*.suppress_non_speech_tokens =*/ false,
/*.temperature =*/ 0.0f,
/*.max_initial_ts =*/ 1.0f,
/*.encoder_begin_callback =*/ nullptr,
/*.encoder_begin_callback_user_data =*/ nullptr,
+
+ /*.logits_filter_callback =*/ nullptr,
+ /*.logits_filter_callback_user_data =*/ nullptr,
};
switch (strategy) {
return res;
}
-static const std::vector<std::string> non_speech_tokens
-{
+static const std::vector<std::string> non_speech_tokens = {
"\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
"_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--",
"---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
// - applies logit filters
// - computes logprobs and probs
static void whisper_process_logits(
- const struct whisper_context & ctx,
+ struct whisper_context & ctx,
const struct whisper_full_params params,
struct whisper_decoder & decoder,
float temperature) {
logits[vocab.token_translate] = -INFINITY;
logits[vocab.token_transcribe] = -INFINITY;
+ if (params.logits_filter_callback) {
+ params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
+ }
// suppress non-speech tokens
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
- if (params.suppress_non_speech_tokens)
- {
- for (const std::string &token : non_speech_tokens)
- {
- std::string suppress_tokens[] = {token, " " + token};
- for (const std::string &suppress_token : suppress_tokens)
- {
- if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end())
- {
+ if (params.suppress_non_speech_tokens) {
+ for (const std::string & token : non_speech_tokens) {
+ const std::string suppress_tokens[] = {token, " " + token};
+ for (const std::string & suppress_token : suppress_tokens) {
+ if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) {
logits[vocab.token_to_id.at(suppress_token)] = -INFINITY;
}
}
}
+
// allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
- if (vocab.token_to_id.find(" -") != vocab.token_to_id.end())
- {
+ if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) {
logits[vocab.token_to_id.at(" -")] = -INFINITY;
}
- if (vocab.token_to_id.find(" '") != vocab.token_to_id.end())
- {
+ if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) {
logits[vocab.token_to_id.at(" '")] = -INFINITY;
}
}
return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
});
- unsigned int cur_c = 0;
+ uint32_t cur_c = 0;
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j];
}
int whisper_full_lang_id(struct whisper_context * ctx) {
- return ctx->lang_id;
+ return ctx->lang_id;
}
int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
// If it returns false, the computation is aborted
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
+ // Logits filter callback
+ // Can be used to modify the logits before sampling
+ // If not NULL, called after applying temperature to logits
+ typedef void (*whisper_logits_filter_callback)(
+ struct whisper_context * ctx,
+ const whisper_token_data * tokens,
+ int n_tokens,
+ float * logits,
+ void * user_data);
+
// Parameters for the whisper_full() function
// If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
// whisper_full_default_params()
// called each time before the encoder starts
whisper_encoder_begin_callback encoder_begin_callback;
void * encoder_begin_callback_user_data;
+
+ // called by each decoder to filter obtained logits
+ whisper_logits_filter_callback logits_filter_callback;
+ void * logits_filter_callback_user_data;
};
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);