]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ruby : add VAD support, migration to Ruby's newer API (#3197)
authorKITAITI Makoto <redacted>
Wed, 28 May 2025 11:05:12 +0000 (20:05 +0900)
committerGitHub <redacted>
Wed, 28 May 2025 11:05:12 +0000 (20:05 +0900)
* Add VAD models

* Extract function to normalize model path from ruby_whisper_initialize()

* Define ruby_whisper_vad_params struct

* Add VAD-related features to Whisper::Params

* Add tests for VAD-related features

* Define Whisper::VADParams

* Add Whisper::VAD::Params attributes

* Add test suite for VAD::Params

* Make older test to follow namespace change

* Add test for transcription with VAD

* Add assertion for test_vad_params

* Add signatures for VAD-related methods

* Define VAD::Params#==

* Add test for VAD::Params#==

* Fix Params#vad_params

* Add test for Params#vad_params

* Fix signature of Params#vad_params

* Use macro to define VAD::Params params

* Define VAD::Params#initialize

* Add tests for VAD::Params#initialize

* Add signature for VAD::Params.new

* Add documentation on VAD in README

* Wrap register_callbask in prepare_transcription for clear meanings

* Set whisper_params.vad_params just before transcription

* Don't touch NULL

* Define ruby_whisper_params_type

* Use TypedData_XXX for ruby_whisper_params instead of Data_XXX

* Remove unused functions

* Define rb_whisper_model_data_type

* Use TypedData_XXX for ruby_whisper_model instead of Data_XXX

* Define ruby_whisper_segment_type

* Use TypedData_XXX for ruby_whisper_segment instead of Data_XXX

* Define ruby_whisper_type

* Use TypedData_XXX for ruby_whisper instead of Data_XXX

* Qualify with const

14 files changed:
bindings/ruby/README.md
bindings/ruby/ext/ruby_whisper.c
bindings/ruby/ext/ruby_whisper.h
bindings/ruby/ext/ruby_whisper_context.c
bindings/ruby/ext/ruby_whisper_model.c
bindings/ruby/ext/ruby_whisper_params.c
bindings/ruby/ext/ruby_whisper_segment.c
bindings/ruby/ext/ruby_whisper_transcribe.cpp
bindings/ruby/ext/ruby_whisper_vad_params.c [new file with mode: 0644]
bindings/ruby/lib/whisper/model/uri.rb
bindings/ruby/sig/whisper.rbs
bindings/ruby/tests/test_params.rb
bindings/ruby/tests/test_vad.rb [new file with mode: 0644]
bindings/ruby/tests/test_vad_params.rb [new file with mode: 0644]

index 7b1a7f29b45c755ab566db47154bb0d8ef90ea5f..208a89f32cf8b877c6fadc9733e0aecec85aed16 100644 (file)
@@ -111,6 +111,41 @@ See [models][] page for details.
 
 Currently, whisper.cpp accepts only 16-bit WAV files.
 
