]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ruby : Add parallel transcription support (#3222)
authorKITAITI Makoto <redacted>
Wed, 4 Jun 2025 05:50:18 +0000 (14:50 +0900)
committerGitHub <redacted>
Wed, 4 Jun 2025 05:50:18 +0000 (14:50 +0900)
* Fix indentation of code sample in document comment

* Make Whisper::Context#transcribe able to run non-parallel

* Add test for Whisper::Context#transcribe with parallel option

* Follow signature API change of Context#transcribe

* Remove useless variable assignment

* Move simple usage up in README

* Add need help section in README

* Add document on Context#transcribe's parallel option in README

* Update date

* Fix signature of Context.new

* Make Context#subscribe accept n_processors option

* Make test follow #transcribe's change

* Make RBS follow #transcribe's change

* Add document for #transcribe's n_processors option

* Rename test directory so that Rake tasks' default setting is used

33 files changed:
bindings/ruby/README.md
bindings/ruby/Rakefile
bindings/ruby/ext/ruby_whisper.c
bindings/ruby/ext/ruby_whisper_context.c
bindings/ruby/ext/ruby_whisper_transcribe.cpp
bindings/ruby/sig/whisper.rbs
bindings/ruby/test/helper.rb [new file with mode: 0644]
bindings/ruby/test/jfk_reader/.gitignore [new file with mode: 0644]
bindings/ruby/test/jfk_reader/extconf.rb [new file with mode: 0644]
bindings/ruby/test/jfk_reader/jfk_reader.c [new file with mode: 0644]
bindings/ruby/test/test_callback.rb [new file with mode: 0644]
bindings/ruby/test/test_error.rb [new file with mode: 0644]
bindings/ruby/test/test_model.rb [new file with mode: 0644]
bindings/ruby/test/test_package.rb [new file with mode: 0644]
bindings/ruby/test/test_params.rb [new file with mode: 0644]
bindings/ruby/test/test_segment.rb [new file with mode: 0644]
bindings/ruby/test/test_vad.rb [new file with mode: 0644]
bindings/ruby/test/test_vad_params.rb [new file with mode: 0644]
bindings/ruby/test/test_whisper.rb [new file with mode: 0644]
bindings/ruby/tests/helper.rb [deleted file]
bindings/ruby/tests/jfk_reader/.gitignore [deleted file]
bindings/ruby/tests/jfk_reader/extconf.rb [deleted file]
bindings/ruby/tests/jfk_reader/jfk_reader.c [deleted file]
bindings/ruby/tests/test_callback.rb [deleted file]
bindings/ruby/tests/test_error.rb [deleted file]
bindings/ruby/tests/test_model.rb [deleted file]
bindings/ruby/tests/test_package.rb [deleted file]
bindings/ruby/tests/test_params.rb [deleted file]
bindings/ruby/tests/test_segment.rb [deleted file]
bindings/ruby/tests/test_vad.rb [deleted file]
bindings/ruby/tests/test_vad_params.rb [deleted file]
bindings/ruby/tests/test_whisper.rb [deleted file]
bindings/ruby/whispercpp.gemspec

index 6de00fb275be082918c29e620cbd701821f78b99..5ba88e6f8d5bc8afc192b42ad335ed392ff61022 100644 (file)
@@ -70,17 +70,6 @@ end
 
 Some models are prepared up-front:
 
-```ruby
-base_en = Whisper::Model.pre_converted_models["base.en"]
-whisper = Whisper::Context.new(base_en)
-```
-
-At first time you use a model, it is downloaded automatically. After that, downloaded cached file is used. To clear cache, call `#clear_cache`:
-
-```ruby
-Whisper::Model.pre_converted_models["base"].clear_cache
-```
-
 You also can use shorthand for pre-converted models:
 
 ```ruby
@@ -105,6 +94,19 @@ puts Whisper::Model.pre_converted_models.keys
 #   :
 ```
 
+You can also retrieve each model:
+
+```ruby
+base_en = Whisper::Model.pre_converted_models["base.en"]
+whisper = Whisper::Context.new(base_en)
+```
+
+At first time you use a model, it is downloaded automatically. After that, downloaded cached file is used. To clear cache, call `#clear_cache`:
+
+```ruby
+Whisper::Model.pre_converted_models["base"].clear_cache
+```
+
 You can also use local model files you prepared:
 
 ```ruby
@@ -163,6 +165,16 @@ For details on VAD, see [whisper.cpp's README](https://github.com/ggml-org/whisp
 API
 ---
 
+### Transcription ###
+
+By default, `Whisper::Context#transcribe` works in a single thread. You can make it work in parallel by passing `n_processors` option:
+
+```ruby
+whisper.transcribe("path/to/audio.wav", params, n_processors: Etc.nprocessors)
+```
+
+Note that transcription occasionally might be low accuracy when it works in parallel.
+
 ### Segments ###
 
 Once `Whisper::Context#transcribe` called, you can retrieve segments by `#each_segment`:
@@ -297,6 +309,11 @@ First call of `rake test` builds an extension and downloads a model for testing.
 
 If something seems wrong on build, running `rake clean` solves some cases.
 
+### Need help ###
+
+* Windows support
+* Refinement of C/C++ code, especially memory management
+
 License
 -------
 
index bc6f843369b8d0a56aaaa59ed6269f739e49c3d6..08a2312a551cc48e5ab2869a32693845facfd347 100644 (file)
@@ -67,17 +67,15 @@ file LIB_FILE => [SO_FILE, "lib"] do |t|
 end
 CLEAN.include LIB_FILE
 
-Rake::TestTask.new do |t|
-  t.test_files = FileList["tests/test_*.rb"]
-end
+Rake::TestTask.new
 
-TEST_MEMORY_VIEW = "tests/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}"
-file TEST_MEMORY_VIEW => "tests/jfk_reader/jfk_reader.c" do |t|
-  chdir "tests/jfk_reader" do
+TEST_MEMORY_VIEW = "test/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}"
+file TEST_MEMORY_VIEW => "test/jfk_reader/jfk_reader.c" do |t|
+  chdir "test/jfk_reader" do
     ruby "extconf.rb"
     sh "make"
   end
 end
-CLEAN.include "tests/jfk_reader/jfk_reader.{o,#{RbConfig::CONFIG['DLEXT']}}"
+CLEAN.include "test/jfk_reader/jfk_reader.{o,#{RbConfig::CONFIG['DLEXT']}}"
 
 task test: [LIB_FILE, TEST_MEMORY_VIEW]
index e88aa29c05d9c38526e5779a1b406bc5162da9b3..a1c2c520512e1b3c363027267c817b388c34db3c 100644 (file)
@@ -24,6 +24,7 @@ ID id_URI;
 ID id_pre_converted_models;
 ID id_coreml_compiled_models;
 ID id_cache;
+ID id_n_processors;
 
 static bool is_log_callback_finalized = false;
 
@@ -142,6 +143,7 @@ void Init_whisper() {
   id_pre_converted_models = rb_intern("pre_converted_models");
   id_coreml_compiled_models = rb_intern("coreml_compiled_models");
   id_cache = rb_intern("cache");
+  id_n_processors = rb_intern("n_processors");
 
   mWhisper = rb_define_module("Whisper");
   mVAD = rb_define_module_under(mWhisper, "VAD");
index 75aa8dc906519d27037358b1e4cd281279447c28..cb58c8d4877908112bc210ec33bfedc598c653c8 100644 (file)
@@ -13,6 +13,7 @@ extern ID id_URI;
 extern ID id_pre_converted_models;
 extern ID id_coreml_compiled_models;
 extern ID id_cache;
+extern ID id_n_processors;
 
 extern VALUE cContext;
 extern VALUE eError;
@@ -24,6 +25,8 @@ extern VALUE rb_whisper_model_s_new(VALUE context);
 extern VALUE rb_whisper_segment_s_new(VALUE context, int index);
 extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context);
 
+ID transcribe_option_names[1];
+
 static void
 ruby_whisper_free(ruby_whisper *rw)
 {
@@ -633,6 +636,8 @@ init_ruby_whisper_context(VALUE *mWhisper)
 {
   cContext = rb_define_class_under(*mWhisper, "Context", rb_cObject);
 
+  transcribe_option_names[0] = id_n_processors;
+
   rb_define_alloc_func(cContext, ruby_whisper_allocate);
   rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1);
 
index d12d2de96fe63490de4928df04f366702e063a13..71c4b49b45a34369644ab340449f1e0f99cf6764 100644 (file)
@@ -13,6 +13,7 @@ extern const rb_data_type_t ruby_whisper_params_type;
 
 extern ID id_to_s;
 extern ID id_call;
+extern ID transcribe_option_names[1];
 
 extern void
 prepare_transcription(ruby_whisper_params * rwp, VALUE * self);
@@ -34,9 +35,14 @@ VALUE
 ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
   ruby_whisper *rw;
   ruby_whisper_params *rwp;
-  VALUE wave_file_path, blk, params;
+  VALUE wave_file_path, blk, params, kws;
+  VALUE opts[1];
+
+  rb_scan_args_kw(RB_SCAN_ARGS_LAST_HASH_KEYWORDS, argc, argv, "2:&", &wave_file_path, &params, &kws, &blk);
+  rb_get_kwargs(kws, transcribe_option_names, 0, 1, opts);
+
+  int n_processors = opts[0] == Qundef ? 1 : NUM2INT(opts[0]);
 
-  rb_scan_args(argc, argv, "02&", &wave_file_path, &params, &blk);
   TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
   TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
 
@@ -66,7 +72,7 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
 
   prepare_transcription(rwp, &self);
 
-  if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
+  if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), n_processors) != 0) {
     fprintf(stderr, "failed to process audio\n");
     return self;
   }
@@ -76,9 +82,8 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
     const char * text = whisper_full_get_segment_text(rw->context, i);
     output = rb_str_concat(output, rb_str_new2(text));
   }
-  VALUE idCall = id_call;
   if (blk != Qnil) {
-    rb_funcall(blk, idCall, 1, output);
+    rb_funcall(blk, id_call, 1, output);
   }
   return self;
 }
index 6f8be29a66bfa5de22a9e2bdd993b4aeba3e1e60..f9d09631509fc9352f10a09772e5e883cac5b982 100644 (file)
@@ -25,19 +25,19 @@ module Whisper
   def self.system_info_str: () -> String
 
   class Context
-    def self.new: (path | ::URI::HTTP) -> instance
+    def self.new: (String | path | ::URI::HTTP) -> instance
 
     # transcribe a single file
     # can emit to a block results
     #
-    #   params = Whisper::Params.new
-    #   params.duration = 60_000
-    #   whisper.transcribe "path/to/audio.wav", params do |text|
-    #     puts text
-    #   end
+    #     params = Whisper::Params.new
+    #     params.duration = 60_000
+    #     whisper.transcribe "path/to/audio.wav", params do |text|
+    #       puts text
+    #     end
     #
-    def transcribe: (string, Params) -> self
-                  | (string, Params) { (String) -> void } -> self
+    def transcribe: (string, Params, ?n_processors: Integer) -> self
+                  | (string, Params, ?n_processors: Integer) { (String) -> void } -> self
 
     def model_n_vocab: () -> Integer
     def model_n_audio_ctx: () -> Integer
@@ -50,16 +50,16 @@ module Whisper
 
     # Yields each Whisper::Segment:
     #
-    #   whisper.transcribe("path/to/audio.wav", params)
-    #   whisper.each_segment do |segment|
-    #     puts segment.text
-    #   end
+    #     whisper.transcribe("path/to/audio.wav", params)
+    #     whisper.each_segment do |segment|
+    #       puts segment.text
+    #     end
     #
     # Returns an Enumerator if no block given:
     #
-    #   whisper.transcribe("path/to/audio.wav", params)
-    #   enum = whisper.each_segment
-    #   enum.to_a # => [#<Whisper::Segment>, ...]
+    #     whisper.transcribe("path/to/audio.wav", params)
+    #     enum = whisper.each_segment
+    #     enum.to_a # => [#<Whisper::Segment>, ...]
     #
     def each_segment: { (Segment) -> void } -> void
                     | () -> Enumerator[Segment]
