]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
node : add audio_ctx and audio buffer params (#2123)
authorPedro Probst <redacted>
Mon, 13 May 2024 12:22:23 +0000 (09:22 -0300)
committerGitHub <redacted>
Mon, 13 May 2024 12:22:23 +0000 (15:22 +0300)
* node : add audio_ctx param

* node : support passing audio buffer directly

* node : parse audio_ctx in index.js

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/addon.node/__test__/whisper.spec.js
examples/addon.node/addon.cpp
examples/addon.node/index.js

index 2f264fd3af5f1cf4e7b3a5f2da35fc5c654febbe..9ba86b6298542f37d3bc87acd24c5eb2a4dc533c 100644 (file)
@@ -16,6 +16,7 @@ const whisperParamsMock = {
   comma_in_time: false,
   translate: true,
   no_timestamps: false,
+  audio_ctx: 0,
 };
 
 describe("Run whisper.node", () => {
index 85576311ca939ad9f2761a3c0dfbbd35be92888d..8125e5dda4cf8ccf59582b157bc47d28a5bc8d41 100644 (file)
@@ -19,6 +19,7 @@ struct whisper_params {
     int32_t max_len      = 0;
     int32_t best_of      = 5;
     int32_t beam_size    = -1;
+    int32_t audio_ctx    = 0;
 
     float word_thold    = 0.01f;
     float entropy_thold = 2.4f;
@@ -46,6 +47,8 @@ struct whisper_params {
 
     std::vector<std::string> fname_inp = {};
     std::vector<std::string> fname_out = {};
+
+    std::vector<float> pcmf32 = {}; // mono-channel F32 PCM
 };
 
 struct whisper_print_user_data {
@@ -125,13 +128,12 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
 void cb_log_disable(enum ggml_log_level, const char *, void *) {}
 
 int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
-
     if (params.no_prints) {
         whisper_log_set(cb_log_disable, NULL);
     }
 
-    if (params.fname_inp.empty()) {
-        fprintf(stderr, "error: no input files specified\n");
+    if (params.fname_inp.empty() && params.pcmf32.empty()) {
+        fprintf(stderr, "error: no input files or audio buffer specified\n");
         return 2;
     }
 
@@ -151,6 +153,14 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
         return 3;
     }
 
+    // if params.pcmf32 is provided, set params.fname_inp to "buffer"
+    // this is simpler than further modifications in the code
+    if (!params.pcmf32.empty()) {
+        fprintf(stderr, "info: using audio buffer as input\n");
+        params.fname_inp.clear();
+        params.fname_inp.emplace_back("buffer");
+    }
+
     for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
         const auto fname_inp = params.fname_inp[f];
         const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
@@ -158,9 +168,14 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
         std::vector<float> pcmf32; // mono-channel F32 PCM
         std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
 
-        if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) {
-            fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str());
-            continue;
+        // read the input audio file if params.pcmf32 is not provided
+        if (params.pcmf32.empty()) {
+            if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) {
+                fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str());
+                continue;
+            }
+        } else {
+            pcmf32 = params.pcmf32;
         }
 
         // print system information
@@ -180,12 +195,13 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
                     fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
                 }
             }
-            fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
+            fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d, audio_ctx = %d ...\n",
                     __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
                     params.n_threads, params.n_processors,
                     params.language.c_str(),
                     params.translate ? "translate" : "transcribe",
-                    params.no_timestamps ? 0 : 1);
+                    params.no_timestamps ? 0 : 1,
+                    params.audio_ctx);
 
             fprintf(stderr, "\n");
         }
@@ -212,6 +228,7 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
             wparams.entropy_thold    = params.entropy_thold;
             wparams.logprob_thold    = params.logprob_thold;
             wparams.max_len          = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
+            wparams.audio_ctx        = params.audio_ctx;
 
             wparams.speed_up         = params.speed_up;
 
@@ -311,14 +328,28 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
   bool use_gpu = whisper_params.Get("use_gpu").As<Napi::Boolean>();
   bool no_prints = whisper_params.Get("no_prints").As<Napi::Boolean>();
   bool no_timestamps = whisper_params.Get("no_timestamps").As<Napi::Boolean>();
+  int32_t audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>();
   bool comma_in_time = whisper_params.Get("comma_in_time").As<Napi::Boolean>();
 
+  Napi::Value pcmf32Value = whisper_params.Get("pcmf32");
+  std::vector<float> pcmf32_vec;
+  if (pcmf32Value.IsTypedArray()) {
+    Napi::Float32Array pcmf32 = pcmf32Value.As<Napi::Float32Array>();
+    size_t length = pcmf32.ElementLength();
+    pcmf32_vec.reserve(length);
+    for (size_t i = 0; i < length; i++) {
+      pcmf32_vec.push_back(pcmf32[i]);
+    }
+  }
+
   params.language = language;
   params.model = model;
   params.fname_inp.emplace_back(input);
   params.use_gpu = use_gpu;
   params.no_prints = no_prints;
   params.no_timestamps = no_timestamps;
+  params.audio_ctx = audio_ctx;
+  params.pcmf32 = pcmf32_vec;
   params.comma_in_time = comma_in_time;
 
   Napi::Function callback = info[1].As<Napi::Function>();
index 90bd6fc2ff4b71f043403f95fd41d26d491e1eff..09b33c540240ba44b6841dae8ed61f72b25867c0 100644 (file)
@@ -16,13 +16,20 @@ const whisperParams = {
   comma_in_time: false,
   translate: true,
   no_timestamps: false,
+  audio_ctx: 0,
 };
 
 const arguments = process.argv.slice(2);
 const params = Object.fromEntries(
   arguments.reduce((pre, item) => {
     if (item.startsWith("--")) {
-      return [...pre, item.slice(2).split("=")];
+      const [key, value] = item.slice(2).split("=");
+      if (key === "audio_ctx") {
+        whisperParams[key] = parseInt(value);
+      } else {
+        whisperParams[key] = value;
+      }
+      return pre;
     }
     return pre;
   }, [])