+### Voice Activity Detection (VAD) ###
+
+Support for Voice Activity Detection (VAD) can be enabled by setting `Whisper::Params`'s `vad` argument to `true` and specifying VAD model:
+
+```ruby
+Whisper::Params.new(
+  vad: true,
+  vad_model_path: "silero-v5.1.2",
+  # other arguments...
+)
+```
+
+When you pass the model name (`"silero-v5.1.2"`) or URI (`https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-silero-v5.1.2.bin`), it will be downloaded automatically.
+Currently, "silero-v5.1.2" is registered as pre-converted model like ASR models. You also specify file path or URI of model.
+
+If you need configure VAD behavior, pass params for that:
+
+```ruby
+Whisper::Params.new(
+  vad: true,
+  vad_model_path: "silero-v5.1.2",
+  vad_params: Whisper::VAD::Params.new(
+    threshold: 1.0, # defaults to 0.5
+    min_speech_duration_ms: 500, # defaults to 250
+    min_silence_duration_ms: 200, # defaults to 100
+    max_speech_duration_s: 30000, # default is FLT_MAX,
+    speech_pad_ms: 50, # defaults to 30
+    samples_overlap: 0.5 # defaults to 0.1
+  ),
+  # other arguments...
+)
+```
+
+For details on VAD, see [whisper.cpp's README](https://github.com/ggml-org/whisper.cpp?tab=readme-ov-file#voice-activity-detection-vad).
+
 API
 ---
 
index 4322778663bbbe912c4ccb070c71f089ab79bcb3..4a83aac9a96c7befa92382b285fd344be8675780 100644 (file)
@@ -3,8 +3,10 @@
 #include "ruby_whisper.h"
 
 VALUE mWhisper;
+VALUE mVAD;
 VALUE cContext;
 VALUE cParams;
+VALUE cVADParams;
 VALUE eError;
 
 VALUE cSegment;
@@ -31,6 +33,7 @@ extern void init_ruby_whisper_params(VALUE *mWhisper);
 extern void init_ruby_whisper_error(VALUE *mWhisper);
 extern void init_ruby_whisper_segment(VALUE *mWhisper, VALUE *cSegment);
 extern void init_ruby_whisper_model(VALUE *mWhisper);
+extern void init_ruby_whisper_vad_params(VALUE *mVAD);
 extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context);
 
 /*
@@ -116,16 +119,6 @@ static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_d
   return Qnil;
 }
 
-static void rb_whisper_model_mark(ruby_whisper_model *rwm) {
-  rb_gc_mark(rwm->context);
-}
-
-static VALUE ruby_whisper_model_allocate(VALUE klass) {
-  ruby_whisper_model *rwm;
-  rwm = ALLOC(ruby_whisper_model);
-  return Data_Wrap_Struct(klass, rb_whisper_model_mark, RUBY_DEFAULT_FREE, rwm);
-}
-
 void Init_whisper() {
   id_to_s = rb_intern("to_s");
   id_call = rb_intern("call");
@@ -139,6 +132,7 @@ void Init_whisper() {
   id_pre_converted_models = rb_intern("pre_converted_models");
 
   mWhisper = rb_define_module("Whisper");
+  mVAD = rb_define_module_under(mWhisper, "VAD");
 
   rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE));
   rb_define_const(mWhisper, "LOG_LEVEL_INFO", INT2NUM(GGML_LOG_LEVEL_INFO));
@@ -159,6 +153,7 @@ void Init_whisper() {
   init_ruby_whisper_error(&mWhisper);
   init_ruby_whisper_segment(&mWhisper, &cContext);
   init_ruby_whisper_model(&mWhisper);
+  init_ruby_whisper_vad_params(&mVAD);
 
   rb_require("whisper/model/uri");
 }
index 6111a151784141797307ecad27ceca21a42bbb58..65b88122ccf29b4ed591198c1187b0cbfb6a7b3e 100644 (file)
@@ -21,8 +21,13 @@ typedef struct {
   ruby_whisper_callback_container *progress_callback_container;
   ruby_whisper_callback_container *encoder_begin_callback_container;
   ruby_whisper_callback_container *abort_callback_container;
+  VALUE vad_params;
 } ruby_whisper_params;
 
+typedef struct {
+  struct whisper_vad_params params;
+} ruby_whisper_vad_params;
+
 typedef struct {
   VALUE context;
   int index;
index df375218dfd9784af53a7557a2cbb18d5d92aa77..c498184e411d22f74ea75e6358afb0005180424f 100644 (file)
@@ -16,10 +16,11 @@ extern VALUE cContext;
 extern VALUE eError;
 extern VALUE cModel;
 
+extern const rb_data_type_t ruby_whisper_params_type;
 extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self);
 extern VALUE rb_whisper_model_initialize(VALUE context);
 extern VALUE rb_whisper_segment_initialize(VALUE context, int index);
-extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context);
+extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context);
 
 static void
 ruby_whisper_free(ruby_whisper *rw)
@@ -37,19 +38,64 @@ rb_whisper_mark(ruby_whisper *rw)
 }
 
 void
-rb_whisper_free(ruby_whisper *rw)
+rb_whisper_free(void *p)
 {
+  ruby_whisper *rw = (ruby_whisper *)p;
   ruby_whisper_free(rw);
   free(rw);
 }
 
+static size_t
+ruby_whisper_memsize(const void *p)
+{
+  const ruby_whisper *rw = (const ruby_whisper *)p;
+  size_t size = sizeof(rw);
+  if (!rw) {
+    return 0;
+  }
+  return size;
+}
+
+const rb_data_type_t ruby_whisper_type = {
+  "ruby_whisper",
+  {0, rb_whisper_free, ruby_whisper_memsize,},
+  0, 0,
+  0
+};
+
 static VALUE
 ruby_whisper_allocate(VALUE klass)
 {
   ruby_whisper *rw;
-  rw = ALLOC(ruby_whisper);
+  VALUE obj = TypedData_Make_Struct(klass, ruby_whisper, &ruby_whisper_type, rw);
   rw->context = NULL;
-  return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw);
+  return obj;
+}
+
+VALUE
+ruby_whisper_normalize_model_path(VALUE model_path)
+{
+  VALUE pre_converted_models = rb_funcall(cModel, id_pre_converted_models, 0);
+  VALUE pre_converted_model = rb_hash_aref(pre_converted_models, model_path);
+  if (!NIL_P(pre_converted_model)) {
+    model_path = pre_converted_model;
+  }
+  else if (TYPE(model_path) == T_STRING) {
+    const char * model_path_str = StringValueCStr(model_path);
+    if (strncmp("http://", model_path_str, 7) == 0 || strncmp("https://", model_path_str, 8) == 0) {
+      VALUE uri_class = rb_const_get(cModel, id_URI);
+      model_path = rb_class_new_instance(1, &model_path, uri_class);
+    }
+  }
+  else if (rb_obj_is_kind_of(model_path, rb_path2class("URI::HTTP"))) {
+    VALUE uri_class = rb_const_get(cModel, id_URI);
+    model_path = rb_class_new_instance(1, &model_path, uri_class);
+  }
+  if (rb_respond_to(model_path, id_to_path)) {
+    model_path = rb_funcall(model_path, id_to_path, 0);
+  }
+
+  return model_path;
 }
 
 /*
@@ -66,27 +112,9 @@ ruby_whisper_initialize(int argc, VALUE *argv, VALUE self)
 
   // TODO: we can support init from buffer here too maybe another ruby object to expose
   rb_scan_args(argc, argv, "01", &whisper_model_file_path);
-  Data_Get_Struct(self, ruby_whisper, rw);
+  TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
 
-  VALUE pre_converted_models = rb_funcall(cModel, id_pre_converted_models, 0);
-  VALUE pre_converted_model = rb_hash_aref(pre_converted_models, whisper_model_file_path);
-  if (!NIL_P(pre_converted_model)) {
-    whisper_model_file_path = pre_converted_model;
-  }
-  if (TYPE(whisper_model_file_path) == T_STRING) {
-    const char * whisper_model_file_path_str = StringValueCStr(whisper_model_file_path);
-    if (strncmp("http://", whisper_model_file_path_str, 7) == 0 || strncmp("https://", whisper_model_file_path_str, 8) == 0) {
-      VALUE uri_class = rb_const_get(cModel, id_URI);
-      whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class);
-    }
-  }
-  if (rb_obj_is_kind_of(whisper_model_file_path, rb_path2class("URI::HTTP"))) {
-    VALUE uri_class = rb_const_get(cModel, id_URI);
-    whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class);
-  }
-  if (rb_respond_to(whisper_model_file_path, id_to_path)) {
-    whisper_model_file_path = rb_funcall(whisper_model_file_path, id_to_path, 0);
-  }
+  whisper_model_file_path = ruby_whisper_normalize_model_path(whisper_model_file_path);
   if (!rb_respond_to(whisper_model_file_path, id_to_s)) {
     rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
   }
@@ -104,7 +132,7 @@ ruby_whisper_initialize(int argc, VALUE *argv, VALUE self)
 VALUE ruby_whisper_model_n_vocab(VALUE self)
 {
   ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
+  TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
   return INT2NUM(whisper_model_n_vocab(rw->context));
 }
 
@@ -115,7 +143,7 @@ VALUE ruby_whisper_model_n_vocab(VALUE self)
 VALUE ruby_whisper_model_n_audio_ctx(VALUE self)
 {
   ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
+  TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
   return INT2NUM(whisper_model_n_audio_ctx(rw->context));
 }
 
@@ -126,7 +154,7 @@ VALUE ruby_whisper_model_n_audio_ctx(VALUE self)
 VALUE ruby_whisper_model_n_audio_state(VALUE self)
 {
   ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
+  TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
   return INT2NUM(whisper_model_n_audio_state(rw->context));
 }
 
@@ -137,7 +165,7 @@ VALUE ruby_whisper_model_n_audio_state(VALUE self)
 VALUE ruby_whisper_model_n_audio_head(VALUE self)
 {
   ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
+  TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
   return INT2NUM(whisper_model_n_audio_head(rw->context));
 }
 
@@ -148,7 +176,7 @@ VALUE ruby_whisper_model_n_audio_head(VALUE self)
 VALUE ruby_whisper_model_n_audio_layer(VALUE self)
 {
   ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
+  TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
   return INT2NUM(whisper_model_n_audio_layer(rw->context));
 }
 
@@ -159,7 +187,7 @@ VALUE ruby_whisper_model_n_audio_layer(VALUE self)
 VALUE ruby_whisper_model_n_text_ctx(VALUE self)
 {
   ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
+  TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
   return INT2NUM(whisper_model_n_text_ctx(rw->context));
 }
 
@@ -170,7 +198,7 @@ VALUE ruby_whisper_model_n_text_ctx(VALUE self)
 VALUE ruby_whisper_model_n_text_state(VALUE self)
 {
   ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
+  TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
   return INT2NUM(whisper_model_n_text_state(rw->context));
 }
 
@@ -181,7 +209,7 @@ VALUE ruby_whisper_model_n_text_state(VALUE self)
 VALUE ruby_whisper_model_n_text_head(VALUE self)
 {
   ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
+  TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
   return INT2NUM(whisper_model_n_text_head(rw->context));
 }
 
@@ -192,7 +220,7 @@ VALUE ruby_whisper_model_n_text_head(VALUE self)
 VALUE ruby_whisper_model_n_text_layer(VALUE self)
 {
   ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
+  TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
   return INT2NUM(whisper_model_n_text_layer(rw->context));
 }
 
@@ -203,7 +231,7 @@ VALUE ruby_whisper_model_n_text_layer(VALUE self)
 VALUE ruby_whisper_model_n_mels(VALUE self)
 {
   ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
+  TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
   return INT2NUM(whisper_model_n_mels(rw->context));
 }
 
@@ -214,7 +242,7 @@ VALUE ruby_whisper_model_n_mels(VALUE self)
 VALUE ruby_whisper_model_ftype(VALUE self)
 {
   ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
+  TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
   return INT2NUM(whisper_model_ftype(rw->context));
 }
 
@@ -225,7 +253,7 @@ VALUE ruby_whisper_model_ftype(VALUE self)
 VALUE ruby_whisper_model_type(VALUE self)
 {
   ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
+  TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
   return rb_str_new2(whisper_model_type_readable(rw->context));
 }
 
@@ -248,9 +276,9 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self)
 
   ruby_whisper *rw;
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper, rw);
+  TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
   VALUE params = argv[0];
-  Data_Get_Struct(params, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   VALUE samples = argv[1];
   int n_samples;
   rb_memory_view_t view;
@@ -296,7 +324,7 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self)
       }
     }
   }
-  register_callbacks(rwp, &self);
+  prepare_transcription(rwp, &self);
   const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples);
   if (0 == result) {
     return self;
@@ -327,9 +355,9 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self)
 
   ruby_whisper *rw;
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper, rw);
+  TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
   VALUE params = argv[0];
-  Data_Get_Struct(params, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   VALUE samples = argv[1];
   int n_samples;
   int n_processors;
@@ -387,7 +415,7 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self)
       }
     }
   }
-  register_callbacks(rwp, &self);
+  prepare_transcription(rwp, &self);
   const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors);
   if (0 == result) {
     return self;
@@ -406,7 +434,7 @@ static VALUE
 ruby_whisper_full_n_segments(VALUE self)
 {
   ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
+  TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
   return INT2NUM(whisper_full_n_segments(rw->context));
 }
 
@@ -420,7 +448,7 @@ static VALUE
 ruby_whisper_full_lang_id(VALUE self)
 {
   ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
+  TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
   return INT2NUM(whisper_full_lang_id(rw->context));
 }
 
@@ -445,7 +473,7 @@ static VALUE
 ruby_whisper_full_get_segment_t0(VALUE self, VALUE i_segment)
 {
   ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
+  TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
   const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
   const int64_t t0 = whisper_full_get_segment_t0(rw->context, c_i_segment);
   return INT2NUM(t0);
@@ -463,7 +491,7 @@ static VALUE
 ruby_whisper_full_get_segment_t1(VALUE self, VALUE i_segment)
 {
   ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
+  TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
   const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
   const int64_t t1 = whisper_full_get_segment_t1(rw->context, c_i_segment);
   return INT2NUM(t1);
@@ -481,7 +509,7 @@ static VALUE
 ruby_whisper_full_get_segment_speaker_turn_next(VALUE self, VALUE i_segment)
 {
   ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
+  TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
   const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
   const bool speaker_turn_next = whisper_full_get_segment_speaker_turn_next(rw->context, c_i_segment);
   return speaker_turn_next ? Qtrue : Qfalse;
@@ -499,7 +527,7 @@ static VALUE
 ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment)
 {
   ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
+  TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
   const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
   const char * text = whisper_full_get_segment_text(rw->context, c_i_segment);
   return rb_str_new2(text);
@@ -513,7 +541,7 @@ static VALUE
 ruby_whisper_full_get_segment_no_speech_prob(VALUE self, VALUE i_segment)
 {
   ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
+  TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
   const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
   const float no_speech_prob = whisper_full_get_segment_no_speech_prob(rw->context, c_i_segment);
   return DBL2NUM(no_speech_prob);
@@ -554,7 +582,7 @@ ruby_whisper_each_segment(VALUE self)
   }
 
   ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
+  TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
 
   const int n_segments = whisper_full_n_segments(rw->context);
   for (int i = 0; i < n_segments; ++i) {
index 1e0648fd54bfed385385ec6a9e05b9cdb7c08947..54763c92da0d575a1355429a62c53bc759a22df2 100644 (file)
@@ -1,22 +1,44 @@
 #include <ruby.h>
 #include "ruby_whisper.h"
 
+extern const rb_data_type_t ruby_whisper_type;
+
 extern VALUE cModel;
 
-static void rb_whisper_model_mark(ruby_whisper_model *rwm) {
-  rb_gc_mark(rwm->context);
+static void rb_whisper_model_mark(void *p) {
+  ruby_whisper_model *rwm = (ruby_whisper_model *)p;
+  if (rwm->context) {
+    rb_gc_mark(rwm->context);
+  }
+}
+
+static size_t
+ruby_whisper_model_memsize(const void *p)
+{
+  const ruby_whisper_model *rwm = (const ruby_whisper_model *)p;
+  size_t size = sizeof(rwm);
+  if (!rwm) {
+    return 0;
+  }
+  return size;
 }
 
+static const rb_data_type_t rb_whisper_model_type = {
+  "ruby_whisper_model",
+  {rb_whisper_model_mark, RUBY_DEFAULT_FREE, ruby_whisper_model_memsize,},
+  0, 0,
+  0
+};
+
 static VALUE ruby_whisper_model_allocate(VALUE klass) {
   ruby_whisper_model *rwm;
-  rwm = ALLOC(ruby_whisper_model);
-  return Data_Wrap_Struct(klass, rb_whisper_model_mark, RUBY_DEFAULT_FREE, rwm);
+  return TypedData_Make_Struct(klass, ruby_whisper_model, &rb_whisper_model_type, rwm);
 }
 
 VALUE rb_whisper_model_initialize(VALUE context) {
   ruby_whisper_model *rwm;
   const VALUE model = ruby_whisper_model_allocate(cModel);
-  Data_Get_Struct(model, ruby_whisper_model, rwm);
+  TypedData_Get_Struct(model, ruby_whisper_model, &rb_whisper_model_type, rwm);
   rwm->context = context;
   return model;
 };
@@ -29,9 +51,9 @@ static VALUE
 ruby_whisper_model_n_vocab(VALUE self)
 {
   ruby_whisper_model *rwm;
-  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
   ruby_whisper *rw;
-  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
   return INT2NUM(whisper_model_n_vocab(rw->context));
 }
 
@@ -43,9 +65,9 @@ static VALUE
 ruby_whisper_model_n_audio_ctx(VALUE self)
 {
   ruby_whisper_model *rwm;
-  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
   ruby_whisper *rw;
-  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
   return INT2NUM(whisper_model_n_audio_ctx(rw->context));
 }
 
@@ -57,9 +79,9 @@ static VALUE
 ruby_whisper_model_n_audio_state(VALUE self)
 {
   ruby_whisper_model *rwm;
-  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
   ruby_whisper *rw;
-  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
   return INT2NUM(whisper_model_n_audio_state(rw->context));
 }
 
@@ -71,9 +93,9 @@ static VALUE
 ruby_whisper_model_n_audio_head(VALUE self)
 {
   ruby_whisper_model *rwm;
-  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
   ruby_whisper *rw;
-  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
   return INT2NUM(whisper_model_n_audio_head(rw->context));
 }
 
@@ -85,9 +107,9 @@ static VALUE
 ruby_whisper_model_n_audio_layer(VALUE self)
 {
   ruby_whisper_model *rwm;
-  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
   ruby_whisper *rw;
-  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
   return INT2NUM(whisper_model_n_audio_layer(rw->context));
 }
 
@@ -99,9 +121,9 @@ static VALUE
 ruby_whisper_model_n_text_ctx(VALUE self)
 {
   ruby_whisper_model *rwm;
-  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
   ruby_whisper *rw;
-  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
   return INT2NUM(whisper_model_n_text_ctx(rw->context));
 }
 
@@ -113,9 +135,9 @@ static VALUE
 ruby_whisper_model_n_text_state(VALUE self)
 {
   ruby_whisper_model *rwm;
-  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
   ruby_whisper *rw;
-  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
   return INT2NUM(whisper_model_n_text_state(rw->context));
 }
 
@@ -127,9 +149,9 @@ static VALUE
 ruby_whisper_model_n_text_head(VALUE self)
 {
   ruby_whisper_model *rwm;
-  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
   ruby_whisper *rw;
-  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
   return INT2NUM(whisper_model_n_text_head(rw->context));
 }
 
@@ -141,9 +163,9 @@ static VALUE
 ruby_whisper_model_n_text_layer(VALUE self)
 {
   ruby_whisper_model *rwm;
-  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
   ruby_whisper *rw;
-  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
   return INT2NUM(whisper_model_n_text_layer(rw->context));
 }
 
@@ -155,9 +177,9 @@ static VALUE
 ruby_whisper_model_n_mels(VALUE self)
 {
   ruby_whisper_model *rwm;
-  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
   ruby_whisper *rw;
-  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
   return INT2NUM(whisper_model_n_mels(rw->context));
 }
 
@@ -169,9 +191,9 @@ static VALUE
 ruby_whisper_model_ftype(VALUE self)
 {
   ruby_whisper_model *rwm;
-  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
   ruby_whisper *rw;
-  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
   return INT2NUM(whisper_model_ftype(rw->context));
 }
 
@@ -183,9 +205,9 @@ static VALUE
 ruby_whisper_model_type(VALUE self)
 {
   ruby_whisper_model *rwm;
-  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
   ruby_whisper *rw;
-  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
   return rb_str_new2(whisper_model_type_readable(rw->context));
 }
 
index c07f2372f16722b0694a89388605d5889d842ac8..4a65c92a80d4c9884b880ac78d643ad4bcf7c765 100644 (file)
@@ -3,7 +3,7 @@
 
 #define BOOL_PARAMS_SETTER(self, prop, value) \
   ruby_whisper_params *rwp; \
-  Data_Get_Struct(self, ruby_whisper_params, rwp); \
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); \
   if (value == Qfalse || value == Qnil) { \
     rwp->params.prop = false; \
   } else { \
@@ -13,7 +13,7 @@
 
 #define BOOL_PARAMS_GETTER(self,  prop) \
   ruby_whisper_params *rwp; \
-  Data_Get_Struct(self, ruby_whisper_params, rwp); \
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); \
   if (rwp->params.prop) { \
     return Qtrue; \
   } else { \
   rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \
   rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1);
 
-#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 32
+#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 35
 
 extern VALUE cParams;
+extern VALUE cVADParams;
 
 extern ID id_call;
 
+extern VALUE ruby_whisper_normalize_model_path(VALUE model_path);
 extern VALUE rb_whisper_segment_initialize(VALUE context, int index);
+extern const rb_data_type_t ruby_whisper_vad_params_type;
 
 static ID param_names[RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT];
 static ID id_language;
@@ -67,6 +70,9 @@ static ID id_encoder_begin_callback;
 static ID id_encoder_begin_callback_user_data;
 static ID id_abort_callback;
 static ID id_abort_callback_user_data;
+static ID id_vad;
+static ID id_vad_model_path;
+static ID id_vad_params;
 
 static void
 rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc)
@@ -177,7 +183,7 @@ static bool abort_callback(void * user_data) {
   return false;
 }
 
-void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
+static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
   if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) {
     rwp->new_segment_callback_container->context = context;
     rwp->params.new_segment_callback = new_segment_callback;
@@ -203,13 +209,29 @@ void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
   }
 }
 
+static void set_vad_params(ruby_whisper_params *rwp)
+{
+  ruby_whisper_vad_params * rwvp;
+  TypedData_Get_Struct(rwp->vad_params, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
+  rwp->params.vad_params = rwvp->params;
+}
+
+void
+prepare_transcription(ruby_whisper_params *rwp, VALUE *context)
+{
+  register_callbacks(rwp, context);
+  set_vad_params(rwp);
+}
+
 void
-rb_whisper_params_mark(ruby_whisper_params *rwp)
+rb_whisper_params_mark(void *p)
 {
+  ruby_whisper_params *rwp = (ruby_whisper_params *)p;
   rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container);
   rb_whisper_callbcack_container_mark(rwp->progress_callback_container);
   rb_whisper_callbcack_container_mark(rwp->encoder_begin_callback_container);
   rb_whisper_callbcack_container_mark(rwp->abort_callback_container);
+  rb_gc_mark(rwp->vad_params);
 }
 
 void
@@ -218,25 +240,46 @@ ruby_whisper_params_free(ruby_whisper_params *rwp)
 }
 
 void
-rb_whisper_params_free(ruby_whisper_params *rwp)
+rb_whisper_params_free(void *p)
 {
+  ruby_whisper_params *rwp = (ruby_whisper_params *)p;
   // How to free user_data and callback only when not referred to by others?
   ruby_whisper_params_free(rwp);
   free(rwp);
 }
 
+static size_t
+ruby_whisper_params_memsize(const void *p)
+{
+  const ruby_whisper_params *rwp = (const ruby_whisper_params *)p;
+
+  return sizeof(ruby_whisper_params) + sizeof(rwp->params) + sizeof(rwp->vad_params);
+}
+
+const rb_data_type_t ruby_whisper_params_type = {
+  "ruby_whisper_params",
+  {
+    rb_whisper_params_mark,
+    rb_whisper_params_free,
+    ruby_whisper_params_memsize,
+  },
+  0, 0,
+  0
+};
+
 static VALUE
 ruby_whisper_params_allocate(VALUE klass)
 {
   ruby_whisper_params *rwp;
-  rwp = ALLOC(ruby_whisper_params);
+  VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
   rwp->diarize = false;
+  rwp->vad_params = TypedData_Wrap_Struct(cVADParams, &ruby_whisper_vad_params_type, (void *)&rwp->params.vad_params);
   rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
   rwp->progress_callback_container = rb_whisper_callback_container_allocate();
   rwp->encoder_begin_callback_container = rb_whisper_callback_container_allocate();
   rwp->abort_callback_container = rb_whisper_callback_container_allocate();
-  return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
+  return obj;
 }
 
 /*
@@ -249,7 +292,7 @@ static VALUE
 ruby_whisper_params_set_language(VALUE self, VALUE value)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   if (value == Qfalse || value == Qnil) {
     rwp->params.language = "auto";
   } else {
@@ -265,7 +308,7 @@ static VALUE
 ruby_whisper_params_get_language(VALUE self)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   if (rwp->params.language) {
     return rb_str_new2(rwp->params.language);
   } else {
@@ -502,7 +545,7 @@ static VALUE
 ruby_whisper_params_get_initial_prompt(VALUE self)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   return rwp->params.initial_prompt == NULL ? Qnil : rb_str_new2(rwp->params.initial_prompt);
 }
 /*
@@ -513,7 +556,7 @@ static VALUE
 ruby_whisper_params_set_initial_prompt(VALUE self, VALUE value)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   rwp->params.initial_prompt = StringValueCStr(value);
   return value;
 }
@@ -527,7 +570,7 @@ static VALUE
 ruby_whisper_params_get_diarize(VALUE self)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   if (rwp->diarize) {
     return Qtrue;
   } else {
@@ -542,7 +585,7 @@ static VALUE
 ruby_whisper_params_set_diarize(VALUE self, VALUE value)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   if (value == Qfalse || value == Qnil) {
     rwp->diarize = false;
   } else {
@@ -561,7 +604,7 @@ static VALUE
 ruby_whisper_params_get_offset(VALUE self)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   return INT2NUM(rwp->params.offset_ms);
 }
 /*
@@ -572,7 +615,7 @@ static VALUE
 ruby_whisper_params_set_offset(VALUE self, VALUE value)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   rwp->params.offset_ms = NUM2INT(value);
   return value;
 }
@@ -586,7 +629,7 @@ static VALUE
 ruby_whisper_params_get_duration(VALUE self)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   return INT2NUM(rwp->params.duration_ms);
 }
 /*
@@ -597,7 +640,7 @@ static VALUE
 ruby_whisper_params_set_duration(VALUE self, VALUE value)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   rwp->params.duration_ms = NUM2INT(value);
   return value;
 }
@@ -612,7 +655,7 @@ static VALUE
 ruby_whisper_params_get_max_text_tokens(VALUE self)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   return INT2NUM(rwp->params.n_max_text_ctx);
 }
 /*
@@ -623,7 +666,7 @@ static VALUE
 ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   rwp->params.n_max_text_ctx = NUM2INT(value);
   return value;
 }
@@ -635,7 +678,7 @@ static VALUE
 ruby_whisper_params_get_temperature(VALUE self)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   return DBL2NUM(rwp->params.temperature);
 }
 /*
@@ -646,7 +689,7 @@ static VALUE
 ruby_whisper_params_set_temperature(VALUE self, VALUE value)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   rwp->params.temperature = RFLOAT_VALUE(value);
   return value;
 }
@@ -660,7 +703,7 @@ static VALUE
 ruby_whisper_params_get_max_initial_ts(VALUE self)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   return DBL2NUM(rwp->params.max_initial_ts);
 }
 /*
@@ -671,7 +714,7 @@ static VALUE
 ruby_whisper_params_set_max_initial_ts(VALUE self, VALUE value)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   rwp->params.max_initial_ts = RFLOAT_VALUE(value);
   return value;
 }
@@ -683,7 +726,7 @@ static VALUE
 ruby_whisper_params_get_length_penalty(VALUE self)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   return DBL2NUM(rwp->params.length_penalty);
 }
 /*
@@ -694,7 +737,7 @@ static VALUE
 ruby_whisper_params_set_length_penalty(VALUE self, VALUE value)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   rwp->params.length_penalty = RFLOAT_VALUE(value);
   return value;
 }
@@ -706,7 +749,7 @@ static VALUE
 ruby_whisper_params_get_temperature_inc(VALUE self)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   return DBL2NUM(rwp->params.temperature_inc);
 }
 /*
@@ -717,7 +760,7 @@ static VALUE
 ruby_whisper_params_set_temperature_inc(VALUE self, VALUE value)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   rwp->params.temperature_inc = RFLOAT_VALUE(value);
   return value;
 }
@@ -731,7 +774,7 @@ static VALUE
 ruby_whisper_params_get_entropy_thold(VALUE self)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   return DBL2NUM(rwp->params.entropy_thold);
 }
 /*
@@ -742,7 +785,7 @@ static VALUE
 ruby_whisper_params_set_entropy_thold(VALUE self, VALUE value)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   rwp->params.entropy_thold = RFLOAT_VALUE(value);
   return value;
 }
@@ -754,7 +797,7 @@ static VALUE
 ruby_whisper_params_get_logprob_thold(VALUE self)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   return DBL2NUM(rwp->params.logprob_thold);
 }
 /*
@@ -765,7 +808,7 @@ static VALUE
 ruby_whisper_params_set_logprob_thold(VALUE self, VALUE value)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   rwp->params.logprob_thold = RFLOAT_VALUE(value);
   return value;
 }
@@ -777,7 +820,7 @@ static VALUE
 ruby_whisper_params_get_no_speech_thold(VALUE self)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   return DBL2NUM(rwp->params.no_speech_thold);
 }
 /*
@@ -788,7 +831,7 @@ static VALUE
 ruby_whisper_params_set_no_speech_thold(VALUE self, VALUE value)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   rwp->params.no_speech_thold = RFLOAT_VALUE(value);
   return value;
 }
@@ -796,7 +839,7 @@ static VALUE
 ruby_whisper_params_get_new_segment_callback(VALUE self)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   return rwp->new_segment_callback_container->callback;
 }
 /*
@@ -813,7 +856,7 @@ static VALUE
 ruby_whisper_params_set_new_segment_callback(VALUE self, VALUE value)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   rwp->new_segment_callback_container->callback = value;
   return value;
 }
@@ -821,7 +864,7 @@ static VALUE
 ruby_whisper_params_get_new_segment_callback_user_data(VALUE self)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   return rwp->new_segment_callback_container->user_data;
 }
 /*
@@ -834,7 +877,7 @@ static VALUE
 ruby_whisper_params_set_new_segment_callback_user_data(VALUE self, VALUE value)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   rwp->new_segment_callback_container->user_data = value;
   return value;
 }
@@ -842,7 +885,7 @@ static VALUE
 ruby_whisper_params_get_progress_callback(VALUE self)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   return rwp->progress_callback_container->callback;
 }
 /*
@@ -861,7 +904,7 @@ static VALUE
 ruby_whisper_params_set_progress_callback(VALUE self, VALUE value)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   rwp->progress_callback_container->callback = value;
   return value;
 }
@@ -869,7 +912,7 @@ static VALUE
 ruby_whisper_params_get_progress_callback_user_data(VALUE self)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   return rwp->progress_callback_container->user_data;
 }
 /*
@@ -882,7 +925,7 @@ static VALUE
 ruby_whisper_params_set_progress_callback_user_data(VALUE self, VALUE value)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   rwp->progress_callback_container->user_data = value;
   return value;
 }
@@ -891,7 +934,7 @@ static VALUE
 ruby_whisper_params_get_encoder_begin_callback(VALUE self)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   return rwp->encoder_begin_callback_container->callback;
 }
 
@@ -909,7 +952,7 @@ static VALUE
 ruby_whisper_params_set_encoder_begin_callback(VALUE self, VALUE value)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   rwp->encoder_begin_callback_container->callback = value;
   return value;
 }
@@ -918,7 +961,7 @@ static VALUE
 ruby_whisper_params_get_encoder_begin_callback_user_data(VALUE self)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   return rwp->encoder_begin_callback_container->user_data;
 }
 
@@ -932,7 +975,7 @@ static VALUE
 ruby_whisper_params_set_encoder_begin_callback_user_data(VALUE self, VALUE value)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   rwp->encoder_begin_callback_container->user_data = value;
   return value;
 }
@@ -941,7 +984,7 @@ static VALUE
 ruby_whisper_params_get_abort_callback(VALUE self)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   return rwp->abort_callback_container->callback;
 }
 /*
@@ -958,7 +1001,7 @@ static VALUE
 ruby_whisper_params_set_abort_callback(VALUE self, VALUE value)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   rwp->abort_callback_container->callback = value;
   return value;
 }
@@ -966,7 +1009,7 @@ static VALUE
 ruby_whisper_params_get_abort_callback_user_data(VALUE self)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   return rwp->abort_callback_container->user_data;
 }
 /*
@@ -979,11 +1022,74 @@ static VALUE
 ruby_whisper_params_set_abort_callback_user_data(VALUE self, VALUE value)
 {
   ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   rwp->abort_callback_container->user_data = value;
   return value;
 }
 
+/*
+ * call-seq:
+ *   vad = use_vad -> use_vad
+ */
+static VALUE
+ruby_whisper_params_get_vad(VALUE self)
+{
+  BOOL_PARAMS_GETTER(self, vad)
+}
+
+static VALUE
+ruby_whisper_params_set_vad(VALUE self, VALUE value)
+{
+  BOOL_PARAMS_SETTER(self, vad, value)
+}
+
+/*
+ * call-seq:
+ *   vad_model_path = model_path -> model_path
+ */
+static VALUE
+ruby_whisper_params_set_vad_model_path(VALUE self, VALUE value)
+{
+  ruby_whisper_params *rwp;
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
+  if (NIL_P(value)) {
+    rwp->params.vad_model_path = NULL;
+    return value;
+  }
+  VALUE path = ruby_whisper_normalize_model_path(value);
+  rwp->params.vad_model_path = StringValueCStr(path);
+  return value;
+}
+
+static VALUE
+ruby_whisper_params_get_vad_model_path(VALUE self)
+{
+  ruby_whisper_params *rwp;
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
+  return rwp->params.vad_model_path == NULL ? Qnil : rb_str_new2(rwp->params.vad_model_path);
+}
+
+/*
+ * call-seq:
+ *   vad_params = params -> params
+ */
+static VALUE
+ruby_whisper_params_set_vad_params(VALUE self, VALUE value)
+{
+  ruby_whisper_params *rwp;
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
+  rwp->vad_params = value;
+  return value;
+}
+
+static VALUE
+ruby_whisper_params_get_vad_params(VALUE self)
+{
+  ruby_whisper_params *rwp;
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
+  return rwp->vad_params;
+}
+
 #define SET_PARAM_IF_SAME(param_name) \
   if (id == id_ ## param_name) { \
     ruby_whisper_params_set_ ## param_name(self, value); \
@@ -993,7 +1099,6 @@ ruby_whisper_params_set_abort_callback_user_data(VALUE self, VALUE value)
 static VALUE
 ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
 {
-
   VALUE kw_hash;
   VALUE values[RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT] = {Qundef};
   VALUE value;
@@ -1007,7 +1112,7 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
   }
 
   rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT, values);
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
 
   for (i = 0; i < RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT; i++) {
     id = param_names[i];
@@ -1050,6 +1155,9 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
       SET_PARAM_IF_SAME(encoder_begin_callback_user_data)
       SET_PARAM_IF_SAME(abort_callback)
       SET_PARAM_IF_SAME(abort_callback_user_data)
+      SET_PARAM_IF_SAME(vad)
+      SET_PARAM_IF_SAME(vad_model_path)
+      SET_PARAM_IF_SAME(vad_params)
     }
   }
 
