]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ref #22 : add "duration" option
authorGeorgi Gerganov <redacted>
Mon, 7 Nov 2022 18:14:52 +0000 (20:14 +0200)
committerGeorgi Gerganov <redacted>
Mon, 7 Nov 2022 18:14:52 +0000 (20:14 +0200)
Can be used to partially process a recording

examples/main/main.cpp
whisper.cpp
whisper.h

index 5907e0b90136a84b74dc9d3051a7c6910a548689..0bac6da4596442382303fa9d193070a32e080d4d 100644 (file)
@@ -53,6 +53,7 @@ struct whisper_params {
     int32_t n_processors = 1;
     int32_t offset_t_ms  = 0;
     int32_t offset_n     = 0;
+    int32_t duration_ms  = 0;
     int32_t max_context  = -1;
     int32_t max_len      = 0;
 
@@ -95,6 +96,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
             params.offset_t_ms = std::stoi(argv[++i]);
         } else if (arg == "-on" || arg == "--offset-n") {
             params.offset_n = std::stoi(argv[++i]);
+        } else if (arg == "-d" || arg == "--duration") {
+            params.duration_ms = std::stoi(argv[++i]);
         } else if (arg == "-mc" || arg == "--max-context") {
             params.max_context = std::stoi(argv[++i]);
         } else if (arg == "-ml" || arg == "--max-len") {
@@ -154,6 +157,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
     fprintf(stderr, "  -p N,     --processors N   number of processors to use during computation (default: %d)\n", params.n_processors);
     fprintf(stderr, "  -ot N,    --offset-t N     time offset in milliseconds (default: %d)\n", params.offset_t_ms);
     fprintf(stderr, "  -on N,    --offset-n N     segment index offset (default: %d)\n", params.offset_n);
+    fprintf(stderr, "  -d  N,    --duration N     duration of audio to process in milliseconds (default: %d)\n", params.duration_ms);
     fprintf(stderr, "  -mc N,    --max-context N  maximum number of text context tokens to store (default: max)\n");
     fprintf(stderr, "  -ml N,    --max-len N      maximum segment length in characters (default: %d)\n", params.max_len);
     fprintf(stderr, "  -wt N,    --word-thold N   word timestamp probability threshold (default: %f)\n", params.word_thold);
@@ -532,6 +536,7 @@ int main(int argc, char ** argv) {
             wparams.n_threads            = params.n_threads;
             wparams.n_max_text_ctx       = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
             wparams.offset_ms            = params.offset_t_ms;
+            wparams.duration_ms          = params.duration_ms;
 
             wparams.token_timestamps     = params.output_wts || params.max_len > 0;
             wparams.thold_pt             = params.word_thold;
index 02ab5cbc8a63e451527f2b1902dc0457a52d9059..7078863aa3e28e68c2bdda802aa43d038bd31081 100644 (file)
@@ -2339,6 +2339,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
                     /*.n_threads            =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
                     /*.n_max_text_ctx       =*/ 16384,
                     /*.offset_ms            =*/ 0,
+                    /*.duration_ms          =*/ 0,
 
                     /*.translate            =*/ false,
                     /*.no_context           =*/ false,
@@ -2376,6 +2377,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
                     /*.n_threads            =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
                     /*.n_max_text_ctx       =*/ 16384,
                     /*.offset_ms            =*/ 0,
+                    /*.duration_ms          =*/ 0,
 
                     /*.translate            =*/ false,
                     /*.no_context           =*/ false,
@@ -2496,11 +2498,12 @@ int whisper_full(
     }
 
     const int seek_start = params.offset_ms/10;
+    const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len(ctx) : params.duration_ms/10);
 
     // if length of spectrogram is less than 1s (100 samples), then return
     // basically don't process anything that is less than 1s
     // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
-    if (whisper_n_len(ctx) < 100 + seek_start) {
+    if (seek_end < 100 + seek_start) {
         return 0;
     }
 
@@ -2533,7 +2536,7 @@ int whisper_full(
     // main loop
     int seek = seek_start;
     while (true) {
-        int progress_cur = (100*seek)/whisper_n_len(ctx);
+        const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
         while (progress_cur >= progress_prev + progress_step) {
             progress_prev += progress_step;
             if (params.print_progress) {
@@ -2541,7 +2544,7 @@ int whisper_full(
             }
         }
 
-        if (seek + 100 >= whisper_n_len(ctx)) {
+        if (seek + 100 >= seek_end) {
             break;
         }
 
@@ -2622,7 +2625,7 @@ int whisper_full(
                 // end of text token
                 if (token.id == whisper_token_eot(ctx)) {
                     if (result_len == 0) {
-                        if (seek + seek_delta + 100 >= whisper_n_len(ctx)) {
+                        if (seek + seek_delta + 100 >= seek_end) {
                             result_len = i + 1;
                         } else {
                             // TODO: figure out how to resolve this
index 57ea5db8bf39578a43798fcbc13159bb93d7df3b..4c112f49c0ca8ca6fa8653ae9f3ff0ee852867a3 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -186,7 +186,8 @@ extern "C" {
 
         int n_threads;
         int n_max_text_ctx;
-        int offset_ms;
+        int offset_ms;      // start offset in ms
+        int duration_ms;    // audio duration to process in ms
 
         bool translate;
         bool no_context;