]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
examples : Implement JSON output for Token-Level data in main (#1358)
authorAarni Koskela <redacted>
Tue, 31 Oct 2023 19:54:52 +0000 (21:54 +0200)
committerGitHub <redacted>
Tue, 31 Oct 2023 19:54:52 +0000 (19:54 +0000)
examples/main/main.cpp

index cdd16ac73a3b5971dc6100d23e835a611a3dd394..bed0789f9ad06ad4b4cd04d2b8bdff586abb1c98 100644 (file)
@@ -83,6 +83,7 @@ struct whisper_params {
     bool output_wts      = false;
     bool output_csv      = false;
     bool output_jsn      = false;
+    bool output_jsn_full = false;
     bool output_lrc      = false;
     bool print_special   = false;
     bool print_colors    = false;
@@ -151,6 +152,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-fp"   || arg == "--font-path")       { params.font_path       = argv[++i]; }
         else if (arg == "-ocsv" || arg == "--output-csv")      { params.output_csv      = true; }
         else if (arg == "-oj"   || arg == "--output-json")     { params.output_jsn      = true; }
+        else if (arg == "-ojf"  || arg == "--output-json-full"){ params.output_jsn_full = params.output_jsn = true; }
         else if (arg == "-of"   || arg == "--output-file")     { params.fname_out.emplace_back(argv[++i]); }
         else if (arg == "-ps"   || arg == "--print-special")   { params.print_special   = true; }
         else if (arg == "-pc"   || arg == "--print-colors")    { params.print_colors    = true; }
@@ -206,6 +208,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -fp,       --font-path         [%-7s] path to a monospace font for karaoke video\n",     params.font_path.c_str());
     fprintf(stderr, "  -ocsv,     --output-csv        [%-7s] output result in a CSV file\n",                    params.output_csv ? "true" : "false");
     fprintf(stderr, "  -oj,       --output-json       [%-7s] output result in a JSON file\n",                   params.output_jsn ? "true" : "false");
+    fprintf(stderr, "  -ojf,      --output-json-full  [%-7s] include more information in the JSON file\n",      params.output_jsn_full ? "true" : "false");
     fprintf(stderr, "  -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n",      "");
     fprintf(stderr, "  -ps,       --print-special     [%-7s] print special tokens\n",                           params.print_special ? "true" : "false");
     fprintf(stderr, "  -pc,       --print-colors      [%-7s] print colors\n",                                   params.print_colors ? "true" : "false");
@@ -511,7 +514,12 @@ bool output_score(struct whisper_context * ctx, const char * fname, const whispe
     return true;
 }
 
-bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
+bool output_json(
+             struct whisper_context * ctx,
+                         const char * fname,
+               const whisper_params & params,
+    std::vector<std::vector<float>>   pcmf32s,
+                               bool   full) {
     std::ofstream fout(fname);
     int indent = 0;
 
@@ -528,7 +536,7 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
     auto end_arr = [&](bool end) {
         indent--;
         doindent();
-        fout << (end ? "]\n" : "},\n");
+        fout << (end ? "]\n" : "],\n");
     };
 
     auto start_obj = [&](const char *name) {
@@ -569,12 +577,29 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
         end_value(end);
     };
 
+    auto value_f = [&](const char *name, const float val, bool end) {
+        start_value(name);
+        fout << val;
+        end_value(end);
+    };
+
     auto value_b = [&](const char *name, const bool val, bool end) {
         start_value(name);
         fout << (val ? "true" : "false");
         end_value(end);
     };
 
+    auto times_o = [&](int64_t t0, int64_t t1, bool end) {
+        start_obj("timestamps");
+        value_s("from", to_timestamp(t0, true).c_str(), false);
+        value_s("to", to_timestamp(t1, true).c_str(), true);
+        end_obj(false);
+        start_obj("offsets");
+        value_i("from", t0 * 10, false);
+        value_i("to", t1 * 10, true);
+        end_obj(end);
+    };
+
     if (!fout.is_open()) {
         fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
         return false;
@@ -620,15 +645,26 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
                 const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
 
                 start_obj(nullptr);
-                    start_obj("timestamps");
-                        value_s("from", to_timestamp(t0, true).c_str(), false);
-                        value_s("to", to_timestamp(t1, true).c_str(), true);
-                    end_obj(false);
-                    start_obj("offsets");
-                        value_i("from", t0 * 10, false);
-                        value_i("to", t1 * 10, true);
-                    end_obj(false);
-                    value_s("text", text, !params.diarize && !params.tinydiarize);
+                    times_o(t0, t1, false);
+                    value_s("text", text, !params.diarize && !params.tinydiarize && !full);
+
+                    if (full) {
+                        start_arr("tokens");
+                        const int n = whisper_full_n_tokens(ctx, i);
+                        for (int j = 0; j < n; ++j) {
+                            auto token = whisper_full_get_token_data(ctx, i, j);
+                            start_obj(nullptr);
+                                value_s("text", whisper_token_to_str(ctx, token.id), false);
+                                if(token.t0 > -1 && token.t1 > -1) {
+                                    // If we have per-token timestamps, write them out
+                                    times_o(token.t0, token.t1, false);
+                                }
+                                value_i("id", token.id, false);
+                                value_f("p", token.p, true);
+                            end_obj(j == (n - 1));
+                        }
+                        end_arr(!params.diarize && !params.tinydiarize);
+                    }
 
                     if (params.diarize && pcmf32s.size() == 2) {
                         value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true);
@@ -912,7 +948,7 @@ int main(int argc, char ** argv) {
             wparams.offset_ms        = params.offset_t_ms;
             wparams.duration_ms      = params.duration_ms;
 
-            wparams.token_timestamps = params.output_wts || params.max_len > 0;
+            wparams.token_timestamps = params.output_wts || params.output_jsn_full || params.max_len > 0;
             wparams.thold_pt         = params.word_thold;
             wparams.max_len          = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
             wparams.split_on_word    = params.split_on_word;
@@ -1012,7 +1048,7 @@ int main(int argc, char ** argv) {
             // output to JSON file
             if (params.output_jsn) {
                 const auto fname_jsn = fname_out + ".json";
-                output_json(ctx, fname_jsn.c_str(), params, pcmf32s);
+                output_json(ctx, fname_jsn.c_str(), params, pcmf32s, params.output_jsn_full);
             }
 
             // output to LRC file