@@ -1071,10 +1179,10 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
 static VALUE
 ruby_whisper_params_on_new_segment(VALUE self)
 {
-  ruby_whisper_params *rws;
-  Data_Get_Struct(self, ruby_whisper_params, rws);
+  ruby_whisper_params *rwp;
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   const VALUE blk = rb_block_proc();
-  rb_ary_push(rws->new_segment_callback_container->callbacks, blk);
+  rb_ary_push(rwp->new_segment_callback_container->callbacks, blk);
   return Qnil;
 }
 
@@ -1091,10 +1199,10 @@ ruby_whisper_params_on_new_segment(VALUE self)
 static VALUE
 ruby_whisper_params_on_progress(VALUE self)
 {
-  ruby_whisper_params *rws;
-  Data_Get_Struct(self, ruby_whisper_params, rws);
+  ruby_whisper_params *rwp;
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   const VALUE blk = rb_block_proc();
-  rb_ary_push(rws->progress_callback_container->callbacks, blk);
+  rb_ary_push(rwp->progress_callback_container->callbacks, blk);
   return Qnil;
 }
 
@@ -1111,10 +1219,10 @@ ruby_whisper_params_on_progress(VALUE self)
 static VALUE
 ruby_whisper_params_on_encoder_begin(VALUE self)
 {
-  ruby_whisper_params *rws;
-  Data_Get_Struct(self, ruby_whisper_params, rws);
+  ruby_whisper_params *rwp;
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   const VALUE blk = rb_block_proc();
-  rb_ary_push(rws->encoder_begin_callback_container->callbacks, blk);
+  rb_ary_push(rwp->encoder_begin_callback_container->callbacks, blk);
   return Qnil;
 }
 
