]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
main : log probs to text file (#1205)
authorYunès <redacted>
Sun, 27 Aug 2023 16:09:06 +0000 (18:09 +0200)
committerGitHub <redacted>
Sun, 27 Aug 2023 16:09:06 +0000 (19:09 +0300)
* token/probability file generated with -ls

* code comment cleaning

* main : indentations

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/main/main.cpp

index 4fbc3f69ad22811a046c52714900da3f7fb1ed36..0affdab911678bf98fa803d0921216ff921979dd 100644 (file)
@@ -87,6 +87,7 @@ struct whisper_params {
     bool print_colors    = false;
     bool print_progress  = false;
     bool no_timestamps   = false;
+    bool log_score       = false;
 
     std::string language  = "en";
     std::string prompt;
@@ -159,6 +160,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-m"    || arg == "--model")           { params.model           = argv[++i]; }
         else if (arg == "-f"    || arg == "--file")            { params.fname_inp.emplace_back(argv[++i]); }
         else if (arg == "-oved" || arg == "--ov-e-device")     { params.openvino_encode_device = argv[++i]; }
+        else if (arg == "-ls"   || arg == "--log-score")       { params.log_score = true; }
         else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
             whisper_print_usage(argc, argv, params);
@@ -212,6 +214,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -m FNAME,  --model FNAME       [%-7s] model path\n",                                     params.model.c_str());
     fprintf(stderr, "  -f FNAME,  --file FNAME        [%-7s] input WAV file path\n",                            "");
     fprintf(stderr, "  -oved D,   --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n",  params.openvino_encode_device.c_str());
+    fprintf(stderr, "  -ls,       --log-score         [%-7s] log best decoder scores of tokens\n",              params.log_score?"true":"false");
     fprintf(stderr, "\n");
 }
 
@@ -486,6 +489,25 @@ bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_
     return true;
 }
 
+bool output_score(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
+    std::ofstream fout(fname);
+    fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
+
+    const int n_segments = whisper_full_n_segments(ctx);
+    // fprintf(stderr,"segments: %d\n",n_segments);
+    for (int i = 0; i < n_segments; ++i) {
+        const int n_tokens = whisper_full_n_tokens(ctx, i);
+        // fprintf(stderr,"tokens: %d\n",n_tokens);
+        for (int j = 0; j < n_tokens; j++) {
+            auto token = whisper_full_get_token_text(ctx, i, j);
+            auto probability = whisper_full_get_token_p(ctx, i, j);
+            fout << token << '\t' << probability << std::endl;
+            // fprintf(stderr,"token: %s %f\n",token,probability);
+           }
+    }
+    return true;
+}
+
 bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
     std::ofstream fout(fname);
     int indent = 0;
@@ -982,6 +1004,12 @@ int main(int argc, char ** argv) {
                 const auto fname_lrc = fname_out + ".lrc";
                 output_lrc(ctx, fname_lrc.c_str(), params, pcmf32s);
             }
+
+            // output to score file
+            if (params.log_score) {
+                const auto fname_score = fname_out + ".score.txt";
+                output_score(ctx, fname_score.c_str(), params, pcmf32s);
+            }
         }
     }