]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : add new-segment callback
authorGeorgi Gerganov <redacted>
Sat, 22 Oct 2022 18:06:50 +0000 (21:06 +0300)
committerGeorgi Gerganov <redacted>
Sat, 22 Oct 2022 18:17:21 +0000 (21:17 +0300)
Can be used to process new segments as they are being generated.
Sample usage in main, for printing the resulting segments during the
inference.

main.cpp
whisper.cpp
whisper.h

index 1fcc106617993562e1ff97b5a9ba9e9bd7e053f7..b06486019fe6f6e41d5e351132ee4f71c8690165 100644 (file)
--- a/main.cpp
+++ b/main.cpp
@@ -141,6 +141,55 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
     fprintf(stderr, "\n");
 }
 
+void whisper_print_segment_callback(struct whisper_context * ctx, void * user_data) {
+    const whisper_params & params = *(whisper_params *) user_data;
+
+    const int n_segments = whisper_full_n_segments(ctx);
+
+    // print the last segment
+    const int i = n_segments - 1;
+    if (i == 0) {
+        printf("\n");
+    }
+
+    if (params.no_timestamps) {
+        if (params.print_colors) {
+            // TODO
+        } else {
+            const char * text = whisper_full_get_segment_text(ctx, i);
+            printf("%s", text);
+            fflush(stdout);
+        }
+    } else {
+        const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
+        const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+
+        if (params.print_colors) {
+            printf("[%s --> %s]  ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
+            for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
+                if (params.print_special_tokens == false) {
+                    const whisper_token id = whisper_full_get_token_id(ctx, i, j);
+                    if (id >= whisper_token_eot(ctx)) {
+                        continue;
+                    }
+                }
+
+                const char * text = whisper_full_get_token_text(ctx, i, j);
+                const float  p    = whisper_full_get_token_p   (ctx, i, j);
+
+                const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
+
+                printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
+            }
+            printf("\n");
+        } else {
+            const char * text = whisper_full_get_segment_text(ctx, i);
+
+            printf("[%s --> %s]  %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
+        }
+    }
+}
+
 bool output_txt(struct whisper_context * ctx, const char * fname) {
     std::ofstream fout(fname);
     if (!fout.is_open()) {
@@ -294,7 +343,7 @@ int main(int argc, char ** argv) {
         {
             whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
 
-            wparams.print_realtime       = !params.print_colors;
+            wparams.print_realtime       = false;
             wparams.print_progress       = false;
             wparams.print_timestamps     = !params.no_timestamps;
             wparams.print_special_tokens = params.print_special_tokens;
@@ -303,49 +352,17 @@ int main(int argc, char ** argv) {
             wparams.n_threads            = params.n_threads;
             wparams.offset_ms            = params.offset_t_ms;
 
+            // this callback is called on each new segment
+            if (!wparams.print_realtime) {
+                wparams.new_segment_callback           = whisper_print_segment_callback;
+                wparams.new_segment_callback_user_data = &params;
+            }
+
             if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
                 fprintf(stderr, "%s: failed to process audio\n", argv[0]);
                 return 7;
             }
 
-            // print result
-            if (!wparams.print_realtime) {
-                printf("\n");
-
-                const int n_segments = whisper_full_n_segments(ctx);
-                for (int i = 0; i < n_segments; ++i) {
-                    if (params.no_timestamps) {
-                        if (params.print_colors) {
-                            // TODO
-                        } else {
-                            const char * text = whisper_full_get_segment_text(ctx, i);
-                            printf("%s", text);
-                            fflush(stdout);
-                        }
-                    } else {
-                        const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
-                        const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
-
-                        if (params.print_colors) {
-                            printf("[%s --> %s]  ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
-                            for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
-                                const char * text = whisper_full_get_token_text(ctx, i, j);
-                                const float  p    = whisper_full_get_token_p   (ctx, i, j);
-
-                                const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
-
-                                printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
-                            }
-                            printf("\n");
-                        } else {
-                            const char * text = whisper_full_get_segment_text(ctx, i);
-
-                            printf("[%s --> %s]  %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
-                        }
-                    }
-                }
-            }
-
             printf("\n");
 
             // output to text file
index 5c5f8bd32627e48e1708ea38e56acde033a5106c..01f6b00bed6c845e8be6ddb93b68ce7d5e732128 100644 (file)
@@ -2320,6 +2320,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
                         /*.beam_width =*/ -1,
                         /*.n_best     =*/ -1,
                     },
+
+                    /*.new_segment_callback =*/ nullptr,
+                    /*.new_segment_callback_user_data =*/ nullptr,
                 };
             } break;
         case WHISPER_SAMPLING_BEAM_SEARCH:
@@ -2348,6 +2351,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
                         /*.beam_width =*/ 10,
                         /*.n_best     =*/ 5,
                     },
+
+                    /*.new_segment_callback =*/ nullptr,
+                    /*.new_segment_callback_user_data =*/ nullptr,
                 };
             } break;
     }
@@ -2549,6 +2555,9 @@ int whisper_full(
                         for (int j = i0; j <= i; j++) {
                             result_all.back().tokens.push_back(tokens_cur[j]);
                         }
+                        if (params.new_segment_callback) {
+                            params.new_segment_callback(ctx, params.new_segment_callback_user_data);
+                        }
                     }
                     text = "";
                     while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
@@ -2576,6 +2585,9 @@ int whisper_full(
                 for (int j = i0; j < (int) tokens_cur.size(); j++) {
                     result_all.back().tokens.push_back(tokens_cur[j]);
                 }
+                if (params.new_segment_callback) {
+                    params.new_segment_callback(ctx, params.new_segment_callback_user_data);
+                }
             }
         }
 
@@ -2609,6 +2621,10 @@ const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_seg
     return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str();
 }
 
+whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) {
+    return ctx->result_all[i_segment].tokens[i_token].id;
+}
+
 float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
     return ctx->result_all[i_segment].tokens[i_token].p;
 }
index 3435cd7744d404e7e2c5e216319ef0a07cf348cc..53b0041055a3013bf6ed8939dcfa8cd63ba107c1 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -160,6 +160,11 @@ extern "C" {
         WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet!
     };
 
+    // Text segment callback
+    // Called on every newly generated text segment
+    // Use the whisper_full_...() functions to obtain the text segments
+    typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, void * user_data);
+
     struct whisper_full_params {
         enum whisper_sampling_strategy strategy;
 
@@ -184,6 +189,9 @@ extern "C" {
             int beam_width;
             int n_best;
         } beam_search;
+
+        whisper_new_segment_callback new_segment_callback;
+        void * new_segment_callback_user_data;
     };
 
     WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
@@ -212,6 +220,7 @@ extern "C" {
 
     // Get the token text of the specified token in the specified segment.
     WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token);
+    WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token);
 
     // Get the probability of the specified token in the specified segment.
     WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);