@@ -1135,10 +1243,10 @@ ruby_whisper_params_on_encoder_begin(VALUE self)
 static VALUE
 ruby_whisper_params_abort_on(VALUE self)
 {
-  ruby_whisper_params *rws;
-  Data_Get_Struct(self, ruby_whisper_params, rws);
+  ruby_whisper_params *rwp;
+  TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
   const VALUE blk = rb_block_proc();
-  rb_ary_push(rws->abort_callback_container->callbacks, blk);
+  rb_ary_push(rwp->abort_callback_container->callbacks, blk);
   return Qnil;
 }
 
@@ -1182,6 +1290,9 @@ init_ruby_whisper_params(VALUE *mWhisper)
   DEFINE_PARAM(encoder_begin_callback_user_data, 29)
   DEFINE_PARAM(abort_callback, 30)
   DEFINE_PARAM(abort_callback_user_data, 31)
+  DEFINE_PARAM(vad, 32)
+  DEFINE_PARAM(vad_model_path, 33)
+  DEFINE_PARAM(vad_params, 34)
 
   rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
   rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
index 3440ff95fba2cfca5705106ff364b22902415d2f..9399f2863d6e8e50bd940ef01ceee1b8162d6f6f 100644 (file)
@@ -1,20 +1,40 @@
 #include <ruby.h>
 #include "ruby_whisper.h"
 