@@ -74,25 +74,25 @@ module Whisper
 
     # Start time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
     #
-    #   full_get_segment_t0(3) # => 1668 (16680 ms)
+    #     full_get_segment_t0(3) # => 1668 (16680 ms)
     #
     def full_get_segment_t0: (Integer) -> Integer
 
     # End time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
     #
-    #   full_get_segment_t1(3) # => 1668 (16680 ms)
+    #     full_get_segment_t1(3) # => 1668 (16680 ms)
     #
     def full_get_segment_t1: (Integer) -> Integer
 
     # Whether the next segment indexed by +segment_index+ is predicated as a speaker turn.
     #
-    #   full_get_segment_speacker_turn_next(3) # => true
+    #     full_get_segment_speacker_turn_next(3) # => true
     #
     def full_get_segment_speaker_turn_next: (Integer) -> (true | false)
 
     # Text of a segment indexed by +segment_index+.
     #
-    #   full_get_segment_text(3) # => "ask not what your country can do for you, ..."
+    #     full_get_segment_text(3) # => "ask not what your country can do for you, ..."
     #
     def full_get_segment_text: (Integer) -> String
 
@@ -282,9 +282,9 @@ module Whisper
 
     # Sets new segment callback, called for every newly generated text segment.
     #
-    #   params.new_segment_callback = ->(context, _, n_new, user_data) {
-    #     # ...
-    #   }
+    #     params.new_segment_callback = ->(context, _, n_new, user_data) {
+    #       # ...
+    #     }
     #
     def new_segment_callback=: (new_segment_callback) -> new_segment_callback
     def new_segment_callback: () -> (new_segment_callback | nil)
@@ -297,9 +297,9 @@ module Whisper
 
     # Sets progress callback, called on each progress update.
     #
-    #   params.new_segment_callback = ->(context, _, progress, user_data) {
-    #     # ...
-    #   }
+    #     params.new_segment_callback = ->(context, _, progress, user_data) {
+    #       # ...
+    #     }
     #
     # +progress+ is an Integer between 0 and 100.
     #
@@ -327,9 +327,9 @@ module Whisper
 
     # Sets abort callback, called to check if the process should be aborted.
     #
-    #   params.abort_callback = ->(user_data) {
-    #     # ...
-    #   }
+    #     params.abort_callback = ->(user_data) {
+    #       # ...
+    #     }
     #
     #
     def abort_callback=: (abort_callback) -> abort_callback
@@ -358,9 +358,9 @@ module Whisper
 
     # Hook called on new segment. Yields each Whisper::Segment.
     #
-    #   whisper.on_new_segment do |segment|
-    #     # ...
-    #   end
+    #     whisper.on_new_segment do |segment|
+    #       # ...
+    #     end
     #
     def on_new_segment: { (Segment) -> void } -> void
 
@@ -374,13 +374,13 @@ module Whisper
 
     # Call block to determine whether abort or not. Return +true+ when you want to abort.
     #
-    #   params.abort_on do
-    #     if some_condition
-    #       true # abort
-    #     else
-    #       false # continue
+    #     params.abort_on do
+    #       if some_condition
+    #         true # abort
+    #       else
+    #         false # continue
+    #       end
     #     end
-    #   end
     #
     def abort_on: { (Object user_data) -> boolish } -> void
   end
