]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
wip : experimental color coding of tokens based on probabilities
authorGeorgi Gerganov <redacted>
Fri, 21 Oct 2022 14:33:59 +0000 (17:33 +0300)
committerGeorgi Gerganov <redacted>
Sat, 22 Oct 2022 18:17:21 +0000 (21:17 +0300)
main.cpp
whisper.cpp
whisper.h

index c9ac6699bc9420f89bdc2cd2c676735a905336ed..cf5ca950fbd7eeeb34d4df24b61cd61313fcfa76 100644 (file)
--- a/main.cpp
+++ b/main.cpp
@@ -5,12 +5,20 @@
 #define DR_WAV_IMPLEMENTATION
 #include "dr_wav.h"
 
+#include <cmath>
 #include <fstream>
 #include <cstdio>
 #include <string>
 #include <thread>
 #include <vector>
 
+// Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9]
+// Lowest is red, middle is yellow, highest is green.
+const std::vector<std::string> k_colors = {
+    "\033[38;5;196m", "\033[38;5;202m", "\033[38;5;208m", "\033[38;5;214m", "\033[38;5;220m",
+    "\033[38;5;226m", "\033[38;5;190m", "\033[38;5;154m", "\033[38;5;118m", "\033[38;5;82m",
+};
+
 //  500 -> 00:05.000
 // 6000 -> 01:00.000
 std::string to_timestamp(int64_t t) {
@@ -41,6 +49,7 @@ struct whisper_params {
     bool output_vtt           = false;
     bool output_srt           = false;
     bool print_special_tokens = false;
+    bool print_colors         = false;
     bool no_timestamps        = false;
 
     std::string language  = "en";
@@ -87,6 +96,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
             params.output_srt = true;
         } else if (arg == "-ps" || arg == "--print_special") {
             params.print_special_tokens = true;
+        } else if (arg == "-pc" || arg == "--print_colors") {
+            params.print_colors = true;
         } else if (arg == "-nt" || arg == "--no_timestamps") {
             params.no_timestamps = true;
         } else if (arg == "-m" || arg == "--model") {
@@ -122,6 +133,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
     fprintf(stderr, "  -ovtt,    --output-vtt     output result in a vtt file\n");
     fprintf(stderr, "  -osrt,    --output-srt     output result in a srt file\n");
     fprintf(stderr, "  -ps,      --print_special  print special tokens\n");
+    fprintf(stderr, "  -pc,      --print_colors   print colors\n");
     fprintf(stderr, "  -nt,      --no_timestamps  do not print timestamps\n");
     fprintf(stderr, "  -l LANG,  --language LANG  spoken language (default: %s)\n", params.language.c_str());
     fprintf(stderr, "  -m FNAME, --model FNAME    model path (default: %s)\n", params.model.c_str());
@@ -222,7 +234,7 @@ int main(int argc, char ** argv) {
         {
             whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
 
-            wparams.print_realtime       = true;
+            wparams.print_realtime       = !params.print_colors;
             wparams.print_progress       = false;
             wparams.print_timestamps     = !params.no_timestamps;
             wparams.print_special_tokens = params.print_special_tokens;
@@ -242,16 +254,34 @@ int main(int argc, char ** argv) {
 
                 const int n_segments = whisper_full_n_segments(ctx);
                 for (int i = 0; i < n_segments; ++i) {
-                    const char * text = whisper_full_get_segment_text(ctx, i);
-
                     if (params.no_timestamps) {
-                        printf("%s", text);
-                        fflush(stdout);
+                        if (params.print_colors) {
+                            // TODO
+                        } else {
+                            const char * text = whisper_full_get_segment_text(ctx, i);
+                            printf("%s", text);
+                            fflush(stdout);
+                        }
                     } else {
                         const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
                         const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
 
-                        printf("[%s --> %s]  %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
+                        if (params.print_colors) {
+                            printf("[%s --> %s]  ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
+                            for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
+                                const char * text = whisper_full_get_token_text(ctx, i, j);
+                                const float  p    = whisper_full_get_token_p   (ctx, i, j);
+
+                                const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
+
+                                printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
+                            }
+                            printf("\n");
+                        } else {
+                            const char * text = whisper_full_get_segment_text(ctx, i);
+
+                            printf("[%s --> %s]  %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
+                        }
                     }
                 }
             }
@@ -260,7 +290,6 @@ int main(int argc, char ** argv) {
 
             // output to text file
             if (params.output_txt) {
-
                 const auto fname_txt = fname_inp + ".txt";
                 std::ofstream fout_txt(fname_txt);
                 if (!fout_txt.is_open()) {
@@ -279,7 +308,6 @@ int main(int argc, char ** argv) {
 
             // output to VTT file
             if (params.output_vtt) {
-
                 const auto fname_vtt = fname_inp + ".vtt";
                 std::ofstream fout_vtt(fname_vtt);
                 if (!fout_vtt.is_open()) {
@@ -304,7 +332,6 @@ int main(int argc, char ** argv) {
 
             // output to SRT file
             if (params.output_srt) {
-
                 const auto fname_srt = fname_inp + ".srt";
                 std::ofstream fout_srt(fname_srt);
                 if (!fout_srt.is_open()) {
index 09250c06755ba96bd786836b958f755e8a0aee1e..5c5f8bd32627e48e1708ea38e56acde033a5106c 100644 (file)
@@ -210,9 +210,12 @@ struct whisper_vocab {
     }
 };
 
-struct whisper_result {
-    int64_t t;
-    whisper_token id;
+struct whisper_token_data {
+    whisper_token id;  // token id
+    whisper_token tid; // forced timestamp token id
+
+    float p;  // probability of the token
+    float pt; // probability of the timestamp token
 };
 
 struct whisper_segment {
@@ -220,6 +223,8 @@ struct whisper_segment {
     int64_t t1;
 
     std::string text;
+
+    std::vector<whisper_token_data> tokens;
 };
 
 // medium
@@ -407,7 +412,7 @@ struct whisper_context {
     std::vector<float> probs;
     std::vector<float> logits;
 
-    std::vector<whisper_result>  result_cur;
+    std::vector<whisper_token_data> tokens_cur;
     std::vector<whisper_segment> result_all;
 
     std::vector<whisper_token> prompt_past;
@@ -1786,9 +1791,11 @@ bool whisper_decode(
 }
 
 // the most basic sampling scheme - select the top token
-whisper_vocab::id whisper_sample_best(
+whisper_token_data whisper_sample_best(
         const whisper_vocab & vocab,
         const float * probs) {
+    whisper_token_data result;
+
     int n_logits = vocab.id_to_token.size();
 
     std::vector<std::pair<double, whisper_vocab::id>> probs_id;
@@ -1798,24 +1805,33 @@ whisper_vocab::id whisper_sample_best(
         probs_id.push_back(std::make_pair(probs[i], i));
     }
 
-    double sum_ts = 0.0;
-    double max_tx = 0.0;
+    {
+        double sum_ts =  0.0;
+        double max_ts = -1.0;
+        double max_tx = -1.0;
 
-    for (int i = 0; i < vocab.token_beg; i++) {
-        max_tx = std::max(max_tx, probs_id[i].first);
-    }
+        for (int i = 0; i < vocab.token_beg; i++) {
+            max_tx = std::max(max_tx, probs_id[i].first);
+        }
 
-    for (int i = vocab.token_beg; i < n_logits; i++) {
-        sum_ts += probs_id[i].first;
-    }
+        for (int i = vocab.token_beg; i < n_logits; i++) {
+            sum_ts += probs_id[i].first;
+            if  (probs_id[i].first > max_ts) {
+                max_ts = probs_id[i].first;
+                result.tid = probs_id[i].second;
+            }
+        }
 
-    // if the probability sum of all timestamp tokesn is higher than the max probability of the text tokens - sample a
-    // timestamp token
-    if (sum_ts > max_tx) {
-        // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
-        for (int i = 0; i < vocab.token_beg; i++) {
-            probs_id[i].first = -INFINITY;
+        // if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a
+        // timestamp token
+        if (sum_ts > max_tx) {
+            // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
+            for (int i = 0; i < vocab.token_beg; i++) {
+                probs_id[i].first = -INFINITY;
+            }
         }
+
+        result.pt = max_ts/(sum_ts + 1e-6);
     }
 
     // find the top K tokens
@@ -1843,7 +1859,10 @@ whisper_vocab::id whisper_sample_best(
         res++;
     }
 
-    return probs_id[res].second;
+    result.id = probs_id[res].second;
+    result.p  = probs_id[res].first;
+
+    return result;
 }
 
 // samples only from the timestamps tokens
@@ -2178,7 +2197,7 @@ whisper_token whisper_sample_best(struct whisper_context * ctx) {
 
     ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
 
-    return res;
+    return res.id;
 }
 
 whisper_token whisper_sample_timestamp(struct whisper_context * ctx) {
@@ -2343,7 +2362,7 @@ int whisper_full(
         int n_samples) {
     // clear old results
     auto & result_all = ctx->result_all;
-    auto & result_cur = ctx->result_cur;
+    auto & tokens_cur = ctx->tokens_cur;
 
     result_all.clear();
 
@@ -2430,7 +2449,7 @@ int whisper_full(
 
         // the accumulated transcription in the current interation
         int result_len = 0;
-        result_cur.clear();
+        tokens_cur.clear();
 
         for (int i = 0; i < whisper_n_text_ctx(ctx)/2 - 4; ++i) {
             if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
@@ -2449,28 +2468,26 @@ int whisper_full(
             // feel free to experiment!
             //
             {
-                whisper_token id  = 0;
-                whisper_token tid = whisper_token_beg(ctx);
+                auto token = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
 
-                id = whisper_sample_best(ctx);
-                if (i > 0) {
-                    tid = whisper_sample_timestamp(ctx);
+                if (i == 0) {
+                    token.tid = whisper_token_beg(ctx);
                 }
 
-                // update sliding window
-                if (id > whisper_token_beg(ctx)) {
-                    seek_delta = 2*(id - whisper_token_beg(ctx));
+                // timestamp token - update sliding window
+                if (token.id > whisper_token_beg(ctx)) {
+                    seek_delta = 2*(token.id - whisper_token_beg(ctx));
                     result_len = i + 1;
                 }
 
                 // add it to the context
-                prompt.push_back(id);
-                result_cur.push_back({ seek + 2*(tid - whisper_token_beg(ctx)), id });
+                prompt.push_back(token.id);
+                tokens_cur.push_back(token);
 
                 //printf("%s: %s\n", __func__, ctx->vocab.id_to_token[id].c_str());
 
                 // end of text token
-                if (id == whisper_token_eot(ctx)) {
+                if (token.id == whisper_token_eot(ctx)) {
                     if (result_len == 0) {
                         if (seek + seek_delta + 100 >= whisper_n_len(ctx)) {
                             result_len = i + 1;
@@ -2494,25 +2511,30 @@ int whisper_full(
             }
         }
 
-        result_cur.resize(result_len);
+        tokens_cur.resize(result_len);
 
-        for (const auto & r : result_cur) {
+        for (const auto & r : tokens_cur) {
             prompt_past.push_back(r.id);
         }
 
         // store the text from this iteration
-        if (result_cur.size() > 0) {
-            auto t0 = result_cur.front().t;
+        if (tokens_cur.size() > 0) {
+            int  i0 = 0;
+            auto t0 = 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
 
             std::string text = "";
 
-            for (int i = 0; i < (int) result_cur.size(); i++) {
-                if (params.print_special_tokens == false && result_cur[i].id >= whisper_token_eot(ctx)) {
+            for (int i = 0; i < (int) tokens_cur.size(); i++) {
+                //printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
+                //        ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
+                //        ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
+
+                if (params.print_special_tokens == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
                 } else {
-                    text += whisper_token_to_str(ctx, result_cur[i].id);
+                    text += whisper_token_to_str(ctx, tokens_cur[i].id);
                 }
-                if (result_cur[i].id > whisper_token_beg(ctx)) {
-                    const auto t1 = result_cur[i].t;
+                if (tokens_cur[i].id > whisper_token_beg(ctx)) {
+                    const auto t1 = 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
                     if (!text.empty()) {
                         if (params.print_realtime) {
                             if (params.print_timestamps) {
@@ -2523,14 +2545,18 @@ int whisper_full(
                             }
                         }
 
-                        result_all.push_back({ t0, t1, text });
+                        result_all.push_back({ t0, t1, text, {} });
+                        for (int j = i0; j <= i; j++) {
+                            result_all.back().tokens.push_back(tokens_cur[j]);
+                        }
                     }
                     text = "";
-                    while (i < (int) result_cur.size() && result_cur[i].id > whisper_token_beg(ctx)) {
+                    while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
                         i++;
                     }
                     i--;
-                    t0 = result_cur[i].t;
+                    t0 = t1;
+                    i0 = i + 1;
                 }
             }
 
@@ -2546,7 +2572,10 @@ int whisper_full(
                     }
                 }
 
-                result_all.push_back({ t0, t1, text });
+                result_all.push_back({ t0, t1, text, {} });
+                for (int j = i0; j < (int) tokens_cur.size(); j++) {
+                    result_all.back().tokens.push_back(tokens_cur[j]);
+                }
             }
         }
 
@@ -2571,3 +2600,15 @@ int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)
 const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) {
     return ctx->result_all[i_segment].text.c_str();
 }
+
+int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) {
+    return ctx->result_all[i_segment].tokens.size();
+}
+
+const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) {
+    return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str();
+}
+
+float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
+    return ctx->result_all[i_segment].tokens[i_token].p;
+}
index 4423674d1d28695b77ef8dfebbba44a8c28889c0..3435cd7744d404e7e2c5e216319ef0a07cf348cc 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -207,6 +207,15 @@ extern "C" {
     // Get the text of the specified segment.
     WHISPER_API const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment);
 
+    // Get number of tokens in the specified segment.
+    WHISPER_API int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment);
+
+    // Get the token text of the specified token in the specified segment.
+    WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token);
+
+    // Get the probability of the specified token in the specified segment.
+    WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
+
 #ifdef __cplusplus
 }
 #endif