+extern const rb_data_type_t ruby_whisper_type;
+
 extern VALUE cSegment;
 
 static void
-rb_whisper_segment_mark(ruby_whisper_segment *rws)
+rb_whisper_segment_mark(void *p)
 {
+  ruby_whisper_segment *rws = (ruby_whisper_segment *)p;
   rb_gc_mark(rws->context);
 }
 
+static size_t
+ruby_whisper_segment_memsize(const void *p)
+{
+  const ruby_whisper_segment *rws = (const ruby_whisper_segment *)p;
+  size_t size = sizeof(rws);
+  if (!rws) {
+    return 0;
+  }
+  return size;
+}
+
+static const rb_data_type_t ruby_whisper_segment_type = {
+  "ruby_whisper_segment",
+  {rb_whisper_segment_mark, RUBY_DEFAULT_FREE, ruby_whisper_segment_memsize,},
+  0, 0,
+  0
+};
+
 VALUE
 ruby_whisper_segment_allocate(VALUE klass)
 {
   ruby_whisper_segment *rws;
-  rws = ALLOC(ruby_whisper_segment);
-  return Data_Wrap_Struct(klass, rb_whisper_segment_mark, RUBY_DEFAULT_FREE, rws);
+  return TypedData_Make_Struct(klass, ruby_whisper_segment, &ruby_whisper_segment_type, rws);
 }
 
 VALUE
