]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
whisper : fix timestamp sampling
authorGeorgi Gerganov <redacted>
Tue, 18 Oct 2022 18:14:27 +0000 (21:14 +0300)
committerGeorgi Gerganov <redacted>
Tue, 18 Oct 2022 18:14:27 +0000 (21:14 +0300)
examples/whisper/whisper.cpp
examples/whisper/whisper.h
src/ggml.c

index 236fcf1dba26d37eb81b6a2485f118270b99f36e..2d2b8cedcb173644b4153dba698911dfac216a75 100644 (file)
@@ -1784,7 +1784,7 @@ bool whisper_decode(
 // the most basic sampling scheme - select the top token
 whisper_vocab::id whisper_sample_best(
         const whisper_vocab & vocab,
-        const float * probs, bool need_timestamp) {
+        const float * probs) {
     int n_logits = vocab.id_to_token.size();
 
     std::vector<std::pair<double, whisper_vocab::id>> probs_id;
@@ -1794,9 +1794,29 @@ whisper_vocab::id whisper_sample_best(
         probs_id.push_back(std::make_pair(probs[i], i));
     }
 
-    const int top_k = 4;
+    double sum_ts = 0.0;
+    double max_tx = 0.0;
+
+    for (int i = 0; i < vocab.token_beg; i++) {
+        max_tx = std::max(max_tx, probs_id[i].first);
+    }
+
+    for (int i = vocab.token_beg; i < n_logits; i++) {
+        sum_ts += probs_id[i].first;
+    }
+
+    // if the probability sum of all timestamp tokesn is higher than the max probability of the text tokens - sample a
+    // timestamp token
+    if (sum_ts > max_tx) {
+        // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
+        for (int i = 0; i < vocab.token_beg; i++) {
+            probs_id[i].first = -INFINITY;
+        }
+    }
 
     // find the top K tokens
+    const int top_k = 4;
+
     std::partial_sort(
             probs_id.begin(),
             probs_id.begin() + top_k, probs_id.end(),
@@ -1811,15 +1831,6 @@ whisper_vocab::id whisper_sample_best(
     //    printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
     //}
 
-    if (need_timestamp) {
-        // at the end of the 30-second audio segment, we start giving preference to time tokens
-        for (int i = 0; i < top_k; i++) {
-            if (probs_id[i].second > vocab.token_beg + 1300 && probs_id[i].first > 0.01*probs_id[0].first) {
-                return probs_id[i].second;
-            }
-        }
-    }
-
     int res = 0;
     while ((probs_id[res].second == vocab.token_sot ||
             probs_id[res].second == vocab.token_solm ||
@@ -2155,11 +2166,11 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
     return 0;
 }
 
-whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp) {
+whisper_token whisper_sample_best(struct whisper_context * ctx) {
     const int64_t t_start_sample_us = ggml_time_us();
 
     // TODO: simplify
-    auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), need_timestamp);
+    auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
 
     ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
 
@@ -2437,7 +2448,7 @@ int whisper_full(
                 whisper_token id  = 0;
                 whisper_token tid = whisper_token_beg(ctx);
 
-                id = whisper_sample_best(ctx, result_len == 0);
+                id = whisper_sample_best(ctx);
                 if (i > 0) {
                     tid = whisper_sample_timestamp(ctx);
                 }
index 45faa5b2f220a2d662ff643f38f4afd53e97f546..4423674d1d28695b77ef8dfebbba44a8c28889c0 100644 (file)
@@ -120,7 +120,7 @@ extern "C" {
     // You can also implement your own sampling method using the whisper_get_probs() function.
     // whisper_sample_best() returns the token with the highest probability
     // whisper_sample_timestamp() returns the most probable timestamp token
-    WHISPER_API whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp);
+    WHISPER_API whisper_token whisper_sample_best(struct whisper_context * ctx);
     WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
 
     // Return the id of the specified language, returns -1 if not found
index 4861f24925396e85661b3c0b2b8515a49abb6f74..115e619b081f48927a90e0c37322508316fba308 100644 (file)
@@ -75,6 +75,9 @@ ggml_fp16_t ggml_fp32_to_fp16(float x) {
 
 #include <immintrin.h>
 
+// FP16 <-> FP32
+// ref: https://github.com/Maratyszcza/FP16
+
 static inline float fp32_from_bits(uint32_t w) {
     union {
         uint32_t as_bits;