]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
node : add flash_attn param (#2170)
authorPedro Probst <redacted>
Mon, 20 May 2024 06:08:48 +0000 (03:08 -0300)
committerGitHub <redacted>
Mon, 20 May 2024 06:08:48 +0000 (09:08 +0300)
examples/addon.node/__test__/whisper.spec.js
examples/addon.node/addon.cpp
examples/addon.node/index.js

index 9ba86b6298542f37d3bc87acd24c5eb2a4dc533c..1ee888a1e009c862ce19d79f770378699a471d49 100644 (file)
@@ -12,6 +12,7 @@ const whisperParamsMock = {
   model: path.join(__dirname, "../../../models/ggml-base.en.bin"),
   fname_inp: path.join(__dirname, "../../../samples/jfk.wav"),
   use_gpu: true,
+  flash_attn: false,
   no_prints: true,
   comma_in_time: false,
   translate: true,
index 8125e5dda4cf8ccf59582b157bc47d28a5bc8d41..53bf1abb5a3ed07704fd5a8a3374712788c3968a 100644 (file)
@@ -39,6 +39,7 @@ struct whisper_params {
     bool no_timestamps  = false;
     bool no_prints      = false;
     bool use_gpu        = true;
+    bool flash_attn     = false;
     bool comma_in_time  = true;
 
     std::string language = "en";
@@ -146,6 +147,7 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
 
     struct whisper_context_params cparams = whisper_context_default_params();
     cparams.use_gpu = params.use_gpu;
+    cparams.flash_attn = params.flash_attn;
     struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
 
     if (ctx == nullptr) {
@@ -326,6 +328,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
   std::string model = whisper_params.Get("model").As<Napi::String>();
   std::string input = whisper_params.Get("fname_inp").As<Napi::String>();
   bool use_gpu = whisper_params.Get("use_gpu").As<Napi::Boolean>();
+  bool flash_attn = whisper_params.Get("flash_attn").As<Napi::Boolean>();
   bool no_prints = whisper_params.Get("no_prints").As<Napi::Boolean>();
   bool no_timestamps = whisper_params.Get("no_timestamps").As<Napi::Boolean>();
   int32_t audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>();
@@ -346,6 +349,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
   params.model = model;
   params.fname_inp.emplace_back(input);
   params.use_gpu = use_gpu;
+  params.flash_attn = flash_attn;
   params.no_prints = no_prints;
   params.no_timestamps = no_timestamps;
   params.audio_ctx = audio_ctx;
index 09b33c540240ba44b6841dae8ed61f72b25867c0..643ee756452de62f347eeace89a7ffbc4d8c7b3f 100644 (file)
@@ -12,6 +12,7 @@ const whisperParams = {
   model: path.join(__dirname, "../../models/ggml-base.en.bin"),
   fname_inp: path.join(__dirname, "../../samples/jfk.wav"),
   use_gpu: true,
+  flash_attn: false,
   no_prints: true,
   comma_in_time: false,
   translate: true,