@@ -22,7 +42,7 @@ rb_whisper_segment_initialize(VALUE context, int index)
 {
   ruby_whisper_segment *rws;
   const VALUE segment = ruby_whisper_segment_allocate(cSegment);
-  Data_Get_Struct(segment, ruby_whisper_segment, rws);
+  TypedData_Get_Struct(segment, ruby_whisper_segment, &ruby_whisper_segment_type, rws);
   rws->context = context;
   rws->index = index;
   return segment;
@@ -38,9 +58,9 @@ static VALUE
 ruby_whisper_segment_get_start_time(VALUE self)
 {
   ruby_whisper_segment *rws;
-  Data_Get_Struct(self, ruby_whisper_segment, rws);
+  TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws);
   ruby_whisper *rw;
-  Data_Get_Struct(rws->context, ruby_whisper, rw);
+  TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw);
   const int64_t t0 = whisper_full_get_segment_t0(rw->context, rws->index);
   // able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it
   return INT2NUM(t0 * 10);
@@ -56,9 +76,9 @@ static VALUE
 ruby_whisper_segment_get_end_time(VALUE self)
 {
   ruby_whisper_segment *rws;
-  Data_Get_Struct(self, ruby_whisper_segment, rws);
+  TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws);
   ruby_whisper *rw;
-  Data_Get_Struct(rws->context, ruby_whisper, rw);
+  TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw);
   const int64_t t1 = whisper_full_get_segment_t1(rw->context, rws->index);
   // able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it
   return INT2NUM(t1 * 10);
@@ -74,9 +94,9 @@ static VALUE
 ruby_whisper_segment_get_speaker_turn_next(VALUE self)
 {
   ruby_whisper_segment *rws;
-  Data_Get_Struct(self, ruby_whisper_segment, rws);
+  TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws);
   ruby_whisper *rw;
-  Data_Get_Struct(rws->context, ruby_whisper, rw);
+  TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw);
   return whisper_full_get_segment_speaker_turn_next(rw->context, rws->index) ? Qtrue : Qfalse;
 }
 
@@ -88,9 +108,9 @@ static VALUE
 ruby_whisper_segment_get_text(VALUE self)
 {
   ruby_whisper_segment *rws;
-  Data_Get_Struct(self, ruby_whisper_segment, rws);
+  TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws);
   ruby_whisper *rw;
-  Data_Get_Struct(rws->context, ruby_whisper, rw);
+  TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw);
   const char * text = whisper_full_get_segment_text(rw->context, rws->index);
   return rb_str_new2(text);
 }
@@ -103,9 +123,9 @@ static VALUE
 ruby_whisper_segment_get_no_speech_prob(VALUE self)
 {
   ruby_whisper_segment *rws;
-  Data_Get_Struct(self, ruby_whisper_segment, rws);
+  TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws);
   ruby_whisper *rw;
-  Data_Get_Struct(rws->context, ruby_whisper, rw);
+  TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw);
   return DBL2NUM(whisper_full_get_segment_no_speech_prob(rw->context, rws->index));
 }
 
