whisper_token tid_last;
std::vector<float> energy; // PCM signal energy
+ float no_speech_prob = 0.0f;
// [EXPERIMENTAL] Token-level timestamps with DTW
whisper_aheads_masks aheads_masks;
"♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
};
+static void whisper_compute_logprobs(
+ const std::vector<float> & logits,
+ const int n_logits,
+ std::vector<float> & logprobs) {
+ const float logit_max = *std::max_element(logits.begin(), logits.end());
+ float logsumexp = 0.0f;
+ for (int i = 0; i < n_logits; ++i) {
+ if (logits[i] > -INFINITY) {
+ logsumexp += expf(logits[i] - logit_max);
+ }
+ }
+ logsumexp = logf(logsumexp) + logit_max;
+
+ for (int i = 0; i < n_logits; ++i) {
+ if (logits[i] > -INFINITY) {
+ logprobs[i] = logits[i] - logsumexp;
+ } else {
+ logprobs[i] = -INFINITY;
+ }
+ }
+}
+
+static void whisper_compute_probs(
+ const std::vector<float> & logits,
+ const int n_logits,
+ const std::vector<float> & logprobs,
+ std::vector<float> & probs) {
+ for (int i = 0; i < n_logits; ++i) {
+ if (logits[i] == -INFINITY) {
+ probs[i] = 0.0f;
+ } else {
+ probs[i] = expf(logprobs[i]);
+ }
+ }
+}
+
// process the logits for the selected decoder
// - applies logit filters
// - computes logprobs and probs
// suppress sot and nosp tokens
logits[vocab.token_sot] = -INFINITY;
- logits[vocab.token_nosp] = -INFINITY; // TODO: ignore this token for now
+ logits[vocab.token_nosp] = -INFINITY;
// [TDRZ] when tinydiarize is disabled, suppress solm token
if (params.tdrz_enable == false) {
}
// populate the logprobs array (log_softmax)
- {
- const float logit_max = *std::max_element(logits.begin(), logits.end());
- float logsumexp = 0.0f;
- for (int i = 0; i < n_logits; ++i) {
- if (logits[i] > -INFINITY) {
- logsumexp += expf(logits[i] - logit_max);
- }
- }
- logsumexp = logf(logsumexp) + logit_max;
-
- for (int i = 0; i < n_logits; ++i) {
- if (logits[i] > -INFINITY) {
- logprobs[i] = logits[i] - logsumexp;
- } else {
- logprobs[i] = -INFINITY;
- }
- }
- }
+ whisper_compute_logprobs(logits, n_logits, logprobs);
// if sum of probability over timestamps is above any other token, sample timestamp
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
}
// compute probs
- {
- for (int i = 0; i < n_logits; ++i) {
- if (logits[i] == -INFINITY) {
- probs[i] = 0.0f;
- } else {
- probs[i] = expf(logprobs[i]);
- }
- }
- }
+ whisper_compute_probs(logits, n_logits, logprobs, probs);
#if 0
// print first 100 logits - token string : logit
return -8;
}
+ // Calculate no_speech probability after first decode.
+ // This has to be done before any logit filtering. Hence we cannot use the probs from the whisper_process_logits.
+ {
+ const int n_logits = ctx->vocab.id_to_token.size();
+ std::vector<float> logprobs(n_logits);
+ std::vector<float> probs(n_logits);
+
+ whisper_compute_logprobs(state->logits, n_logits, logprobs);
+ whisper_compute_probs(state->logits, n_logits, logprobs, probs);
+ state->no_speech_prob = probs[whisper_token_nosp(ctx)];
+ }
+
{
const int64_t t_start_sample_us = ggml_time_us();
if (it != (int) temperatures.size() - 1) {
const auto & decoder = state->decoders[best_decoder_id];
- if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
- WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold);
+ if (decoder.failed ||
+ (decoder.sequence.avg_logprobs < params.logprob_thold && state->no_speech_prob < params.no_speech_thold)) {
+ WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f and no_speech_prob %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold, state->no_speech_prob, params.no_speech_thold);
success = false;
state->n_fail_p++;
}
// [EXPERIMENTAL] Token-level timestamps with DTW
const auto n_segments_before = state->result_all.size();
+ const bool is_no_speech = (state->no_speech_prob > params.no_speech_thold &&
+ best_decoder.sequence.avg_logprobs < params.logprob_thold);
+
//WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
// update prompt_past
prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
}
- for (int i = 0; i < result_len; ++i) {
+ for (int i = 0; i < result_len && !is_no_speech; ++i) {
prompt_past.push_back(tokens_cur[i].id);
}
- if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
+ if (!tokens_cur.empty() && ctx->model.n_loaded > 0 && !is_no_speech) {
int i0 = 0;
auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));