diff --git a/bindings/ruby/test/helper.rb b/bindings/ruby/test/helper.rb
new file mode 100644 (file)
index 0000000..389e15c
--- /dev/null
@@ -0,0 +1,24 @@
+require "test/unit"
+require "whisper"
+require_relative "jfk_reader/jfk_reader"
+
+class TestBase < Test::Unit::TestCase
+  AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")
+
+  class << self
+    def whisper
+      return @whisper if @whisper
+
+      @whisper = Whisper::Context.new("base.en")
+      params = Whisper::Params.new
+      params.print_timestamps = false
+      @whisper.transcribe(TestBase::AUDIO, params)
+    end
+  end
+
+  private
+
+  def whisper
+    self.class.whisper
+  end
+end
diff --git a/bindings/ruby/test/jfk_reader/.gitignore b/bindings/ruby/test/jfk_reader/.gitignore
new file mode 100644 (file)
index 0000000..656da8d
--- /dev/null
@@ -0,0 +1,5 @@
+Makefile
+jfk_reader.o
+jfk_reader.so
+jfk_reader.bundle
+jfk_reader.dll
diff --git a/bindings/ruby/test/jfk_reader/extconf.rb b/bindings/ruby/test/jfk_reader/extconf.rb
new file mode 100644 (file)
index 0000000..0d842d0
--- /dev/null
@@ -0,0 +1,3 @@
+require "mkmf"
+
+create_makefile("jfk_reader")
diff --git a/bindings/ruby/test/jfk_reader/jfk_reader.c b/bindings/ruby/test/jfk_reader/jfk_reader.c
new file mode 100644 (file)
index 0000000..6657176
--- /dev/null
@@ -0,0 +1,68 @@
+#include <ruby.h>
+#include <ruby/memory_view.h>
+#include <ruby/encoding.h>
+
+static VALUE
+jfk_reader_initialize(VALUE self, VALUE audio_path)
+{
+  rb_iv_set(self, "audio_path", audio_path);
+  return Qnil;
+}
+
+static bool
+jfk_reader_get_memory_view(const VALUE obj, rb_memory_view_t *view, int flags)
+{
+  VALUE audio_path = rb_iv_get(obj, "audio_path");
+  const char *audio_path_str = StringValueCStr(audio_path);
+  const int n_samples = 176000;
+  float *data = (float *)malloc(n_samples * sizeof(float));
+  short *samples = (short *)malloc(n_samples * sizeof(short));
+  FILE *file = fopen(audio_path_str, "rb");
+
+  fseek(file, 78, SEEK_SET);
+  fread(samples, sizeof(short), n_samples, file);
+  fclose(file);
+  for (int i = 0; i < n_samples; i++) {
+    data[i] = samples[i]/32768.0;
+  }
+
+  view->obj = obj;
+  view->data = (void *)data;
+  view->byte_size = sizeof(float) * n_samples;
+  view->readonly = true;
+  view->format = "f";
+  view->item_size = sizeof(float);
+  view->item_desc.components = NULL;
+  view->item_desc.length = 0;
+  view->ndim = 1;
+  view->shape = NULL;
+  view->sub_offsets = NULL;
+  view->private_data = NULL;
+
+  return true;
+}
+
+static bool
+jfk_reader_release_memory_view(const VALUE obj, rb_memory_view_t *view)
+{
+  return true;
+}
+
+static bool
+jfk_reader_memory_view_available_p(const VALUE obj)
+{
+  return true;
+}
+
+static const rb_memory_view_entry_t jfk_reader_view_entry = {
+  jfk_reader_get_memory_view,
+  jfk_reader_release_memory_view,
+  jfk_reader_memory_view_available_p
+};
+
+void Init_jfk_reader(void)
+{
+  VALUE cJFKReader = rb_define_class("JFKReader", rb_cObject);
+  rb_memory_view_register(cJFKReader, &jfk_reader_view_entry);
+  rb_define_method(cJFKReader, "initialize", jfk_reader_initialize, 1);
+}
diff --git a/bindings/ruby/test/test_callback.rb b/bindings/ruby/test/test_callback.rb
new file mode 100644 (file)
index 0000000..a7f4924
--- /dev/null
@@ -0,0 +1,202 @@
+require_relative "helper"
+
+class TestCallback < TestBase
+  def setup
+    GC.start
+    @params = Whisper::Params.new
+    @whisper = Whisper::Context.new("base.en")
+    @audio = File.join(AUDIO)
+  end
+
+  def test_new_segment_callback
+    @params.new_segment_callback = ->(context, state, n_new, user_data) {
+      assert_kind_of Integer, n_new
+      assert n_new > 0
+      assert_same @whisper, context
+
+      n_segments = context.full_n_segments
+      n_new.times do |i|
+        i_segment = n_segments - 1 + i
+        start_time = context.full_get_segment_t0(i_segment) * 10
+        end_time = context.full_get_segment_t1(i_segment) * 10
+        text = context.full_get_segment_text(i_segment)
+
+        assert_kind_of Integer, start_time
+        assert start_time >= 0
+        assert_kind_of Integer, end_time
+        assert end_time > 0
+        assert_match(/ask not what your country can do for you, ask what you can do for your country/, text) if i_segment == 0
+      end
+    }
+
+    @whisper.transcribe(@audio, @params)
+  end
+
+  def test_new_segment_callback_closure
+    search_word = "what"
+    @params.new_segment_callback = ->(context, state, n_new, user_data) {
+      n_segments = context.full_n_segments
+      n_new.times do |i|
+        i_segment = n_segments - 1 + i
+        text = context.full_get_segment_text(i_segment)
+        if text.include?(search_word)
+          t0 = context.full_get_segment_t0(i_segment)
+          t1 = context.full_get_segment_t1(i_segment)
+          raise "search word '#{search_word}' found at between #{t0} and #{t1}"
+        end
+      end
+    }
+
+    assert_raise RuntimeError do
+      @whisper.transcribe(@audio, @params)
+    end
+  end
+
+  def test_new_segment_callback_user_data
+    udata = Object.new
+    @params.new_segment_callback_user_data = udata
+    @params.new_segment_callback = ->(context, state, n_new, user_data) {
+      assert_same udata, user_data
+    }
+
+    @whisper.transcribe(@audio, @params)
+  end
+
+  def test_new_segment_callback_user_data_gc
+    @params.new_segment_callback_user_data = "My user data"
+    @params.new_segment_callback = ->(context, state, n_new, user_data) {
+      assert_equal "My user data", user_data
+    }
+    GC.start
+
+    assert_same @whisper, @whisper.transcribe(@audio, @params)
+  end
+
+  def test_progress_callback
+    first = nil
+    last = nil
+    @params.progress_callback = ->(context, state, progress, user_data) {
+      assert_kind_of Integer, progress
+      assert 0 <= progress && progress <= 100
+      assert_same @whisper, context
+      first = progress if first.nil?
+      last = progress
+    }
+    @whisper.transcribe(@audio, @params)
+    assert_equal 0, first
+    assert_equal 100, last
+  end
+
+  def test_progress_callback_user_data
+    udata = Object.new
+    @params.progress_callback_user_data = udata
+    @params.progress_callback = ->(context, state, n_new, user_data) {
+      assert_same udata, user_data
+    }
+
+    @whisper.transcribe(@audio, @params)
+  end
+
+  def test_on_progress
+    first = nil
+    last = nil
+    @params.on_progress do |progress|
+      assert_kind_of Integer, progress
+      assert 0 <= progress && progress <= 100
+      first = progress if first.nil?
+      last = progress
+    end
+    @whisper.transcribe(@audio, @params)
+    assert_equal 0, first
+    assert_equal 100, last
+  end
+
+  def test_encoder_begin_callback
+    i = 0
+    @params.encoder_begin_callback = ->(context, state, user_data) {
+      i += 1
+    }
+    @whisper.transcribe(@audio, @params)
+    assert i > 0
+  end
+
+  def test_encoder_begin_callback_abort
+    logs = []
+    Whisper.log_set -> (level, buffer, user_data) {
+      logs << buffer if level == Whisper::LOG_LEVEL_ERROR
+    }, logs
+    @params.encoder_begin_callback = ->(context, state, user_data) {
+      return false
+    }
+    @whisper.transcribe(@audio, @params)
+    assert_match(/encoder_begin_callback returned false - aborting/, logs.join)
+    Whisper.log_set ->(level, buffer, user_data) {}, nil
+  end
+
+  def test_encoder_begin_callback_user_data
+    udata = Object.new
+    @params.encoder_begin_callback_user_data = udata
+    yielded = nil
+    @params.encoder_begin_callback = ->(context, state, user_data) {
+      yielded = user_data
+    }
+    @whisper.transcribe(@audio, @params)
+    assert_same udata, yielded
+  end
+
+  def test_on_encoder_begin
+    i = 0
+    @params.on_encoder_begin do
+      i += 1
+    end
+    @whisper.transcribe(@audio, @params)
+    assert i > 0
+  end
+
+  def test_abort_callback
+    i = 0
+    @params.abort_callback = ->(user_data) {
+      assert_nil user_data
+      i += 1
+      return false
+    }
+    @whisper.transcribe(@audio, @params)
+    assert i > 0
+  end
+
+  def test_abort_callback_abort
+    i = 0
+    @params.abort_callback = ->(user_data) {
+      i += 1
+      return i == 3
+    }
+    @whisper.transcribe(@audio, @params)
+    assert_equal 3, i
+  end
+
+  def test_abort_callback_user_data
+    udata = Object.new
+    @params.abort_callback_user_data = udata
+    yielded = nil
+    @params.abort_callback = ->(user_data) {
+      yielded = user_data
+    }
+    @whisper.transcribe(@audio, @params)
+    assert_same udata, yielded
+  end
+
+  def test_abort_on
+    do_abort = false
+    _aborted_from_callback = false
+    @params.on_new_segment do |segment|
+      do_abort = true if segment.text.match?(/ask/)
+    end
+    i = 0
+    @params.abort_on do
+      i += 1
+      do_abort
+    end
+    @whisper.transcribe(@audio, @params)
+    assert i > 0
+  end
+end
diff --git a/bindings/ruby/test/test_error.rb b/bindings/ruby/test/test_error.rb
new file mode 100644 (file)
index 0000000..2f28849
--- /dev/null
@@ -0,0 +1,20 @@
+require_relative "helper"
+
+class TestError < TestBase
+  def test_error
+    error = Whisper::Error.new(-2)
+    assert_equal "failed to compute log mel spectrogram", error.message
+    assert_equal(-2, error.code)
+  end
+
+  def test_unknown_error
+    error = Whisper::Error.new(-20)
+    assert_equal "unknown error", error.message
+  end
+
+  def test_non_int_code
+    assert_raise TypeError do
+      _error = Whisper::Error.new("non int")
+    end
+  end
+end
diff --git a/bindings/ruby/test/test_model.rb b/bindings/ruby/test/test_model.rb
new file mode 100644 (file)
index 0000000..5648fc3
--- /dev/null
@@ -0,0 +1,118 @@
+require_relative "helper"
+require "pathname"
+
+class TestModel < TestBase
+  def test_model
+    whisper = Whisper::Context.new("base.en")
+    assert_instance_of Whisper::Model, whisper.model
+  end
+
+  def test_attributes
+    whisper = Whisper::Context.new("base.en")
+    model = whisper.model
+
+    assert_equal 51864, model.n_vocab
+    assert_equal 1500, model.n_audio_ctx
+    assert_equal 512, model.n_audio_state
+    assert_equal 8, model.n_audio_head
+    assert_equal 6, model.n_audio_layer
+    assert_equal 448, model.n_text_ctx
+    assert_equal 512, model.n_text_state
+    assert_equal 8, model.n_text_head
+    assert_equal 6, model.n_text_layer
+    assert_equal 80, model.n_mels
+    assert_equal 1, model.ftype
+    assert_equal "base", model.type
+  end
+
+  def test_gc
+    model = Whisper::Context.new("base.en").model
+    GC.start
+
+    assert_equal 51864, model.n_vocab
+    assert_equal 1500, model.n_audio_ctx
+    assert_equal 512, model.n_audio_state
+    assert_equal 8, model.n_audio_head
+    assert_equal 6, model.n_audio_layer
+    assert_equal 448, model.n_text_ctx
+    assert_equal 512, model.n_text_state
+    assert_equal 8, model.n_text_head
+    assert_equal 6, model.n_text_layer
+    assert_equal 80, model.n_mels
+    assert_equal 1, model.ftype
+    assert_equal "base", model.type
+  end
+
+  def test_pathname
+    path = Pathname(Whisper::Model.pre_converted_models["base.en"].to_path)
+    whisper = Whisper::Context.new(path)
+    model = whisper.model
+
+    assert_equal 51864, model.n_vocab
+    assert_equal 1500, model.n_audio_ctx
+    assert_equal 512, model.n_audio_state
+    assert_equal 8, model.n_audio_head
+    assert_equal 6, model.n_audio_layer
+    assert_equal 448, model.n_text_ctx
+    assert_equal 512, model.n_text_state
+    assert_equal 8, model.n_text_head
+    assert_equal 6, model.n_text_layer
+    assert_equal 80, model.n_mels
+    assert_equal 1, model.ftype
+    assert_equal "base", model.type
+  end
+
+  def test_auto_download
+    path = Whisper::Model.pre_converted_models["base.en"].to_path
+
+    assert_path_exist path
+    assert_equal 147964211, File.size(path)
+  end
+
+  def test_uri_string
+    path = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin"
+    whisper = Whisper::Context.new(path)
+    model = whisper.model
+
+    assert_equal 51864, model.n_vocab
+    assert_equal 1500, model.n_audio_ctx
+    assert_equal 512, model.n_audio_state
+    assert_equal 8, model.n_audio_head
+    assert_equal 6, model.n_audio_layer
+    assert_equal 448, model.n_text_ctx
+    assert_equal 512, model.n_text_state
+    assert_equal 8, model.n_text_head
+    assert_equal 6, model.n_text_layer
+    assert_equal 80, model.n_mels
+    assert_equal 1, model.ftype
+    assert_equal "base", model.type
+  end
+
+  def test_uri
+    path = URI("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin")
+    whisper = Whisper::Context.new(path)
+    model = whisper.model
+
+    assert_equal 51864, model.n_vocab
+    assert_equal 1500, model.n_audio_ctx
+    assert_equal 512, model.n_audio_state
+    assert_equal 8, model.n_audio_head
+    assert_equal 6, model.n_audio_layer
+    assert_equal 448, model.n_text_ctx
+    assert_equal 512, model.n_text_state
+    assert_equal 8, model.n_text_head
+    assert_equal 6, model.n_text_layer
+    assert_equal 80, model.n_mels
+    assert_equal 1, model.ftype
+    assert_equal "base", model.type
+  end
+
+  def test_coreml_model_auto_download
+    uri = Whisper::Model.coreml_compiled_models[Whisper::Model.pre_converted_models["tiny"]]
+    model_path = Pathname(uri.to_path).sub_ext("")
+    model_path.rmtree if model_path.exist?
+
+    uri.cache
+    assert_path_exist model_path
+  end
+end
diff --git a/bindings/ruby/test/test_package.rb b/bindings/ruby/test/test_package.rb
new file mode 100644 (file)
index 0000000..33cd2a3
--- /dev/null
@@ -0,0 +1,50 @@
+require_relative "helper"
+require 'tempfile'
+require 'tmpdir'
+require 'shellwords'
+
+class TestPackage < TestBase
+  def test_build
+    Tempfile.create do |file|
+      assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true)
+      assert file.size > 0
+      assert_path_exist file.to_path
+    end
+  end
+
+  sub_test_case "Building binary on installation" do
+    def setup
+      system "rake", "build", exception: true
+    end
+
+    def test_install
+      gemspec = Gem::Specification.load("whispercpp.gemspec")
+      Dir.mktmpdir do |dir|
+        system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{gemspec.file_name.shellescape}", exception: true
+        assert_installed dir, gemspec.version
+      end
+    end
+
+    def test_install_with_coreml
+      omit_unless RUBY_PLATFORM.match?(/darwin/) do
+        gemspec = Gem::Specification.load("whispercpp.gemspec")
+        Dir.mktmpdir do |dir|
+          system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{gemspec.file_name.shellescape}", "--", "--enable-whisper-coreml", exception: true
+          assert_installed dir, gemspec.version
+          assert_nothing_raised do
+            libdir = File.join(dir, "gems", "#{gemspec.name}-#{gemspec.version}", "lib")
+            system "ruby", "-I", libdir, "-r", "whisper", "-e", "Whisper::Context.new('tiny')", exception: true
+          end
+        end
+      end
+    end
+
+    private
+
+    def assert_installed(dir, version)
+      assert_path_exist File.join(dir, "gems/whispercpp-#{version}/lib", "whisper.#{RbConfig::CONFIG["DLEXT"]}")
+      assert_path_exist File.join(dir, "gems/whispercpp-#{version}/LICENSE")
+      assert_path_not_exist File.join(dir, "gems/whispercpp-#{version}/ext/build")
+    end
+  end
+end
diff --git a/bindings/ruby/test/test_params.rb b/bindings/ruby/test/test_params.rb
new file mode 100644 (file)
index 0000000..9a95357
--- /dev/null
@@ -0,0 +1,297 @@
+require_relative "helper"
+
+class TestParams < TestBase
+  PARAM_NAMES = [
+    :language,
+    :translate,
+    :no_context,
+    :single_segment,
+    :print_special,
+    :print_progress,
+    :print_realtime,
+    :print_timestamps,
+    :suppress_blank,
+    :suppress_nst,
+    :token_timestamps,
+    :split_on_word,
+    :initial_prompt,
+    :diarize,
+    :offset,
+    :duration,
+    :max_text_tokens,
+    :temperature,
+    :max_initial_ts,
+    :length_penalty,
+    :temperature_inc,
+    :entropy_thold,
+    :logprob_thold,
+    :no_speech_thold,
+    :new_segment_callback,
+    :new_segment_callback_user_data,
+    :progress_callback,
+    :progress_callback_user_data,
+    :abort_callback,
+    :abort_callback_user_data,
+    :vad,
+    :vad_model_path,
+    :vad_params,
+  ]
+
+  def setup
+    @params  = Whisper::Params.new
+  end
+
+  def test_language
+    @params.language = "en"
+    assert_equal @params.language, "en"
+    @params.language = "auto"
+    assert_equal @params.language, "auto"
+  end
+
+  def test_offset
+    @params.offset = 10_000
+    assert_equal @params.offset, 10_000
+    @params.offset = 0
+    assert_equal @params.offset, 0
+  end
+
+  def test_duration
+    @params.duration = 60_000
+    assert_equal @params.duration, 60_000
+    @params.duration = 0
+    assert_equal @params.duration, 0
+  end
+
+  def test_max_text_tokens
+    @params.max_text_tokens = 300
+    assert_equal @params.max_text_tokens, 300
+    @params.max_text_tokens = 0
+    assert_equal @params.max_text_tokens, 0
+  end
+
+  def test_translate
+    @params.translate = true
+    assert @params.translate
+    @params.translate = false
+    assert !@params.translate
+  end
+
+  def test_no_context
+    @params.no_context = true
+    assert @params.no_context
+    @params.no_context = false
+    assert !@params.no_context
+  end
+
+  def test_single_segment
+    @params.single_segment = true
+    assert @params.single_segment
+    @params.single_segment = false
+    assert !@params.single_segment
+  end
+
+  def test_print_special
+    @params.print_special = true
+    assert @params.print_special
+    @params.print_special = false
+    assert !@params.print_special
+  end
+
+  def test_print_progress
+    @params.print_progress = true
+    assert @params.print_progress
+    @params.print_progress = false
+    assert !@params.print_progress
+  end
+
+  def test_print_realtime
+    @params.print_realtime = true
+    assert @params.print_realtime
+    @params.print_realtime = false
+    assert !@params.print_realtime
+  end
+
+  def test_print_timestamps
+    @params.print_timestamps = true
+    assert @params.print_timestamps
+    @params.print_timestamps = false
+    assert !@params.print_timestamps
+  end
+
+  def test_suppress_blank
+    @params.suppress_blank = true
+    assert @params.suppress_blank
+    @params.suppress_blank = false
+    assert !@params.suppress_blank
+  end
+
+  def test_suppress_nst
+    @params.suppress_nst = true
+    assert @params.suppress_nst
+    @params.suppress_nst = false
+    assert !@params.suppress_nst
+  end
+
+  def test_token_timestamps
+    @params.token_timestamps = true
+    assert @params.token_timestamps
+    @params.token_timestamps = false
+    assert !@params.token_timestamps
+  end
+
+  def test_split_on_word
+    @params.split_on_word = true
+    assert @params.split_on_word
+    @params.split_on_word = false
+    assert !@params.split_on_word
+  end
+
+  def test_initial_prompt
+    assert_nil @params.initial_prompt
+    @params.initial_prompt = "You are a polite person."
+    assert_equal "You are a polite person.", @params.initial_prompt
+  end
+
+  def test_temperature
+    assert_equal 0.0, @params.temperature
+    @params.temperature = 0.5
+    assert_equal 0.5, @params.temperature
+  end
+
+  def test_max_initial_ts
+    assert_equal 1.0, @params.max_initial_ts
+    @params.max_initial_ts = 600.0
+    assert_equal 600.0, @params.max_initial_ts
+  end
+
+  def test_length_penalty
+    assert_equal(-1.0, @params.length_penalty)
+    @params.length_penalty = 0.5
+    assert_equal 0.5, @params.length_penalty
+  end
+
+  def test_temperature_inc
+    assert_in_delta 0.2, @params.temperature_inc
+    @params.temperature_inc = 0.5
+    assert_in_delta 0.5, @params.temperature_inc
+  end
+
+  def test_entropy_thold
+    assert_in_delta 2.4, @params.entropy_thold
+    @params.entropy_thold = 3.0
+    assert_in_delta 3.0, @params.entropy_thold
+  end
+
+  def test_logprob_thold
+    assert_in_delta(-1.0, @params.logprob_thold)
+    @params.logprob_thold = -0.5
+    assert_in_delta(-0.5, @params.logprob_thold)
+  end
+
+  def test_no_speech_thold
+    assert_in_delta 0.6, @params.no_speech_thold
+    @params.no_speech_thold = 0.2
+    assert_in_delta 0.2, @params.no_speech_thold
+  end
+
+  def test_vad
+    assert_false @params.vad
+    @params.vad = true
+    assert_true @params.vad
+  end
+
+  def test_vad_model_path
+    assert_nil @params.vad_model_path
+    @params.vad_model_path = "silero-v5.1.2"
+    assert_equal Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path, @params.vad_model_path
+  end
+
+  def test_vad_model_path_with_nil
+    @params.vad_model_path = "silero-v5.1.2"
+    @params.vad_model_path = nil
+    assert_nil @params.vad_model_path
+  end
+
+  def test_vad_model_path_with_invalid
+    assert_raise TypeError do
+      @params.vad_model_path = Object.new
+    end
+  end
+
+  def test_vad_model_path_with_URI_string
+    @params.vad_model_path = "https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-silero-v5.1.2.bin"
+    assert_equal @params.vad_model_path, Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path
+  end
+
+  def test_vad_model_path_with_URI
+    @params.vad_model_path = URI("https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-silero-v5.1.2.bin")
+    assert_equal @params.vad_model_path, Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path
+  end
+
+  def test_vad_params
+    assert_kind_of Whisper::VAD::Params, @params.vad_params
+    default_params = @params.vad_params
+    assert_same default_params, @params.vad_params
+    assert_equal 0.5, default_params.threshold
+    new_params = Whisper::VAD::Params.new
+    @params.vad_params = new_params
+    assert_same new_params, @params.vad_params
+  end
+
+  def test_new_with_kw_args
+    params = Whisper::Params.new(language: "es")
+    assert_equal "es", params.language
+    assert_equal 1.0, params.max_initial_ts
+  end
+
+  def test_new_with_kw_args_non_existent
+    assert_raise ArgumentError do
+      Whisper::Params.new(non_existent: "value")
+    end
+  end
+
+  def test_new_with_kw_args_wrong_type
+    assert_raise TypeError do
+      Whisper::Params.new(language: 3)
+    end
+  end
+
+  data(PARAM_NAMES.collect {|param| [param, param]}.to_h)
+  def test_new_with_kw_args_default_values(param)
+    default_value = @params.send(param)
+    value = case [param, default_value]
+            in [*, true | false]
+              !default_value
+            in [*, Integer | Float]
+              default_value + 1
+            in [:language, *]
+              "es"
+            in [:initial_prompt, *]
+              "Initial prompt"
+            in [/_callback\Z/, *]
+              proc {}
+            in [/_user_data\Z/, *]
+              Object.new
+            in [:vad_model_path, *]
+              Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path
+            in [:vad_params, *]
+              Whisper::VAD::Params.new
+            end
+    params = Whisper::Params.new(param => value)
+    if Float === value
+      assert_in_delta value, params.send(param)
+    else
+      assert_equal value, params.send(param)
+    end
+
+    PARAM_NAMES.reject {|name| name == param}.each do |name|
+      expected = @params.send(name)
+      actual = params.send(name)
+      if Float === expected
+        assert_in_delta expected, actual
+      else
+        assert_equal expected, actual
+      end
+    end
+  end
+end
diff --git a/bindings/ruby/test/test_segment.rb b/bindings/ruby/test/test_segment.rb
new file mode 100644 (file)
index 0000000..e8b9987
--- /dev/null
@@ -0,0 +1,74 @@
+require_relative "helper"
+
+class TestSegment < TestBase
+  def test_iteration
+    whisper.each_segment do |segment|
+      assert_instance_of Whisper::Segment, segment
+    end
+  end
+
+  def test_enumerator
+    enum = whisper.each_segment
+    assert_instance_of Enumerator, enum
+    enum.to_a.each_with_index do |segment, index|
+      assert_instance_of Whisper::Segment, segment
+      assert_kind_of Integer, index
+    end
+  end
+
+  def test_start_time
+    i = 0
+    whisper.each_segment do |segment|
+      assert_equal 0, segment.start_time if i == 0
+      i += 1
+    end
+  end
+
+  def test_end_time
+    i = 0
+    whisper.each_segment do |segment|
+      assert_equal whisper.full_get_segment_t1(i) * 10, segment.end_time
+      i += 1
+    end
+  end
+
+  def test_no_speech_prob
+    no_speech_prob = nil
+    whisper.each_segment do |segment|
+      no_speech_prob = segment.no_speech_prob
+    end
+    assert no_speech_prob > 0.0
+  end
+
+  def test_on_new_segment
+    params = Whisper::Params.new
+    seg = nil
+    index = 0
+    params.on_new_segment do |segment|
+      assert_instance_of Whisper::Segment, segment
+      if index == 0
+        seg = segment
+        assert_equal 0, segment.start_time
+        assert_match(/ask not what your country can do for you, ask what you can do for your country/, segment.text)
+      end
+      index += 1
+    end
+    whisper.transcribe(AUDIO, params)
+    assert_equal 0, seg.start_time
+    assert_match(/ask not what your country can do for you, ask what you can do for your country/, seg.text)
+  end
+
+  def test_on_new_segment_twice
+    params = Whisper::Params.new
+    seg = nil
+    params.on_new_segment do |segment|
+      seg = segment
+      return
+    end
+    params.on_new_segment do |segment|
+      assert_same seg, segment
+      return
+    end
+    whisper.transcribe(AUDIO, params)
+  end
+end
diff --git a/bindings/ruby/test/test_vad.rb b/bindings/ruby/test/test_vad.rb
new file mode 100644 (file)
index 0000000..cb5e3c7
--- /dev/null
@@ -0,0 +1,19 @@
+require_relative "helper"
+
+class TestVAD < TestBase
+  def setup
+    @whisper = Whisper::Context.new("base.en")
+    vad_params = Whisper::VAD::Params.new
+    @params = Whisper::Params.new(
+      vad: true,
+      vad_model_path: "silero-v5.1.2",
+      vad_params:
+    )
+  end
+
+  def test_transcribe
+    @whisper.transcribe(TestBase::AUDIO, @params) do |text|
+      assert_match(/ask not what your country can do for you[,.] ask what you can do for your country/i, text)
+    end
+  end
+end
diff --git a/bindings/ruby/test/test_vad_params.rb b/bindings/ruby/test/test_vad_params.rb
new file mode 100644 (file)
index 0000000..add4899
--- /dev/null
@@ -0,0 +1,103 @@
+require_relative "helper"
+
+class TestVADParams < TestBase
+  PARAM_NAMES = [
+    :threshold,
+    :min_speech_duration_ms,
+    :min_silence_duration_ms,
+    :max_speech_duration_s,
+    :speech_pad_ms,
+    :samples_overlap
+  ]
+
+  def setup
+    @params = Whisper::VAD::Params.new
+  end
+
+  def test_new
+    params = Whisper::VAD::Params.new
+    assert_kind_of Whisper::VAD::Params, params
+  end
+
+  def test_threshold
+    assert_in_delta @params.threshold, 0.5
+    @params.threshold = 0.7
+    assert_in_delta @params.threshold, 0.7
+  end
+
+  def test_min_speech_duration
+    pend
+  end
+
+  def test_min_speech_duration_ms
+    assert_equal 250, @params.min_speech_duration_ms
+    @params.min_speech_duration_ms = 500
+    assert_equal 500, @params.min_speech_duration_ms
+  end
+
+  def test_min_silence_duration_ms
+    assert_equal 100, @params.min_silence_duration_ms
+    @params.min_silence_duration_ms = 200
+    assert_equal 200, @params.min_silence_duration_ms
+  end
+
+  def test_max_speech_duration
+    pend
+  end
+
+  def test_max_speech_duration_s
+    assert @params.max_speech_duration_s >= 10e37 # Defaults to FLT_MAX
+    @params.max_speech_duration_s = 60.0
+    assert_equal 60.0, @params.max_speech_duration_s
+  end
+
+  def test_speech_pad_ms
+    assert_equal 30, @params.speech_pad_ms
+    @params.speech_pad_ms = 50
+    assert_equal 50, @params.speech_pad_ms
+  end
+
+  def test_samples_overlap
+    assert_in_delta @params.samples_overlap, 0.1
+    @params.samples_overlap = 0.5
+    assert_in_delta @params.samples_overlap, 0.5
+  end
+
+  def test_equal
+    assert_equal @params, Whisper::VAD::Params.new
+  end
+
+  def test_new_with_kw_args
+    params = Whisper::VAD::Params.new(threshold: 0.7)
+    assert_in_delta params.threshold, 0.7
+    assert_equal 250, params.min_speech_duration_ms
+  end
+
+  def test_new_with_kw_args_non_existent
+    assert_raise ArgumentError do
+      Whisper::VAD::Params.new(non_existent: "value")
+    end
+  end
+
+  data(PARAM_NAMES.collect {|param| [param, param]}.to_h)
+  def test_new_with_kw_args_default_values(param)
+    default_value = @params.send(param)
+    value = default_value + 1
+    params = Whisper::VAD::Params.new(param => value)
+    if Float === value
+      assert_in_delta value, params.send(param)
+    else
+      assert_equal value, params.send(param)
+    end
+
+    PARAM_NAMES.reject {|name| name == param}.each do |name|
+      expected = @params.send(name)
+      actual = params.send(name)
+      if Float === expected
+        assert_in_delta expected, actual
+      else
+        assert_equal expected, actual
+      end
+    end
+  end
+end
diff --git a/bindings/ruby/test/test_whisper.rb b/bindings/ruby/test/test_whisper.rb
new file mode 100644 (file)
index 0000000..8f1e69d
--- /dev/null
@@ -0,0 +1,248 @@
+require_relative "helper"
+require "stringio"
+require "etc"
+
+# Exists to detect memory-related bug
+Whisper.log_set ->(level, buffer, user_data) {}, nil
+
+class TestWhisper < TestBase
+  def setup
+    @params  = Whisper::Params.new
+  end
+
+  def test_whisper
+    @whisper = Whisper::Context.new("base.en")
+    params  = Whisper::Params.new
+    params.print_timestamps = false
+
+    @whisper.transcribe(AUDIO, params) {|text|
+      assert_match(/ask not what your country can do for you, ask what you can do for your country/, text)
+    }
+  end
+
+  def test_transcribe_non_parallel
+    @whisper = Whisper::Context.new("base.en")
+    params  = Whisper::Params.new
+
+    @whisper.transcribe(AUDIO, params, n_processors: 1) {|text|
+      assert_match(/ask not what your country can do for you, ask what you can do for your country/, text)
+    }
+  end
+
+  def test_transcribe_n_processors
+    @whisper = Whisper::Context.new("base.en")
+    params  = Whisper::Params.new
+
+    @whisper.transcribe(AUDIO, params, n_processors: 4) {|text|
+      assert_match(/ask not what your country can do for you[,.] ask what you can do for your country/i, text)
+    }
+  end
+
+  sub_test_case "After transcription" do
+    def test_full_n_segments
+      assert_equal 1, whisper.full_n_segments
+    end
+
+    def test_full_lang_id
+      assert_equal 0, whisper.full_lang_id
+    end
+
+    def test_full_get_segment
+      segment = whisper.full_get_segment(0)
+      assert_equal 0, segment.start_time
+      assert_match(/ask not what your country can do for you, ask what you can do for your country/, segment.text)
+    end
+
+    def test_full_get_segment_t0
+      assert_equal 0, whisper.full_get_segment_t0(0)
+      assert_raise IndexError do
+        whisper.full_get_segment_t0(whisper.full_n_segments)
+      end
+      assert_raise IndexError do
+        whisper.full_get_segment_t0(-1)
+      end
+    end
+
+    def test_full_get_segment_t1
+      t1 = whisper.full_get_segment_t1(0)
+      assert_kind_of Integer, t1
+      assert t1 > 0
+      assert_raise IndexError do
+        whisper.full_get_segment_t1(whisper.full_n_segments)
+      end
+    end
+
+    def test_full_get_segment_speaker_turn_next
+      assert_false whisper.full_get_segment_speaker_turn_next(0)
+    end
+
+    def test_full_get_segment_text
+      assert_match(/ask not what your country can do for you, ask what you can do for your country/, whisper.full_get_segment_text(0))
+    end
+
+    def test_full_get_segment_no_speech_prob
+      prob = whisper.full_get_segment_no_speech_prob(0)
+      assert prob > 0.0
+      assert prob < 1.0
+    end
+  end
+
+  def test_lang_max_id
+    assert_kind_of Integer, Whisper.lang_max_id
+  end
+
+  def test_lang_id
+    assert_equal 0, Whisper.lang_id("en")
+    assert_raise ArgumentError do
+      Whisper.lang_id("non existing language")
+    end
+  end
+
+  def test_lang_str
+    assert_equal "en", Whisper.lang_str(0)
+    assert_raise IndexError do
+      Whisper.lang_str(Whisper.lang_max_id + 1)
+    end
+  end
+
+  def test_lang_str_full
+    assert_equal "english", Whisper.lang_str_full(0)
+    assert_raise IndexError do
+      Whisper.lang_str_full(Whisper.lang_max_id + 1)
+    end
+  end
+
+  def test_system_info_str
+    assert_match /\AWHISPER : COREML = \d | OPENVINO = \d |/, Whisper.system_info_str
+  end
+
+  def test_log_set
+    user_data = Object.new
+    logs = []
+    log_callback = ->(level, buffer, udata) {
+      logs << [level, buffer, udata]
+    }
+    Whisper.log_set log_callback, user_data
+    Whisper::Context.new("base.en")
+
+    assert logs.length > 30
+    logs.each do |log|
+      assert_include [Whisper::LOG_LEVEL_DEBUG, Whisper::LOG_LEVEL_INFO, Whisper::LOG_LEVEL_WARN], log[0]
+      assert_same user_data, log[2]
+    end
+  end
+
+  def test_log_suppress
+    stderr = $stderr
+    Whisper.log_set ->(level, buffer, user_data) {
+      # do nothing
+    }, nil
+    dev = StringIO.new("")
+    $stderr = dev
+    Whisper::Context.new("base.en")
+    assert_empty dev.string
+  ensure
+    $stderr = stderr
+  end
+
+  sub_test_case "full" do
+    def setup
+      super
+      @whisper = Whisper::Context.new("base.en")
+      @samples = File.read(AUDIO, nil, 78).unpack("s<*").collect {|i| i.to_f / 2**15}
+    end
+
+    def test_full
+      @whisper.full(@params, @samples, @samples.length)
+
+      assert_equal 1, @whisper.full_n_segments
+      assert_match(/ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text)
+    end
+
+    def test_full_without_length
+      @whisper.full(@params, @samples)
+
+      assert_equal 1, @whisper.full_n_segments
+      assert_match(/ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text)
+    end
+
+    def test_full_enumerator
+      samples = @samples.each
+      @whisper.full(@params, samples, @samples.length)
+
+      assert_equal 1, @whisper.full_n_segments
+      assert_match(/ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text)
+    end
+
+    def test_full_enumerator_without_length
+      samples = @samples.each
+      assert_raise ArgumentError do
+        @whisper.full(@params, samples)
+      end
+    end
+
+    def test_full_enumerator_with_too_large_length
+      samples = @samples.each.take(10).to_enum
+      assert_raise StopIteration do
+        @whisper.full(@params, samples, 11)
+      end
+    end
+
+    def test_full_with_memory_view
+      samples = JFKReader.new(AUDIO)
+      @whisper.full(@params, samples)
+
+      assert_equal 1, @whisper.full_n_segments
+      assert_match(/ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text)
+    end
+
+    def test_full_parallel
+      nprocessors = 2
+      @whisper.full_parallel(@params, @samples, @samples.length, nprocessors)
+
+      assert_equal nprocessors, @whisper.full_n_segments
+      text = @whisper.each_segment.collect(&:text).join
+      assert_match(/ask what you can do/i, text)
+      assert_match(/for your country/i, text)
+    end
+
+    def test_full_parallel_with_memory_view
+      nprocessors = 2
+      samples = JFKReader.new(AUDIO)
+      @whisper.full_parallel(@params, samples, nil, nprocessors)
+
+      assert_equal nprocessors, @whisper.full_n_segments
+      text = @whisper.each_segment.collect(&:text).join
+      assert_match(/ask what you can do/i, text)
+      assert_match(/for your country/i, text)
+    end
+
+    def test_full_parallel_without_length_and_n_processors
+      @whisper.full_parallel(@params, @samples)
+
+      assert_equal 1, @whisper.full_n_segments
+      text = @whisper.each_segment.collect(&:text).join
+      assert_match(/ask what you can do/i, text)
+      assert_match(/for your country/i, text)
+    end
+
+    def test_full_parallel_without_length
+      nprocessors = 2
+      @whisper.full_parallel(@params, @samples, nil, nprocessors)
+
+      assert_equal nprocessors, @whisper.full_n_segments
+      text = @whisper.each_segment.collect(&:text).join
+      assert_match(/ask what you can do/i, text)
+      assert_match(/for your country/i, text)
+    end
+
+    def test_full_parallel_without_n_processors
+      @whisper.full_parallel(@params, @samples, @samples.length)
+
+      assert_equal 1, @whisper.full_n_segments
+      text = @whisper.each_segment.collect(&:text).join
+      assert_match(/ask what you can do/i, text)
+      assert_match(/for your country/i, text)
+    end
+  end
+end
diff --git a/bindings/ruby/tests/helper.rb b/bindings/ruby/tests/helper.rb
deleted file mode 100644 (file)
index 389e15c..0000000
+++ /dev/null
@@ -1,24 +0,0 @@
-require "test/unit"
-require "whisper"
-require_relative "jfk_reader/jfk_reader"
-
-class TestBase < Test::Unit::TestCase
-  AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")
-
-  class << self
-    def whisper
-      return @whisper if @whisper
-
-      @whisper = Whisper::Context.new("base.en")
-      params = Whisper::Params.new
-      params.print_timestamps = false
-      @whisper.transcribe(TestBase::AUDIO, params)
-    end
-  end
-
-  private
-
-  def whisper
-    self.class.whisper
-  end
-end
diff --git a/bindings/ruby/tests/jfk_reader/.gitignore b/bindings/ruby/tests/jfk_reader/.gitignore
deleted file mode 100644 (file)
index 656da8d..0000000
+++ /dev/null
@@ -1,5 +0,0 @@
-Makefile
-jfk_reader.o
-jfk_reader.so
-jfk_reader.bundle
-jfk_reader.dll
diff --git a/bindings/ruby/tests/jfk_reader/extconf.rb b/bindings/ruby/tests/jfk_reader/extconf.rb
deleted file mode 100644 (file)
index 0d842d0..0000000
+++ /dev/null
@@ -1,3 +0,0 @@
-require "mkmf"
-
-create_makefile("jfk_reader")
diff --git a/bindings/ruby/tests/jfk_reader/jfk_reader.c b/bindings/ruby/tests/jfk_reader/jfk_reader.c
deleted file mode 100644 (file)
index 6657176..0000000
+++ /dev/null
@@ -1,68 +0,0 @@
-#include <ruby.h>
-#include <ruby/memory_view.h>
-#include <ruby/encoding.h>
-
-static VALUE
-jfk_reader_initialize(VALUE self, VALUE audio_path)
-{
-  rb_iv_set(self, "audio_path", audio_path);
-  return Qnil;
-}
-
-static bool
-jfk_reader_get_memory_view(const VALUE obj, rb_memory_view_t *view, int flags)
-{
-  VALUE audio_path = rb_iv_get(obj, "audio_path");
-  const char *audio_path_str = StringValueCStr(audio_path);
-  const int n_samples = 176000;
-  float *data = (float *)malloc(n_samples * sizeof(float));
-  short *samples = (short *)malloc(n_samples * sizeof(short));
-  FILE *file = fopen(audio_path_str, "rb");
-
-  fseek(file, 78, SEEK_SET);
-  fread(samples, sizeof(short), n_samples, file);
-  fclose(file);
-  for (int i = 0; i < n_samples; i++) {
-    data[i] = samples[i]/32768.0;
-  }
-
-  view->obj = obj;
-  view->data = (void *)data;
-  view->byte_size = sizeof(float) * n_samples;
-  view->readonly = true;
-  view->format = "f";
-  view->item_size = sizeof(float);
-  view->item_desc.components = NULL;
-  view->item_desc.length = 0;
-  view->ndim = 1;
-  view->shape = NULL;
-  view->sub_offsets = NULL;
-  view->private_data = NULL;
-
-  return true;
-}
-
-static bool
-jfk_reader_release_memory_view(const VALUE obj, rb_memory_view_t *view)
-{
-  return true;
-}
-
-static bool
-jfk_reader_memory_view_available_p(const VALUE obj)
-{
-  return true;
-}
-
-static const rb_memory_view_entry_t jfk_reader_view_entry = {
-  jfk_reader_get_memory_view,
-  jfk_reader_release_memory_view,
-  jfk_reader_memory_view_available_p
-};
-
-void Init_jfk_reader(void)
-{
-  VALUE cJFKReader = rb_define_class("JFKReader", rb_cObject);
-  rb_memory_view_register(cJFKReader, &jfk_reader_view_entry);
-  rb_define_method(cJFKReader, "initialize", jfk_reader_initialize, 1);
-}
diff --git a/bindings/ruby/tests/test_callback.rb b/bindings/ruby/tests/test_callback.rb
deleted file mode 100644 (file)
index a7f4924..0000000
+++ /dev/null
@@ -1,202 +0,0 @@
-require_relative "helper"
-
-class TestCallback < TestBase
-  def setup
-    GC.start
-    @params = Whisper::Params.new
-    @whisper = Whisper::Context.new("base.en")
-    @audio = File.join(AUDIO)
-  end
-
-  def test_new_segment_callback
-    @params.new_segment_callback = ->(context, state, n_new, user_data) {
-      assert_kind_of Integer, n_new
-      assert n_new > 0
-      assert_same @whisper, context
-
-      n_segments = context.full_n_segments
-      n_new.times do |i|
-        i_segment = n_segments - 1 + i
-        start_time = context.full_get_segment_t0(i_segment) * 10
-        end_time = context.full_get_segment_t1(i_segment) * 10
-        text = context.full_get_segment_text(i_segment)
-
-        assert_kind_of Integer, start_time
-        assert start_time >= 0
-        assert_kind_of Integer, end_time
-        assert end_time > 0
-        assert_match(/ask not what your country can do for you, ask what you can do for your country/, text) if i_segment == 0
-      end
-    }
-
-    @whisper.transcribe(@audio, @params)
-  end
-
-  def test_new_segment_callback_closure
-    search_word = "what"
-    @params.new_segment_callback = ->(context, state, n_new, user_data) {
-      n_segments = context.full_n_segments
-      n_new.times do |i|
-        i_segment = n_segments - 1 + i
-        text = context.full_get_segment_text(i_segment)
-        if text.include?(search_word)
-          t0 = context.full_get_segment_t0(i_segment)
-          t1 = context.full_get_segment_t1(i_segment)
-          raise "search word '#{search_word}' found at between #{t0} and #{t1}"
-        end
-      end
-    }
-
-    assert_raise RuntimeError do
-      @whisper.transcribe(@audio, @params)
-    end
-  end
-
-  def test_new_segment_callback_user_data
-    udata = Object.new
-    @params.new_segment_callback_user_data = udata
-    @params.new_segment_callback = ->(context, state, n_new, user_data) {
-      assert_same udata, user_data
-    }
-
-    @whisper.transcribe(@audio, @params)
-  end
-
-  def test_new_segment_callback_user_data_gc
-    @params.new_segment_callback_user_data = "My user data"
-    @params.new_segment_callback = ->(context, state, n_new, user_data) {
-      assert_equal "My user data", user_data
-    }
-    GC.start
-
-    assert_same @whisper, @whisper.transcribe(@audio, @params)
-  end
-
-  def test_progress_callback
-    first = nil
-    last = nil
-    @params.progress_callback = ->(context, state, progress, user_data) {
-      assert_kind_of Integer, progress
-      assert 0 <= progress && progress <= 100
-      assert_same @whisper, context
-      first = progress if first.nil?
-      last = progress
-    }
-    @whisper.transcribe(@audio, @params)
-    assert_equal 0, first
-    assert_equal 100, last
-  end
-
-  def test_progress_callback_user_data
-    udata = Object.new
-    @params.progress_callback_user_data = udata
-    @params.progress_callback = ->(context, state, n_new, user_data) {
-      assert_same udata, user_data
-    }
-
-    @whisper.transcribe(@audio, @params)
-  end
-
-  def test_on_progress
-    first = nil
-    last = nil
-    @params.on_progress do |progress|
-      assert_kind_of Integer, progress
-      assert 0 <= progress && progress <= 100
-      first = progress if first.nil?
-      last = progress
-    end
-    @whisper.transcribe(@audio, @params)
-    assert_equal 0, first
-    assert_equal 100, last
-  end
-
-  def test_encoder_begin_callback
-    i = 0
-    @params.encoder_begin_callback = ->(context, state, user_data) {
-      i += 1
-    }
-    @whisper.transcribe(@audio, @params)
-    assert i > 0
-  end
-
-  def test_encoder_begin_callback_abort
-    logs = []
-    Whisper.log_set -> (level, buffer, user_data) {
-      logs << buffer if level == Whisper::LOG_LEVEL_ERROR
-    }, logs
-    @params.encoder_begin_callback = ->(context, state, user_data) {
-      return false
-    }
-    @whisper.transcribe(@audio, @params)
-    assert_match(/encoder_begin_callback returned false - aborting/, logs.join)
-    Whisper.log_set ->(level, buffer, user_data) {}, nil
-  end
-
-  def test_encoder_begin_callback_user_data
-    udata = Object.new
-    @params.encoder_begin_callback_user_data = udata
-    yielded = nil
-    @params.encoder_begin_callback = ->(context, state, user_data) {
-      yielded = user_data
-    }
-    @whisper.transcribe(@audio, @params)
-    assert_same udata, yielded
-  end
-
-  def test_on_encoder_begin
-    i = 0
-    @params.on_encoder_begin do
-      i += 1
-    end
-    @whisper.transcribe(@audio, @params)
-    assert i > 0
-  end
-
-  def test_abort_callback
-    i = 0
-    @params.abort_callback = ->(user_data) {
-      assert_nil user_data
-      i += 1
-      return false
-    }
-    @whisper.transcribe(@audio, @params)
-    assert i > 0
-  end
-
-  def test_abort_callback_abort
-    i = 0
-    @params.abort_callback = ->(user_data) {
-      i += 1
-      return i == 3
-    }
-    @whisper.transcribe(@audio, @params)
-    assert_equal 3, i
-  end
-
-  def test_abort_callback_user_data
-    udata = Object.new
-    @params.abort_callback_user_data = udata
-    yielded = nil
-    @params.abort_callback = ->(user_data) {
-      yielded = user_data
-    }
-    @whisper.transcribe(@audio, @params)
-    assert_same udata, yielded
-  end
-
-  def test_abort_on
-    do_abort = false
-    _aborted_from_callback = false
-    @params.on_new_segment do |segment|
-      do_abort = true if segment.text.match?(/ask/)
-    end
-    i = 0
-    @params.abort_on do
-      i += 1
-      do_abort
-    end
-    @whisper.transcribe(@audio, @params)
-    assert i > 0
-  end
-end
diff --git a/bindings/ruby/tests/test_error.rb b/bindings/ruby/tests/test_error.rb
deleted file mode 100644 (file)
index 2f28849..0000000
+++ /dev/null
@@ -1,20 +0,0 @@
-require_relative "helper"
-
-class TestError < TestBase
-  def test_error
-    error = Whisper::Error.new(-2)
-    assert_equal "failed to compute log mel spectrogram", error.message
-    assert_equal(-2, error.code)
-  end
-
-  def test_unknown_error
-    error = Whisper::Error.new(-20)
-    assert_equal "unknown error", error.message
-  end
-
-  def test_non_int_code
-    assert_raise TypeError do
-      _error = Whisper::Error.new("non int")
-    end
-  end
-end
diff --git a/bindings/ruby/tests/test_model.rb b/bindings/ruby/tests/test_model.rb
deleted file mode 100644 (file)
index 5648fc3..0000000
+++ /dev/null
@@ -1,118 +0,0 @@
-require_relative "helper"
-require "pathname"
-
-class TestModel < TestBase
-  def test_model
-    whisper = Whisper::Context.new("base.en")
-    assert_instance_of Whisper::Model, whisper.model
-  end
-
-  def test_attributes
-    whisper = Whisper::Context.new("base.en")
-    model = whisper.model
-
-    assert_equal 51864, model.n_vocab
-    assert_equal 1500, model.n_audio_ctx
-    assert_equal 512, model.n_audio_state
-    assert_equal 8, model.n_audio_head
-    assert_equal 6, model.n_audio_layer
-    assert_equal 448, model.n_text_ctx
-    assert_equal 512, model.n_text_state
-    assert_equal 8, model.n_text_head
-    assert_equal 6, model.n_text_layer
-    assert_equal 80, model.n_mels
-    assert_equal 1, model.ftype
-    assert_equal "base", model.type
-  end
-
-  def test_gc
-    model = Whisper::Context.new("base.en").model
-    GC.start
-
-    assert_equal 51864, model.n_vocab
-    assert_equal 1500, model.n_audio_ctx
-    assert_equal 512, model.n_audio_state
-    assert_equal 8, model.n_audio_head
-    assert_equal 6, model.n_audio_layer
-    assert_equal 448, model.n_text_ctx
-    assert_equal 512, model.n_text_state
-    assert_equal 8, model.n_text_head
-    assert_equal 6, model.n_text_layer
-    assert_equal 80, model.n_mels
-    assert_equal 1, model.ftype
-    assert_equal "base", model.type
-  end
-
-  def test_pathname
-    path = Pathname(Whisper::Model.pre_converted_models["base.en"].to_path)
-    whisper = Whisper::Context.new(path)
-    model = whisper.model
-
-    assert_equal 51864, model.n_vocab
-    assert_equal 1500, model.n_audio_ctx
-    assert_equal 512, model.n_audio_state
-    assert_equal 8, model.n_audio_head
-    assert_equal 6, model.n_audio_layer
-    assert_equal 448, model.n_text_ctx
-    assert_equal 512, model.n_text_state
-    assert_equal 8, model.n_text_head
-    assert_equal 6, model.n_text_layer
-    assert_equal 80, model.n_mels
-    assert_equal 1, model.ftype
-    assert_equal "base", model.type
-  end
-
-  def test_auto_download
-    path = Whisper::Model.pre_converted_models["base.en"].to_path
-
-    assert_path_exist path
-    assert_equal 147964211, File.size(path)
-  end
-
-  def test_uri_string
-    path = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin"
-    whisper = Whisper::Context.new(path)
-    model = whisper.model
-
-    assert_equal 51864, model.n_vocab
-    assert_equal 1500, model.n_audio_ctx
-    assert_equal 512, model.n_audio_state
-    assert_equal 8, model.n_audio_head
-    assert_equal 6, model.n_audio_layer
-    assert_equal 448, model.n_text_ctx
-    assert_equal 512, model.n_text_state
-    assert_equal 8, model.n_text_head
-    assert_equal 6, model.n_text_layer
-    assert_equal 80, model.n_mels
-    assert_equal 1, model.ftype
-    assert_equal "base", model.type
-  end
-
-  def test_uri
-    path = URI("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin")
-    whisper = Whisper::Context.new(path)
-    model = whisper.model
-
-    assert_equal 51864, model.n_vocab
-    assert_equal 1500, model.n_audio_ctx
-    assert_equal 512, model.n_audio_state
-    assert_equal 8, model.n_audio_head
-    assert_equal 6, model.n_audio_layer
-    assert_equal 448, model.n_text_ctx
-    assert_equal 512, model.n_text_state
-    assert_equal 8, model.n_text_head
-    assert_equal 6, model.n_text_layer
-    assert_equal 80, model.n_mels
-    assert_equal 1, model.ftype
-    assert_equal "base", model.type
-  end
-
-  def test_coreml_model_auto_download
-    uri = Whisper::Model.coreml_compiled_models[Whisper::Model.pre_converted_models["tiny"]]
-    model_path = Pathname(uri.to_path).sub_ext("")
-    model_path.rmtree if model_path.exist?
-
-    uri.cache
-    assert_path_exist model_path
-  end
-end
diff --git a/bindings/ruby/tests/test_package.rb b/bindings/ruby/tests/test_package.rb
deleted file mode 100644 (file)
index 33cd2a3..0000000
+++ /dev/null
@@ -1,50 +0,0 @@
-require_relative "helper"
-require 'tempfile'
-require 'tmpdir'
-require 'shellwords'
-
-class TestPackage < TestBase
-  def test_build
-    Tempfile.create do |file|
-      assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true)
-      assert file.size > 0
-      assert_path_exist file.to_path
-    end
-  end
-
-  sub_test_case "Building binary on installation" do
-    def setup
-      system "rake", "build", exception: true
-    end
-
-    def test_install
-      gemspec = Gem::Specification.load("whispercpp.gemspec")
-      Dir.mktmpdir do |dir|
-        system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{gemspec.file_name.shellescape}", exception: true
-        assert_installed dir, gemspec.version
-      end
-    end
-
-    def test_install_with_coreml
-      omit_unless RUBY_PLATFORM.match?(/darwin/) do
-        gemspec = Gem::Specification.load("whispercpp.gemspec")
-        Dir.mktmpdir do |dir|
-          system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{gemspec.file_name.shellescape}", "--", "--enable-whisper-coreml", exception: true
-          assert_installed dir, gemspec.version
-          assert_nothing_raised do
-            libdir = File.join(dir, "gems", "#{gemspec.name}-#{gemspec.version}", "lib")
-            system "ruby", "-I", libdir, "-r", "whisper", "-e", "Whisper::Context.new('tiny')", exception: true
-          end
-        end
-      end
-    end
-
-    private
-
-    def assert_installed(dir, version)
-      assert_path_exist File.join(dir, "gems/whispercpp-#{version}/lib", "whisper.#{RbConfig::CONFIG["DLEXT"]}")
-      assert_path_exist File.join(dir, "gems/whispercpp-#{version}/LICENSE")
-      assert_path_not_exist File.join(dir, "gems/whispercpp-#{version}/ext/build")
-    end
-  end
-end
diff --git a/bindings/ruby/tests/test_params.rb b/bindings/ruby/tests/test_params.rb
deleted file mode 100644 (file)
index 9a95357..0000000
+++ /dev/null
@@ -1,297 +0,0 @@
-require_relative "helper"
-
-class TestParams < TestBase
-  PARAM_NAMES = [
-    :language,
-    :translate,
-    :no_context,
-    :single_segment,
-    :print_special,
-    :print_progress,
-    :print_realtime,
-    :print_timestamps,
-    :suppress_blank,
-    :suppress_nst,
-    :token_timestamps,
-    :split_on_word,
-    :initial_prompt,
-    :diarize,
-    :offset,
-    :duration,
-    :max_text_tokens,
-    :temperature,
-    :max_initial_ts,
-    :length_penalty,
-    :temperature_inc,
-    :entropy_thold,
-    :logprob_thold,
-    :no_speech_thold,
-    :new_segment_callback,
-    :new_segment_callback_user_data,
-    :progress_callback,
-    :progress_callback_user_data,
-    :abort_callback,
-    :abort_callback_user_data,
-    :vad,
-    :vad_model_path,
-    :vad_params,
-  ]
-
-  def setup
-    @params  = Whisper::Params.new
-  end
-
-  def test_language
-    @params.language = "en"
-    assert_equal @params.language, "en"
-    @params.language = "auto"
-    assert_equal @params.language, "auto"
-  end
-
-  def test_offset
-    @params.offset = 10_000
-    assert_equal @params.offset, 10_000
-    @params.offset = 0
-    assert_equal @params.offset, 0
-  end
-
-  def test_duration
-    @params.duration = 60_000
-    assert_equal @params.duration, 60_000
-    @params.duration = 0
-    assert_equal @params.duration, 0
-  end
-
-  def test_max_text_tokens
-    @params.max_text_tokens = 300
-    assert_equal @params.max_text_tokens, 300
-    @params.max_text_tokens = 0
-    assert_equal @params.max_text_tokens, 0
-  end
-
-  def test_translate
-    @params.translate = true
-    assert @params.translate
-    @params.translate = false
-    assert !@params.translate
-  end
-
-  def test_no_context
-    @params.no_context = true
-    assert @params.no_context
-    @params.no_context = false
-    assert !@params.no_context
-  end
-
-  def test_single_segment
-    @params.single_segment = true
-    assert @params.single_segment
-    @params.single_segment = false
-    assert !@params.single_segment
-  end
-
-  def test_print_special
-    @params.print_special = true
-    assert @params.print_special
-    @params.print_special = false
-    assert !@params.print_special
-  end
-
-  def test_print_progress
-    @params.print_progress = true
-    assert @params.print_progress
-    @params.print_progress = false
-    assert !@params.print_progress
-  end
-
-  def test_print_realtime
-    @params.print_realtime = true
-    assert @params.print_realtime
-    @params.print_realtime = false
-    assert !@params.print_realtime
-  end
-
-  def test_print_timestamps
-    @params.print_timestamps = true
-    assert @params.print_timestamps
-    @params.print_timestamps = false
-    assert !@params.print_timestamps
-  end
-
-  def test_suppress_blank
-    @params.suppress_blank = true
-    assert @params.suppress_blank
-    @params.suppress_blank = false
-    assert !@params.suppress_blank
-  end
-
-  def test_suppress_nst
-    @params.suppress_nst = true
-    assert @params.suppress_nst
-    @params.suppress_nst = false
-    assert !@params.suppress_nst
-  end
-
-  def test_token_timestamps
-    @params.token_timestamps = true
-    assert @params.token_timestamps
-    @params.token_timestamps = false
-    assert !@params.token_timestamps
-  end
-
-  def test_split_on_word
-    @params.split_on_word = true
-    assert @params.split_on_word
-    @params.split_on_word = false
-    assert !@params.split_on_word
-  end
-
-  def test_initial_prompt
-    assert_nil @params.initial_prompt
-    @params.initial_prompt = "You are a polite person."
-    assert_equal "You are a polite person.", @params.initial_prompt
-  end
-
-  def test_temperature
-    assert_equal 0.0, @params.temperature
-    @params.temperature = 0.5
-    assert_equal 0.5, @params.temperature
-  end
-
-  def test_max_initial_ts
-    assert_equal 1.0, @params.max_initial_ts
-    @params.max_initial_ts = 600.0
-    assert_equal 600.0, @params.max_initial_ts
-  end
-
-  def test_length_penalty
-    assert_equal(-1.0, @params.length_penalty)
-    @params.length_penalty = 0.5
-    assert_equal 0.5, @params.length_penalty
-  end
-
-  def test_temperature_inc
-    assert_in_delta 0.2, @params.temperature_inc
-    @params.temperature_inc = 0.5
-    assert_in_delta 0.5, @params.temperature_inc
-  end
-
-  def test_entropy_thold
-    assert_in_delta 2.4, @params.entropy_thold
-    @params.entropy_thold = 3.0
-    assert_in_delta 3.0, @params.entropy_thold
-  end
-
-  def test_logprob_thold
-    assert_in_delta(-1.0, @params.logprob_thold)
-    @params.logprob_thold = -0.5
-    assert_in_delta(-0.5, @params.logprob_thold)
-  end
-
-  def test_no_speech_thold
-    assert_in_delta 0.6, @params.no_speech_thold
-    @params.no_speech_thold = 0.2
-    assert_in_delta 0.2, @params.no_speech_thold
-  end
-
-  def test_vad
-    assert_false @params.vad
-    @params.vad = true
-    assert_true @params.vad
-  end
-
-  def test_vad_model_path
-    assert_nil @params.vad_model_path
-    @params.vad_model_path = "silero-v5.1.2"
-    assert_equal Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path, @params.vad_model_path
-  end
-
-  def test_vad_model_path_with_nil
-    @params.vad_model_path = "silero-v5.1.2"
-    @params.vad_model_path = nil
-    assert_nil @params.vad_model_path
-  end
-
-  def test_vad_model_path_with_invalid
-    assert_raise TypeError do
-      @params.vad_model_path = Object.new
-    end
-  end
-
-  def test_vad_model_path_with_URI_string
-    @params.vad_model_path = "https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-silero-v5.1.2.bin"
-    assert_equal @params.vad_model_path, Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path
-  end
-
-  def test_vad_model_path_with_URI
-    @params.vad_model_path = URI("https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-silero-v5.1.2.bin")
-    assert_equal @params.vad_model_path, Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path
-  end
-
-  def test_vad_params
-    assert_kind_of Whisper::VAD::Params, @params.vad_params
-    default_params = @params.vad_params
-    assert_same default_params, @params.vad_params
-    assert_equal 0.5, default_params.threshold
-    new_params = Whisper::VAD::Params.new
-    @params.vad_params = new_params
-    assert_same new_params, @params.vad_params
-  end
-
-  def test_new_with_kw_args
-    params = Whisper::Params.new(language: "es")
-    assert_equal "es", params.language
-    assert_equal 1.0, params.max_initial_ts
-  end
-
-  def test_new_with_kw_args_non_existent
-    assert_raise ArgumentError do
-      Whisper::Params.new(non_existent: "value")
-    end
-  end
-
-  def test_new_with_kw_args_wrong_type
-    assert_raise TypeError do
-      Whisper::Params.new(language: 3)
-    end
-  end
-
-  data(PARAM_NAMES.collect {|param| [param, param]}.to_h)
-  def test_new_with_kw_args_default_values(param)
-    default_value = @params.send(param)
-    value = case [param, default_value]
-            in [*, true | false]
-              !default_value
-            in [*, Integer | Float]
-              default_value + 1
-            in [:language, *]
-              "es"
-            in [:initial_prompt, *]
-              "Initial prompt"
-            in [/_callback\Z/, *]
-              proc {}
-            in [/_user_data\Z/, *]
-              Object.new
-            in [:vad_model_path, *]
-              Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path
-            in [:vad_params, *]
-              Whisper::VAD::Params.new
-            end
-    params = Whisper::Params.new(param => value)
-    if Float === value
-      assert_in_delta value, params.send(param)
-    else
-      assert_equal value, params.send(param)
-    end
-
-    PARAM_NAMES.reject {|name| name == param}.each do |name|
-      expected = @params.send(name)
-      actual = params.send(name)
-      if Float === expected
-        assert_in_delta expected, actual
-      else
-        assert_equal expected, actual
-      end
-    end
-  end
-end
diff --git a/bindings/ruby/tests/test_segment.rb b/bindings/ruby/tests/test_segment.rb
deleted file mode 100644 (file)
index e8b9987..0000000
+++ /dev/null
@@ -1,74 +0,0 @@
-require_relative "helper"
-
-class TestSegment < TestBase
-  def test_iteration
-    whisper.each_segment do |segment|
-      assert_instance_of Whisper::Segment, segment
-    end
-  end
-
-  def test_enumerator
-    enum = whisper.each_segment
-    assert_instance_of Enumerator, enum
-    enum.to_a.each_with_index do |segment, index|
-      assert_instance_of Whisper::Segment, segment
-      assert_kind_of Integer, index
-    end
-  end
-
-  def test_start_time
-    i = 0
-    whisper.each_segment do |segment|
-      assert_equal 0, segment.start_time if i == 0
-      i += 1
-    end
-  end
-
-  def test_end_time
-    i = 0
-    whisper.each_segment do |segment|
-      assert_equal whisper.full_get_segment_t1(i) * 10, segment.end_time
-      i += 1
-    end
-  end
-
-  def test_no_speech_prob
-    no_speech_prob = nil
-    whisper.each_segment do |segment|
-      no_speech_prob = segment.no_speech_prob
-    end
-    assert no_speech_prob > 0.0
-  end
-
-  def test_on_new_segment
-    params = Whisper::Params.new
-    seg = nil
-    index = 0
-    params.on_new_segment do |segment|
-      assert_instance_of Whisper::Segment, segment
-      if index == 0
-        seg = segment
-        assert_equal 0, segment.start_time
-        assert_match(/ask not what your country can do for you, ask what you can do for your country/, segment.text)
-      end
-      index += 1
-    end
-    whisper.transcribe(AUDIO, params)
-    assert_equal 0, seg.start_time
-    assert_match(/ask not what your country can do for you, ask what you can do for your country/, seg.text)
-  end
-
-  def test_on_new_segment_twice
-    params = Whisper::Params.new
-    seg = nil
-    params.on_new_segment do |segment|
-      seg = segment
-      return
-    end
-    params.on_new_segment do |segment|
-      assert_same seg, segment
-      return
-    end
-    whisper.transcribe(AUDIO, params)
-  end
-end
diff --git a/bindings/ruby/tests/test_vad.rb b/bindings/ruby/tests/test_vad.rb
deleted file mode 100644 (file)
index cb5e3c7..0000000
+++ /dev/null
@@ -1,19 +0,0 @@
-require_relative "helper"
-
-class TestVAD < TestBase
-  def setup
-    @whisper = Whisper::Context.new("base.en")
-    vad_params = Whisper::VAD::Params.new
-    @params = Whisper::Params.new(
-      vad: true,
-      vad_model_path: "silero-v5.1.2",
-      vad_params:
-    )
-  end
-
-  def test_transcribe
-    @whisper.transcribe(TestBase::AUDIO, @params) do |text|
-      assert_match(/ask not what your country can do for you[,.] ask what you can do for your country/i, text)
-    end
-  end
-end
diff --git a/bindings/ruby/tests/test_vad_params.rb b/bindings/ruby/tests/test_vad_params.rb
deleted file mode 100644 (file)
index add4899..0000000
+++ /dev/null
@@ -1,103 +0,0 @@
-require_relative "helper"
-
-class TestVADParams < TestBase
-  PARAM_NAMES = [
-    :threshold,
-    :min_speech_duration_ms,
-    :min_silence_duration_ms,
-    :max_speech_duration_s,
-    :speech_pad_ms,
-    :samples_overlap
-  ]
-
-  def setup
-    @params = Whisper::VAD::Params.new
-  end
-
-  def test_new
-    params = Whisper::VAD::Params.new
-    assert_kind_of Whisper::VAD::Params, params
-  end
-
-  def test_threshold
-    assert_in_delta @params.threshold, 0.5
-    @params.threshold = 0.7
-    assert_in_delta @params.threshold, 0.7
-  end
-
-  def test_min_speech_duration
-    pend
-  end
-
-  def test_min_speech_duration_ms
-    assert_equal 250, @params.min_speech_duration_ms
-    @params.min_speech_duration_ms = 500
-    assert_equal 500, @params.min_speech_duration_ms
-  end
-
-  def test_min_silence_duration_ms
-    assert_equal 100, @params.min_silence_duration_ms
-    @params.min_silence_duration_ms = 200
-    assert_equal 200, @params.min_silence_duration_ms
-  end
-
-  def test_max_speech_duration
-    pend
-  end
-
-  def test_max_speech_duration_s
-    assert @params.max_speech_duration_s >= 10e37 # Defaults to FLT_MAX
-    @params.max_speech_duration_s = 60.0
-    assert_equal 60.0, @params.max_speech_duration_s
-  end
-
-  def test_speech_pad_ms
-    assert_equal 30, @params.speech_pad_ms
-    @params.speech_pad_ms = 50
-    assert_equal 50, @params.speech_pad_ms
-  end
-
-  def test_samples_overlap
-    assert_in_delta @params.samples_overlap, 0.1
-    @params.samples_overlap = 0.5
-    assert_in_delta @params.samples_overlap, 0.5
-  end
-
-  def test_equal
-    assert_equal @params, Whisper::VAD::Params.new
-  end
-
-  def test_new_with_kw_args
-    params = Whisper::VAD::Params.new(threshold: 0.7)
-    assert_in_delta params.threshold, 0.7
-    assert_equal 250, params.min_speech_duration_ms
-  end
-
-  def test_new_with_kw_args_non_existent
-    assert_raise ArgumentError do
-      Whisper::VAD::Params.new(non_existent: "value")
-    end
-  end
-
-  data(PARAM_NAMES.collect {|param| [param, param]}.to_h)
-  def test_new_with_kw_args_default_values(param)
-    default_value = @params.send(param)
-    value = default_value + 1
-    params = Whisper::VAD::Params.new(param => value)
-    if Float === value
-      assert_in_delta value, params.send(param)
-    else
-      assert_equal value, params.send(param)
-    end
-
-    PARAM_NAMES.reject {|name| name == param}.each do |name|
-      expected = @params.send(name)
-      actual = params.send(name)
-      if Float === expected
-        assert_in_delta expected, actual
-      else
-        assert_equal expected, actual
-      end
-    end
-  end
-end
diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb
deleted file mode 100644 (file)
index d915041..0000000
+++ /dev/null
@@ -1,230 +0,0 @@
-require_relative "helper"
-require "stringio"
-require "etc"
-
-# Exists to detect memory-related bug
-Whisper.log_set ->(level, buffer, user_data) {}, nil
-
-class TestWhisper < TestBase
-  def setup
-    @params  = Whisper::Params.new
-  end
-
-  def test_whisper
-    @whisper = Whisper::Context.new("base.en")
-    params  = Whisper::Params.new
-    params.print_timestamps = false
-
-    @whisper.transcribe(AUDIO, params) {|text|
-      assert_match(/ask not what your country can do for you, ask what you can do for your country/, text)
-    }
-  end
-
-  sub_test_case "After transcription" do
-    def test_full_n_segments
-      assert_equal 1, whisper.full_n_segments
-    end
-
-    def test_full_lang_id
-      assert_equal 0, whisper.full_lang_id
-    end
-
-    def test_full_get_segment
-      segment = whisper.full_get_segment(0)
-      assert_equal 0, segment.start_time
-      assert_match(/ask not what your country can do for you, ask what you can do for your country/, segment.text)
-    end
-
-    def test_full_get_segment_t0
-      assert_equal 0, whisper.full_get_segment_t0(0)
-      assert_raise IndexError do
-        whisper.full_get_segment_t0(whisper.full_n_segments)
-      end
-      assert_raise IndexError do
-        whisper.full_get_segment_t0(-1)
-      end
-    end
-
-    def test_full_get_segment_t1
-      t1 = whisper.full_get_segment_t1(0)
-      assert_kind_of Integer, t1
-      assert t1 > 0
-      assert_raise IndexError do
-        whisper.full_get_segment_t1(whisper.full_n_segments)
-      end
-    end
-
-    def test_full_get_segment_speaker_turn_next
-      assert_false whisper.full_get_segment_speaker_turn_next(0)
-    end
-
-    def test_full_get_segment_text
-      assert_match(/ask not what your country can do for you, ask what you can do for your country/, whisper.full_get_segment_text(0))
-    end
-
-    def test_full_get_segment_no_speech_prob
-      prob = whisper.full_get_segment_no_speech_prob(0)
-      assert prob > 0.0
-      assert prob < 1.0
-    end
-  end
-
-  def test_lang_max_id
-    assert_kind_of Integer, Whisper.lang_max_id
-  end
-
-  def test_lang_id
-    assert_equal 0, Whisper.lang_id("en")
-    assert_raise ArgumentError do
-      Whisper.lang_id("non existing language")
-    end
-  end
-
-  def test_lang_str
-    assert_equal "en", Whisper.lang_str(0)
-    assert_raise IndexError do
-      Whisper.lang_str(Whisper.lang_max_id + 1)
-    end
-  end
-
-  def test_lang_str_full
-    assert_equal "english", Whisper.lang_str_full(0)
-    assert_raise IndexError do
-      Whisper.lang_str_full(Whisper.lang_max_id + 1)
-    end
-  end
-
-  def test_system_info_str
-    assert_match /\AWHISPER : COREML = \d | OPENVINO = \d |/, Whisper.system_info_str
-  end
-
-  def test_log_set
-    user_data = Object.new
-    logs = []
-    log_callback = ->(level, buffer, udata) {
-      logs << [level, buffer, udata]
-    }
-    Whisper.log_set log_callback, user_data
-    Whisper::Context.new("base.en")
-
-    assert logs.length > 30
-    logs.each do |log|
-      assert_include [Whisper::LOG_LEVEL_DEBUG, Whisper::LOG_LEVEL_INFO, Whisper::LOG_LEVEL_WARN], log[0]
-      assert_same user_data, log[2]
-    end
-  end
-
-  def test_log_suppress
-    stderr = $stderr
-    Whisper.log_set ->(level, buffer, user_data) {
-      # do nothing
-    }, nil
-    dev = StringIO.new("")
-    $stderr = dev
-    Whisper::Context.new("base.en")
-    assert_empty dev.string
-  ensure
-    $stderr = stderr
-  end
-
-  sub_test_case "full" do
-    def setup
-      super
-      @whisper = Whisper::Context.new("base.en")
-      @samples = File.read(AUDIO, nil, 78).unpack("s<*").collect {|i| i.to_f / 2**15}
-    end
-
-    def test_full
-      @whisper.full(@params, @samples, @samples.length)
-
-      assert_equal 1, @whisper.full_n_segments
-      assert_match(/ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text)
-    end
-
-    def test_full_without_length
-      @whisper.full(@params, @samples)
-
-      assert_equal 1, @whisper.full_n_segments
-      assert_match(/ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text)
-    end
-
-    def test_full_enumerator
-      samples = @samples.each
-      @whisper.full(@params, samples, @samples.length)
-
-      assert_equal 1, @whisper.full_n_segments
-      assert_match(/ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text)
-    end
-
-    def test_full_enumerator_without_length
-      samples = @samples.each
-      assert_raise ArgumentError do
-        @whisper.full(@params, samples)
-      end
-    end
-
-    def test_full_enumerator_with_too_large_length
-      samples = @samples.each.take(10).to_enum
-      assert_raise StopIteration do
-        @whisper.full(@params, samples, 11)
-      end
-    end
-
-    def test_full_with_memory_view
-      samples = JFKReader.new(AUDIO)
-      @whisper.full(@params, samples)
-
-      assert_equal 1, @whisper.full_n_segments
-      assert_match(/ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text)
-    end
-
-    def test_full_parallel
-      nprocessors = 2
-      @whisper.full_parallel(@params, @samples, @samples.length, nprocessors)
-
-      assert_equal nprocessors, @whisper.full_n_segments
-      text = @whisper.each_segment.collect(&:text).join
-      assert_match(/ask what you can do/i, text)
-      assert_match(/for your country/i, text)
-    end
-
-    def test_full_parallel_with_memory_view
-      nprocessors = 2
-      samples = JFKReader.new(AUDIO)
-      @whisper.full_parallel(@params, samples, nil, nprocessors)
-
-      assert_equal nprocessors, @whisper.full_n_segments
-      text = @whisper.each_segment.collect(&:text).join
-      assert_match(/ask what you can do/i, text)
-      assert_match(/for your country/i, text)
-    end
-
-    def test_full_parallel_without_length_and_n_processors
-      @whisper.full_parallel(@params, @samples)
-
-      assert_equal 1, @whisper.full_n_segments
-      text = @whisper.each_segment.collect(&:text).join
-      assert_match(/ask what you can do/i, text)
-      assert_match(/for your country/i, text)
-    end
-
-    def test_full_parallel_without_length
-      nprocessors = 2
-      @whisper.full_parallel(@params, @samples, nil, nprocessors)
-
-      assert_equal nprocessors, @whisper.full_n_segments
-      text = @whisper.each_segment.collect(&:text).join
-      assert_match(/ask what you can do/i, text)
-      assert_match(/for your country/i, text)
-    end
-
-    def test_full_parallel_without_n_processors
-      @whisper.full_parallel(@params, @samples, @samples.length)
-
-      assert_equal 1, @whisper.full_n_segments
-      text = @whisper.each_segment.collect(&:text).join
-      assert_match(/ask what you can do/i, text)
-      assert_match(/for your country/i, text)
-    end
-  end
-end
index 06bef943510ae3a75ca7caeaa9adb7142cda954d..b838aa9fbde558533bf199b0b9497a596bd66cf8 100644 (file)
@@ -4,7 +4,7 @@ Gem::Specification.new do |s|
   s.name    = "whispercpp"
   s.authors = ["Georgi Gerganov", "Todd A. Fisher"]
   s.version = '1.3.3'
-  s.date    = '2025-06-01'
+  s.date    = '2025-06-03'
   s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby}
   s.email   = 'todd.fisher@gmail.com'
   s.extra_rdoc_files = ['LICENSE', 'README.md']
@@ -21,7 +21,7 @@ Gem::Specification.new do |s|
               }
 
   s.summary = %q{Ruby whisper.cpp bindings}
-  s.test_files = s.files.select {|file| file.start_with? "tests/"}
+  s.test_files = s.files.select {|file| file.start_with? "test/"}
 
   s.extensions << 'ext/extconf.rb'
   s.required_ruby_version = '>= 3.1.0'