index ef3c0780f45c8a5dbb4d3297ef63bf7a9a4c1068..d12d2de96fe63490de4928df04f366702e063a13 100644 (file)
@@ -8,11 +8,14 @@
 extern "C" {
 #endif
 
+extern const rb_data_type_t ruby_whisper_type;
+extern const rb_data_type_t ruby_whisper_params_type;
+
 extern ID id_to_s;
 extern ID id_call;
 
 extern void
-register_callbacks(ruby_whisper_params * rwp, VALUE * self);
+prepare_transcription(ruby_whisper_params * rwp, VALUE * self);
 
 /*
  * transcribe a single file
@@ -34,8 +37,8 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
   VALUE wave_file_path, blk, params;
 
   rb_scan_args(argc, argv, "02&", &wave_file_path, &params, &blk);
-  Data_Get_Struct(self, ruby_whisper, rw);
-  Data_Get_Struct(params, ruby_whisper_params, rwp);
+  TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
+  TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
 
   if (!rb_respond_to(wave_file_path, id_to_s)) {
     rb_raise(rb_eRuntimeError, "Expected file path to wave file");
@@ -61,7 +64,7 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
   //   rwp->params.encoder_begin_callback_user_data = &is_aborted;
   // }
 
-  register_callbacks(rwp, &self);
+  prepare_transcription(rwp, &self);
 
   if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
     fprintf(stderr, "failed to process audio\n");
diff --git a/bindings/ruby/ext/ruby_whisper_vad_params.c b/bindings/ruby/ext/ruby_whisper_vad_params.c
new file mode 100644 (file)
index 0000000..be7bc46
--- /dev/null
@@ -0,0 +1,288 @@
+#include <ruby.h>
+#include "ruby_whisper.h"
+
+#define DEFINE_PARAM(param_name, nth) \
+  id_ ## param_name = rb_intern(#param_name); \
+  param_names[nth] = id_ ## param_name; \
+  rb_define_method(cVADParams, #param_name, ruby_whisper_vad_params_get_ ## param_name, 0); \
+  rb_define_method(cVADParams, #param_name "=", ruby_whisper_vad_params_set_ ## param_name, 1);
+
+#define NUM_PARAMS 6
+
+extern VALUE cVADParams;
+
+static size_t
+ruby_whisper_vad_params_memsize(const void *p)
+{
+  const struct ruby_whisper_vad_params *params = p;
+  size_t size = sizeof(params);
+  if (!params) {
+    return 0;
+  }
+  return size;
+}
+
+static ID param_names[NUM_PARAMS];
+static ID id_threshold;
+static ID id_min_speech_duration_ms;
+static ID id_min_silence_duration_ms;
+static ID id_max_speech_duration_s;
+static ID id_speech_pad_ms;
+static ID id_samples_overlap;
+
+const rb_data_type_t ruby_whisper_vad_params_type = {
+  "ruby_whisper_vad_params",
+  {0, 0, ruby_whisper_vad_params_memsize,},
+  0, 0,
+  0
+};
+
+static VALUE
+ruby_whisper_vad_params_s_allocate(VALUE klass)
+{
+  ruby_whisper_vad_params *rwvp;
+  VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
+  rwvp->params = whisper_vad_default_params();
+  return obj;
+}
+
+/*
+ * Probability threshold to consider as speech.
+ *
+ * call-seq:
+ *   threshold = th -> th
+ */
+static VALUE
+ruby_whisper_vad_params_set_threshold(VALUE self, VALUE value)
+{
+  ruby_whisper_vad_params *rwvp;
+  TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
+  rwvp->params.threshold = RFLOAT_VALUE(value);
+  return value;
+}
+
+static VALUE
+ruby_whisper_vad_params_get_threshold(VALUE self)
+{
+  ruby_whisper_vad_params *rwvp;
+  TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
+  return DBL2NUM(rwvp->params.threshold);
+}
+
+/*
+ * Min duration for a valid speech segment.
+ *
+ * call-seq:
+ *   min_speech_duration_ms = duration_ms -> duration_ms
+ */
+static VALUE
+ruby_whisper_vad_params_set_min_speech_duration_ms(VALUE self, VALUE value)
+{
+  ruby_whisper_vad_params *rwvp;
+  TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
+  rwvp->params.min_speech_duration_ms = NUM2INT(value);
+  return value;
+}
+
+static VALUE
+ruby_whisper_vad_params_get_min_speech_duration_ms(VALUE self)
+{
+  ruby_whisper_vad_params *rwvp;
+  TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
+  return INT2NUM(rwvp->params.min_speech_duration_ms);
+}
+
+/*
+ * Min silence duration to consider speech as ended.
+ *
+ * call-seq:
+ *   min_silence_duration_ms = duration_ms -> duration_ms
+ */
+static VALUE
+ruby_whisper_vad_params_set_min_silence_duration_ms(VALUE self, VALUE value)
+{
+  ruby_whisper_vad_params *rwvp;
+  TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
+  rwvp->params.min_silence_duration_ms = NUM2INT(value);
+  return value;
+}
+
+static VALUE
+ruby_whisper_vad_params_get_min_silence_duration_ms(VALUE self)
+{
+  ruby_whisper_vad_params *rwvp;
+  TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
+  return INT2NUM(rwvp->params.min_silence_duration_ms);
+}
+
+/*
+ * Max duration of a speech segment before forcing a new segment.
+ *
+ * call-seq:
+ *   max_speech_duration_s = duration_s -> duration_s
+ */
+static VALUE
+ruby_whisper_vad_params_set_max_speech_duration_s(VALUE self, VALUE value)
+{
+  ruby_whisper_vad_params *rwvp;
+  TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
+  rwvp->params.max_speech_duration_s = RFLOAT_VALUE(value);
+  return value;
+}
+
+static VALUE
+ruby_whisper_vad_params_get_max_speech_duration_s(VALUE self)
+{
+  ruby_whisper_vad_params *rwvp;
+  TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
+  return DBL2NUM(rwvp->params.max_speech_duration_s);
+}
+
+/*
+ * Padding added before and after speech segments.
+ *
+ * call-seq:
+ *   speech_pad_ms = pad_ms -> pad_ms
+ */
+static VALUE
+ruby_whisper_vad_params_set_speech_pad_ms(VALUE self, VALUE value)
+{
+  ruby_whisper_vad_params *rwvp;
+  TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
+  rwvp->params.speech_pad_ms = NUM2INT(value);
+  return value;
+}
+
+static VALUE
+ruby_whisper_vad_params_get_speech_pad_ms(VALUE self)
+{
+  ruby_whisper_vad_params *rwvp;
+  TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
+  return INT2NUM(rwvp->params.speech_pad_ms);
+}
+
+/*
+ * Overlap in seconds when copying audio samples from speech segment.
+ *
+ * call-seq:
+ *   samples_overlap = overlap -> overlap
+ */
+static VALUE
+ruby_whisper_vad_params_set_samples_overlap(VALUE self, VALUE value)
+{
+  ruby_whisper_vad_params *rwvp;
+  TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
+  rwvp->params.samples_overlap = RFLOAT_VALUE(value);
+  return value;
+}
+
+static VALUE
+ruby_whisper_vad_params_get_samples_overlap(VALUE self)
+{
+  ruby_whisper_vad_params *rwvp;
+  TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
+  return DBL2NUM(rwvp->params.samples_overlap);
+}
+
+static VALUE
+ruby_whisper_vad_params_equal(VALUE self, VALUE other)
+{
+  ruby_whisper_vad_params *rwvp1;
+  ruby_whisper_vad_params *rwvp2;
+
+  if (self == other) {
+    return Qtrue;
+  }
+
+  if (!rb_obj_is_kind_of(other, cVADParams)) {
+    return Qfalse;
+  }
+
+  TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp1);
+  TypedData_Get_Struct(other, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp2);
+
+  if (rwvp1->params.threshold != rwvp2->params.threshold) {
+    return Qfalse;
+  }
+  if (rwvp1->params.min_speech_duration_ms != rwvp2->params.min_speech_duration_ms) {
+    return Qfalse;
+  }
+  if (rwvp1->params.min_silence_duration_ms != rwvp2->params.min_silence_duration_ms) {
+    return Qfalse;
+  }
+  if (rwvp1->params.max_speech_duration_s != rwvp2->params.max_speech_duration_s) {
+    return Qfalse;
+  }
+  if (rwvp1->params.speech_pad_ms != rwvp2->params.speech_pad_ms) {
+    return Qfalse;
+  }
+  if (rwvp1->params.samples_overlap != rwvp2->params.samples_overlap) {
+    return Qfalse;
+  }
+
+  return Qtrue;
+}
+
+#define SET_PARAM_IF_SAME(param_name) \
+  if (id == id_ ## param_name) { \
+    ruby_whisper_vad_params_set_ ## param_name(self, value); \
+    continue; \
+  }
+
+VALUE
+ruby_whisper_vad_params_initialize(int argc, VALUE *argv, VALUE self)
+{
+  VALUE kw_hash;
+  VALUE values[NUM_PARAMS] = {Qundef};
+  VALUE value;
+  ruby_whisper_vad_params *rwvp;
+  ID id;
+  int i;
+
+  TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
+
+  rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash);
+  if (NIL_P(kw_hash)) {
+    return self;
+  }
+
+  rb_get_kwargs(kw_hash, param_names, 0, NUM_PARAMS, values);
+
+  for (i = 0; i < NUM_PARAMS; i++) {
+    id= param_names[i];
+    value = values[i];
+    if (value == Qundef) {
+      continue;
+    }
+    SET_PARAM_IF_SAME(threshold)
+    SET_PARAM_IF_SAME(min_speech_duration_ms)
+    SET_PARAM_IF_SAME(min_silence_duration_ms)
+    SET_PARAM_IF_SAME(max_speech_duration_s)
+    SET_PARAM_IF_SAME(speech_pad_ms)
+    SET_PARAM_IF_SAME(samples_overlap)
+  }
+
+  return self;
+}
+
+#undef SET_PARAM_IF_SAME
+
+void
+init_ruby_whisper_vad_params(VALUE *mVAD)
+{
+  cVADParams = rb_define_class_under(*mVAD, "Params", rb_cObject);
+  rb_define_alloc_func(cVADParams, ruby_whisper_vad_params_s_allocate);
+  rb_define_method(cVADParams, "initialize", ruby_whisper_vad_params_initialize, -1);
+
+  DEFINE_PARAM(threshold, 0)
+  DEFINE_PARAM(min_speech_duration_ms, 1)
+  DEFINE_PARAM(min_silence_duration_ms, 2)
+  DEFINE_PARAM(max_speech_duration_s, 3)
+  DEFINE_PARAM(speech_pad_ms, 4)
+  DEFINE_PARAM(samples_overlap, 5)
+
+  rb_define_method(cVADParams, "==", ruby_whisper_vad_params_equal, 1);
+}
+
+#undef DEFINE_PARAM
+#undef NUM_PARAMS
index 06e7a263570da440a3f5bb95bdcb96c6513a704c..fb3ee5db0a4dac43d5a6bf1d5f48b20b9b4d7fcf 100644 (file)
@@ -165,6 +165,12 @@ module Whisper
       models[name] = URI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}.bin")
     }
 
+    %w[
+      silero-v5.1.2
+    ].each do |name|
+      @pre_converted_models[name] = URI.new("https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-#{name}.bin")
+    end
+
     class << self
       attr_reader :pre_converted_models
     end
index a3ce94b8fde0d9acfdf900aff74e8b8584d538cb..c1373c878f2ef4f70a96009fb069947fd739920e 100644 (file)
@@ -150,7 +150,10 @@ module Whisper
       ?encoder_begin_callback: encoder_begin_callback,
       ?encoder_begin_callback_user_data: Object,
       ?abort_callback: abort_callback,
-      ?abort_callback_user_data: Object
+      ?abort_callback_user_data: Object,
+      ?vad: boolish,
+      ?vad_model_path: path | URI,
+      ?vad_params: Whisper::VAD::Params
     ) -> instance
 
     # params.language = "auto" | "en", etc...
@@ -338,6 +341,20 @@ module Whisper
 
     def abort_callback_user_data: () -> Object
 
