// the most basic sampling scheme - select the top token
whisper_vocab::id whisper_sample_best(
const whisper_vocab & vocab,
- const float * probs, bool need_timestamp) {
+ const float * probs) {
int n_logits = vocab.id_to_token.size();
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
probs_id.push_back(std::make_pair(probs[i], i));
}
- const int top_k = 4;
+ double sum_ts = 0.0;
+ double max_tx = 0.0;
+
+ for (int i = 0; i < vocab.token_beg; i++) {
+ max_tx = std::max(max_tx, probs_id[i].first);
+ }
+
+ for (int i = vocab.token_beg; i < n_logits; i++) {
+ sum_ts += probs_id[i].first;
+ }
+
+ // if the probability sum of all timestamp tokesn is higher than the max probability of the text tokens - sample a
+ // timestamp token
+ if (sum_ts > max_tx) {
+ // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
+ for (int i = 0; i < vocab.token_beg; i++) {
+ probs_id[i].first = -INFINITY;
+ }
+ }
// find the top K tokens
+ const int top_k = 4;
+
std::partial_sort(
probs_id.begin(),
probs_id.begin() + top_k, probs_id.end(),
// printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
//}
- if (need_timestamp) {
- // at the end of the 30-second audio segment, we start giving preference to time tokens
- for (int i = 0; i < top_k; i++) {
- if (probs_id[i].second > vocab.token_beg + 1300 && probs_id[i].first > 0.01*probs_id[0].first) {
- return probs_id[i].second;
- }
- }
- }
-
int res = 0;
while ((probs_id[res].second == vocab.token_sot ||
probs_id[res].second == vocab.token_solm ||
return 0;
}
-whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp) {
+whisper_token whisper_sample_best(struct whisper_context * ctx) {
const int64_t t_start_sample_us = ggml_time_us();
// TODO: simplify
- auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), need_timestamp);
+ auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
whisper_token id = 0;
whisper_token tid = whisper_token_beg(ctx);
- id = whisper_sample_best(ctx, result_len == 0);
+ id = whisper_sample_best(ctx);
if (i > 0) {
tid = whisper_sample_timestamp(ctx);
}
// You can also implement your own sampling method using the whisper_get_probs() function.
// whisper_sample_best() returns the token with the highest probability
// whisper_sample_timestamp() returns the most probable timestamp token
- WHISPER_API whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp);
+ WHISPER_API whisper_token whisper_sample_best(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
// Return the id of the specified language, returns -1 if not found