]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
main : make whisper_print_segment_callback() more readable (close #371)
authorGeorgi Gerganov <redacted>
Thu, 5 Jan 2023 19:45:05 +0000 (21:45 +0200)
committerGeorgi Gerganov <redacted>
Thu, 5 Jan 2023 19:45:05 +0000 (21:45 +0200)
examples/main/main.cpp

index 9310894752ffbc6d4a904f13cf7ecca563aa945c..ef903aa71a660cb14026dc84fc5ab604bda3713f 100644 (file)
@@ -176,90 +176,82 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
 
     const int n_segments = whisper_full_n_segments(ctx);
 
+    std::string speaker = "";
+
+    int64_t t0;
+    int64_t t1;
+
     // print the last n_new segments
     const int s0 = n_segments - n_new;
+
     if (s0 == 0) {
         printf("\n");
     }
 
     for (int i = s0; i < n_segments; i++) {
-        if (params.no_timestamps) {
-            if (params.print_colors) {
-                for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
-                    if (params.print_special == false) {
-                        const whisper_token id = whisper_full_get_token_id(ctx, i, j);
-                        if (id >= whisper_token_eot(ctx)) {
-                            continue;
-                        }
-                    }
+        if (!params.no_timestamps || params.diarize) {
+            t0 = whisper_full_get_segment_t0(ctx, i);
+            t1 = whisper_full_get_segment_t1(ctx, i);
+        }
 
-                    const char * text = whisper_full_get_token_text(ctx, i, j);
-                    const float  p    = whisper_full_get_token_p   (ctx, i, j);
+        if (!params.no_timestamps) {
+            printf("[%s --> %s]  ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
+        }
 
-                    const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
+        if (params.diarize && pcmf32s.size() == 2) {
 
-                    printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
-                }
-            } else {
-                const char * text = whisper_full_get_segment_text(ctx, i);
-                printf("%s", text);
-            }
-            fflush(stdout);
-        } else {
-            const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
-            const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+            const int64_t n_samples = pcmf32s[0].size();
 
-            std::string speaker;
+            const int64_t is0 = timestamp_to_sample(t0, n_samples);
+            const int64_t is1 = timestamp_to_sample(t1, n_samples);
 
-            if (params.diarize && pcmf32s.size() == 2) {
-                const int64_t n_samples = pcmf32s[0].size();
+            double energy0 = 0.0f;
+            double energy1 = 0.0f;
 
-                const int64_t is0 = timestamp_to_sample(t0, n_samples);
-                const int64_t is1 = timestamp_to_sample(t1, n_samples);
+            for (int64_t j = is0; j < is1; j++) {
+                energy0 += fabs(pcmf32s[0][j]);
+                energy1 += fabs(pcmf32s[1][j]);
+            }
 
-                double energy0 = 0.0f;
-                double energy1 = 0.0f;
+            if (energy0 > 1.1*energy1) {
+                speaker = "(speaker 0)";
+            } else if (energy1 > 1.1*energy0) {
+                speaker = "(speaker 1)";
+            } else {
+                speaker = "(speaker ?)";
+            }
 
-                for (int64_t j = is0; j < is1; j++) {
-                    energy0 += fabs(pcmf32s[0][j]);
-                    energy1 += fabs(pcmf32s[1][j]);
-                }
+            //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
+        }
 
-                if (energy0 > 1.1*energy1) {
-                    speaker = "(speaker 0)";
-                } else if (energy1 > 1.1*energy0) {
-                    speaker = "(speaker 1)";
-                } else {
-                    speaker = "(speaker ?)";
+        if (params.print_colors) {
+            for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
+                if (params.print_special == false) {
+                    const whisper_token id = whisper_full_get_token_id(ctx, i, j);
+                    if (id >= whisper_token_eot(ctx)) {
+                        continue;
+                    }
                 }
 
-                //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
-            }
-
-            if (params.print_colors) {
-                printf("[%s --> %s]  ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
-                for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
-                    if (params.print_special == false) {
-                        const whisper_token id = whisper_full_get_token_id(ctx, i, j);
-                        if (id >= whisper_token_eot(ctx)) {
-                            continue;
-                        }
-                    }
+                const char * text = whisper_full_get_token_text(ctx, i, j);
+                const float  p    = whisper_full_get_token_p   (ctx, i, j);
 
-                    const char * text = whisper_full_get_token_text(ctx, i, j);
-                    const float  p    = whisper_full_get_token_p   (ctx, i, j);
+                const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
 
-                    const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
+                printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
+            }
+        } else {
+            const char * text = whisper_full_get_segment_text(ctx, i);
 
-                    printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
-                }
-                printf("\n");
-            } else {
-                const char * text = whisper_full_get_segment_text(ctx, i);
+            printf("%s%s", speaker.c_str(), text);
+        }
 
-                printf("[%s --> %s]  %s%s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), speaker.c_str(), text);
-            }
+        // with timestamps or speakers: each segment on new line
+        if (!params.no_timestamps || params.diarize) {
+            printf("\n");
         }
+
+        fflush(stdout);
     }
 }