+    # Enable VAD
+    #
+    def vad=: (boolish) -> boolish
+
+    def vad: () -> (true | false)
+
+    # Path to the VAD model
+    def vad_model_path=: (path | URI | nil) -> (path | URI | nil)
+
+    def vad_model_path: () -> (String | nil)
+
+    def vad_params=: (Whisper::VAD::Params) -> Whisper::VAD::Params
+    def vad_params: () -> (Whisper::VAD::Params)
+
     # Hook called on new segment. Yields each Whisper::Segment.
     #
     #   whisper.on_new_segment do |segment|
@@ -406,6 +423,55 @@ module Whisper
     def no_speech_prob: () -> Float
   end
 
+  module VAD
+    class Params
+      def self.new: (
+        ?threshold: Float,
+        ?min_speech_duration_ms: Integer,
+        ?min_silence_duration_ms: Integer,
+        ?max_speech_duration_s: Float,
+        ?speech_pad_ms: Integer,
+        ?samples_overlap: Float
+      ) -> instance
+
+      # Probability threshold to consider as speech.
+      #
+      def threshold=: (Float) -> Float
+
+      def threshold: () -> Float
+
+      # Min duration for a valid speech segment.
+      #
+      def min_speech_duration_ms=: (Integer) -> Integer
+
+      def min_speech_duration_ms: () -> Integer
+
+      # Min silence duration to consider speech as ended.
+      #
+      def min_silence_duration_ms=: (Integer) -> Integer
+
+      def min_silence_duration_ms: () -> Integer
+
+      # Max duration of a speech segment before forcing a new segment.
+      def max_speech_duration_s=: (Float) -> Float
+
+      def max_speech_duration_s: () -> Float
+
+      # Padding added before and after speech segments.
+      #
+      def speech_pad_ms=: (Integer) -> Integer
+
+      def speech_pad_ms: () -> Integer
+
+      # Overlap in seconds when copying audio samples from speech segment.
+      #
+      def samples_overlap=: (Float) -> Float
+
+      def samples_overlap: () -> Float
+      def ==: (Params) -> (true | false)
+    end
+  end
+
   class Error < StandardError
     attr_reader code: Integer
 
index 5f7fc387aa99d506984b41ae3d7a790f6e83d914..9a9535799b78afb9db844f0eac446c354ab03d85 100644 (file)
@@ -32,6 +32,9 @@ class TestParams < TestBase
     :progress_callback_user_data,
     :abort_callback,
     :abort_callback_user_data,
+    :vad,
+    :vad_model_path,
+    :vad_params,
   ]
 
   def setup
@@ -191,6 +194,50 @@ class TestParams < TestBase
     assert_in_delta 0.2, @params.no_speech_thold
   end
 
+  def test_vad
+    assert_false @params.vad
+    @params.vad = true
+    assert_true @params.vad
+  end
+
+  def test_vad_model_path
+    assert_nil @params.vad_model_path
+    @params.vad_model_path = "silero-v5.1.2"
+    assert_equal Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path, @params.vad_model_path
+  end
+
+  def test_vad_model_path_with_nil
+    @params.vad_model_path = "silero-v5.1.2"
+    @params.vad_model_path = nil
+    assert_nil @params.vad_model_path
+  end
+
+  def test_vad_model_path_with_invalid
+    assert_raise TypeError do
+      @params.vad_model_path = Object.new
+    end
+  end
+
+  def test_vad_model_path_with_URI_string
+    @params.vad_model_path = "https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-silero-v5.1.2.bin"
+    assert_equal @params.vad_model_path, Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path
+  end
+
+  def test_vad_model_path_with_URI
+    @params.vad_model_path = URI("https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-silero-v5.1.2.bin")
+    assert_equal @params.vad_model_path, Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path
+  end
+
+  def test_vad_params
+    assert_kind_of Whisper::VAD::Params, @params.vad_params
+    default_params = @params.vad_params
+    assert_same default_params, @params.vad_params
+    assert_equal 0.5, default_params.threshold
+    new_params = Whisper::VAD::Params.new
+    @params.vad_params = new_params
+    assert_same new_params, @params.vad_params
+  end
+
   def test_new_with_kw_args
     params = Whisper::Params.new(language: "es")
     assert_equal "es", params.language
@@ -225,6 +272,10 @@ class TestParams < TestBase
               proc {}
             in [/_user_data\Z/, *]
               Object.new
+            in [:vad_model_path, *]
+              Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path
+            in [:vad_params, *]
+              Whisper::VAD::Params.new
             end
     params = Whisper::Params.new(param => value)
     if Float === value
diff --git a/bindings/ruby/tests/test_vad.rb b/bindings/ruby/tests/test_vad.rb
new file mode 100644 (file)
index 0000000..cb5e3c7
--- /dev/null
@@ -0,0 +1,19 @@
+require_relative "helper"
+
+class TestVAD < TestBase
+  def setup
+    @whisper = Whisper::Context.new("base.en")
+    vad_params = Whisper::VAD::Params.new
+    @params = Whisper::Params.new(
+      vad: true,
+      vad_model_path: "silero-v5.1.2",
+      vad_params:
+    )
+  end
+
+  def test_transcribe
+    @whisper.transcribe(TestBase::AUDIO, @params) do |text|
+      assert_match(/ask not what your country can do for you[,.] ask what you can do for your country/i, text)
+    end
+  end
+end
diff --git a/bindings/ruby/tests/test_vad_params.rb b/bindings/ruby/tests/test_vad_params.rb
new file mode 100644 (file)
index 0000000..add4899
--- /dev/null
@@ -0,0 +1,103 @@
+require_relative "helper"
+
+class TestVADParams < TestBase
+  PARAM_NAMES = [
+    :threshold,
+    :min_speech_duration_ms,
+    :min_silence_duration_ms,
+    :max_speech_duration_s,
+    :speech_pad_ms,
+    :samples_overlap
+  ]
+
+  def setup
+    @params = Whisper::VAD::Params.new
+  end
+
+  def test_new
+    params = Whisper::VAD::Params.new
+    assert_kind_of Whisper::VAD::Params, params
+  end
+
+  def test_threshold
+    assert_in_delta @params.threshold, 0.5
+    @params.threshold = 0.7
+    assert_in_delta @params.threshold, 0.7
+  end
+
+  def test_min_speech_duration
+    pend
+  end
+
+  def test_min_speech_duration_ms
+    assert_equal 250, @params.min_speech_duration_ms
+    @params.min_speech_duration_ms = 500
+    assert_equal 500, @params.min_speech_duration_ms
+  end
+
+  def test_min_silence_duration_ms
+    assert_equal 100, @params.min_silence_duration_ms
+    @params.min_silence_duration_ms = 200
+    assert_equal 200, @params.min_silence_duration_ms
+  end
+
+  def test_max_speech_duration
+    pend
+  end
+
+  def test_max_speech_duration_s
+    assert @params.max_speech_duration_s >= 10e37 # Defaults to FLT_MAX
+    @params.max_speech_duration_s = 60.0
+    assert_equal 60.0, @params.max_speech_duration_s
+  end
+
+  def test_speech_pad_ms
+    assert_equal 30, @params.speech_pad_ms
+    @params.speech_pad_ms = 50
+    assert_equal 50, @params.speech_pad_ms
+  end
+
+  def test_samples_overlap
+    assert_in_delta @params.samples_overlap, 0.1
+    @params.samples_overlap = 0.5
+    assert_in_delta @params.samples_overlap, 0.5
+  end
+
+  def test_equal
+    assert_equal @params, Whisper::VAD::Params.new
+  end
+
+  def test_new_with_kw_args
+    params = Whisper::VAD::Params.new(threshold: 0.7)
+    assert_in_delta params.threshold, 0.7
+    assert_equal 250, params.min_speech_duration_ms
+  end
+
+  def test_new_with_kw_args_non_existent
+    assert_raise ArgumentError do
+      Whisper::VAD::Params.new(non_existent: "value")
+    end
+  end
+
+  data(PARAM_NAMES.collect {|param| [param, param]}.to_h)
+  def test_new_with_kw_args_default_values(param)
+    default_value = @params.send(param)
+    value = default_value + 1
+    params = Whisper::VAD::Params.new(param => value)
+    if Float === value
+      assert_in_delta value, params.send(param)
+    else
+      assert_equal value, params.send(param)
+    end
+
+    PARAM_NAMES.reject {|name| name == param}.each do |name|
+      expected = @params.send(name)
+      actual = params.send(name)
+      if Float === expected
+        assert_in_delta expected, actual
+      else
+        assert_equal expected, actual
+      end
+    end
+  end
+end