]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ruby : add `VAD::Context#segments_from_samples`, allow Pathname, etc. (#3633)
authorKITAITI Makoto <redacted>
Fri, 30 Jan 2026 13:59:36 +0000 (22:59 +0900)
committerGitHub <redacted>
Fri, 30 Jan 2026 13:59:36 +0000 (22:59 +0900)
* ruby : Bump version to 1.3.6

* Fix code in example

* Add sample code to transcribe from MemoryView

* Define GetVADContext macro

* Use GetVADContext

* Extract parse_full_args function

* Use parse_full_args in ruby_whisper_full_parallel

* Free samples after use

* Check return value of parse_full_args()

* Define GetVADParams macro

* Add VAD::Context#segments_from_samples

* Add tests for VAD::Context#segments_from_samples

* Add signature for VAD::Context#segments_from_samples

* Add sample code for VAD::Context#segments_from_samples

* Add test for Whisper::Context#transcribe with Pathname

* Make Whisper::Context#transcribe and Whisper::VAD::Context#detect accept Pathname

* Update signature of Whisper::Context#transcribe

* Fix variable name

* Don't free memory view

* Make parse_full_args return struct

* Fallback when failed to get MemoryView

* Add num of samples when too long

* Check members of MemoryView

* Fix a typo

* Remove unnecessary include

* Fix a typo

* Fix a typo

* Care the case of MemoryView doesn't fit spec

* Add TODO comment

* Add optimazation option to compiler flags

* Use ALLOC_N instead of malloc

* Add description to sample code

* Rename and change args: parse_full_args -> parse_samples

* Free samples when exception raised

* Assign type check result to a variable

* Define wrapper function of whisper_full

* Change signature of parse_samples for rb_ensure

* Ensure release MemoryView

* Extract fill_samples function

* Free samples memory when filling it failed

* Free samples memory when transcription failed

* Prepare transcription in wrapper funciton

* Change function name

* Simplify function boundary

19 files changed:
bindings/ruby/README.md
bindings/ruby/ext/extconf.rb
bindings/ruby/ext/ruby_whisper.c
bindings/ruby/ext/ruby_whisper.h
bindings/ruby/ext/ruby_whisper_context.c
bindings/ruby/ext/ruby_whisper_model.c
bindings/ruby/ext/ruby_whisper_params.c
bindings/ruby/ext/ruby_whisper_segment.c
bindings/ruby/ext/ruby_whisper_token.c
bindings/ruby/ext/ruby_whisper_transcribe.cpp
bindings/ruby/ext/ruby_whisper_vad_context.c
bindings/ruby/ext/ruby_whisper_vad_context_detect.cpp
bindings/ruby/ext/ruby_whisper_vad_params.c
bindings/ruby/ext/ruby_whisper_vad_segment.c
bindings/ruby/ext/ruby_whisper_vad_segments.c
bindings/ruby/sig/whisper.rbs
bindings/ruby/test/test_vad_context.rb
bindings/ruby/test/test_whisper.rb
bindings/ruby/whispercpp.gemspec

index ea202753b677fd3f986362fe447ce3b4247ec33f..86774158355f4fca04e4d2c24853ac8025e2c6a8 100644 (file)
@@ -323,7 +323,24 @@ whisper
   end
 ```
 
-The second argument `samples` may be an array, an object with `length` and `each` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy.
+The second argument `samples` may be an array, an object with `length` and `each` method, or a MemoryView.
+
+If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy.
+
+```ruby
+require "torchaudio"
+require "arrow-numo-narray"
+require "whisper"
+
+waveform, sample_rate = TorchAudio.load("test/fixtures/jfk.wav")
+# Convert Torch::Tensor to Arrow::Array via Numo::NArray
+samples = waveform.squeeze.numo.to_arrow.to_arrow_array
+
+whisper = Whisper::Context.new("base")
+whisper
+  # Arrow::Array exports MemoryView
+  .full(Whisper::Params.new, samples)
+```
 
 Using VAD separately from ASR
 -----------------------------
@@ -334,13 +351,27 @@ VAD feature itself is useful. You can use it separately from ASR:
 vad = Whisper::VAD::Context.new("silero-v6.2.0")
 vad
   .detect("path/to/audio.wav", Whisper::VAD::Params.new)
-  .each_with_index do |segment, index|
+  .each.with_index do |segment, index|
     segment => {start_time: st, end_time: ed} # `Segment` responds to `#deconstruct_keys`
 
     puts "[%{nth}: %{st} --> %{ed}]" % {nth: index + 1, st:, ed:}
   end
 ```
 
+You may also low level API `Whisper::VAD::Context#segments_from_samples` as such `Whisper::Context#full`:
+
+```ruby
+# Ruby Array
+reader = WaveFile::Reader.new("path/to/audio.wav", WaveFile::Format.new(:mono, :float, 16000))
+samples = reader.enum_for(:each_buffer).map(&:samples).flatten
+
+# Or, object which exports MemoryView
+waveform, sample_rate = TorchAudio.load("test/fixtures/jfk.wav")
+samples = waveform.squeeze.numo.to_arrow.to_arrow_array
+
+segments = vad.segments_from_samples(Whisper::VAD::Params.new, samples)
+```
+
 Development
 -----------
 
index 8a5ac67457b048cb29e6d1d2ea3e1603f8bd012e..acff501aa3b3b8215fc2399f355455409d4d33a0 100644 (file)
@@ -7,6 +7,7 @@ options = Options.new(cmake).to_s
 have_library("gomp") rescue nil
 libs = Dependencies.new(cmake, options).to_s
 
+$CFLAGS << " -O3 -march=native"
 $INCFLAGS << " -Isources/include -Isources/ggml/include -Isources/examples"
 $LOCAL_LIBS << " #{libs}"
 $cleanfiles << " build #{libs}"
index ac677e9e3df8992d94806bee5913a7247a92f150..eb95829c03228815d26a7da8ac662eaa810fe964 100644 (file)
@@ -1,5 +1,3 @@
-#include <ruby.h>
-#include <ruby/memory_view.h>
 #include "ruby_whisper.h"
 
 VALUE mWhisper;
index 3f5660c374dcc31551f4f192c5aaba2a6ea5ac42..c2c9866ae0de868dbc1165beffd3ef92b8fb3c5a 100644 (file)
@@ -1,6 +1,8 @@
 #ifndef RUBY_WHISPER_H
 #define RUBY_WHISPER_H
 
+#include <ruby.h>
+#include <ruby/memory_view.h>
 #include "whisper.h"
 
 typedef struct {
@@ -55,6 +57,13 @@ typedef struct {
   struct whisper_vad_context *context;
 } ruby_whisper_vad_context;
 
+typedef struct parsed_samples_t {
+  float *samples;
+  int n_samples;
+  rb_memory_view_t memview;
+  bool memview_exported;
+} parsed_samples_t;
+
 #define GetContext(obj, rw) do { \
   TypedData_Get_Struct((obj), ruby_whisper, &ruby_whisper_type, (rw)); \
   if ((rw)->context == NULL) { \
@@ -69,6 +78,17 @@ typedef struct {
   } \
 } while (0)
 
+#define GetVADContext(obj, rwvc) do { \
+    TypedData_Get_Struct((obj), ruby_whisper_vad_context, &ruby_whisper_vad_context_type, (rwvc)); \
+    if ((rwvc)->context == NULL) { \
+      rb_raise(rb_eRuntimeError, "Not initialized"); \
+    } \
+} while (0)
+
+#define GetVADParams(obj, rwvp) do { \
+  TypedData_Get_Struct((obj), ruby_whisper_vad_params, &ruby_whisper_vad_params_type, (rwvp)); \
+} while (0)
+
 #define GetVADSegments(obj, rwvss) do { \
   TypedData_Get_Struct((obj), ruby_whisper_vad_segments, &ruby_whisper_vad_segments_type, (rwvss)); \
   if ((rwvss)->segments == NULL) { \
index a7b5f8513db56200955fa36257d584e58afa2760..84790e3dedfa90d64d7511be57bdf15e5fbdaa6a 100644 (file)
@@ -1,5 +1,3 @@
-#include <ruby.h>
-#include <ruby/memory_view.h>
 #include "ruby_whisper.h"
 
 extern ID id_to_s;
@@ -27,6 +25,27 @@ extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context);
 
 ID transcribe_option_names[1];
 
+typedef struct fill_samples_args {
+  float *dest;
+  VALUE *src;
+  int n_samples;
+} fill_samples_args;
+
+typedef struct full_args {
+  VALUE *context;
+  VALUE *params;
+  float *samples;
+  int n_samples;
+} full_args;
+
+typedef struct full_parallel_args {
+  VALUE *context;
+  VALUE *params;
+  float *samples;
+  int n_samples;
+  int n_processors;
+} full_parallel_args;
+
 static void
 ruby_whisper_free(ruby_whisper *rw)
 {
@@ -272,82 +291,175 @@ VALUE ruby_whisper_model_type(VALUE self)
   return rb_str_new2(whisper_model_type_readable(rw->context));
 }
 
-/*
- * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
- * Not thread safe for same context
- * Uses the specified decoding strategy to obtain the text.
- *
- * call-seq:
- *   full(params, samples, n_samples) -> nil
- *   full(params, samples) -> nil
- *
- * The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data.
- */
-VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self)
+static bool
+check_memory_view(rb_memory_view_t *memview)
 {
-  if (argc < 2 || argc > 3) {
-    rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
+  if (strcmp(memview->format, "f") != 0) {
+    rb_warn("currently only format \"f\" is supported for MemoryView, but given: %s", memview->format);
+    return false;
+  }
+  if (memview->ndim != 1) {
+    rb_warn("currently only 1 dimensional MemoryView is supported, but given: %zd", memview->ndim);
+    return false;
   }
 
-  ruby_whisper *rw;
-  ruby_whisper_params *rwp;
-  GetContext(self, rw);
-  VALUE params = argv[0];
-  TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
-  VALUE samples = argv[1];
-  int n_samples;
-  rb_memory_view_t view;
-  const bool memory_view_available_p = rb_memory_view_available_p(samples);
-  if (argc == 3) {
-    n_samples = NUM2INT(argv[2]);
-    if (TYPE(samples) == T_ARRAY) {
-      if (RARRAY_LEN(samples) < n_samples) {
-        rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
+  return true;
+}
+
+static VALUE
+fill_samples(VALUE rb_args)
+{
+  fill_samples_args *args = (fill_samples_args *)rb_args;
+
+  if (RB_TYPE_P(*args->src, T_ARRAY)) {
+    for (int i = 0; i < args->n_samples; i++) {
+      args->dest[i] = RFLOAT_VALUE(rb_ary_entry(*args->src, i));
+    }
+  } else {
+    // TODO: use rb_block_call
+    VALUE iter = rb_funcall(*args->src, id_to_enum, 1, rb_str_new2("each"));
+    for (int i = 0; i < args->n_samples; i++) {
+      // TODO: check if iter is exhausted and raise ArgumentError appropriately
+      VALUE sample = rb_funcall(iter, id_next, 0);
+      args->dest[i] = RFLOAT_VALUE(sample);
+    }
+  }
+
+  return Qnil;
+}
+
+struct parsed_samples_t
+parse_samples(VALUE *samples, VALUE *n_samples)
+{
+  bool memview_available = rb_memory_view_available_p(*samples);
+  struct parsed_samples_t parsed = {0};
+  parsed.memview_exported = false;
+  const bool is_array = RB_TYPE_P(*samples, T_ARRAY);
+
+  if (!NIL_P(*n_samples)) {
+    parsed.n_samples = NUM2INT(*n_samples);
+    if (is_array) {
+      if (RARRAY_LEN(*samples) < parsed.n_samples) {
+        rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(*samples), parsed.n_samples);
       }
     }
     // Should check when samples.respond_to?(:length)?
   } else {
-    if (TYPE(samples) == T_ARRAY) {
-      if (RARRAY_LEN(samples) > INT_MAX) {
+    if (is_array) {
+      if (RARRAY_LEN(*samples) > INT_MAX) {
         rb_raise(rb_eArgError, "samples are too long");
       }
-      n_samples = (int)RARRAY_LEN(samples);
-    } else if (memory_view_available_p) {
-      if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
-        view.obj = Qnil;
-        rb_raise(rb_eArgError, "unable to get a memory view");
+      parsed.n_samples = (int)RARRAY_LEN(*samples);
+    } else if (memview_available) {
+      bool memview_got = rb_memory_view_get(*samples, &parsed.memview, RUBY_MEMORY_VIEW_SIMPLE);
+      if (memview_got) {
+        parsed.memview_exported = check_memory_view(&parsed.memview);
+        if (!parsed.memview_exported) {
+          rb_memory_view_release(&parsed.memview);
+          parsed.memview = (rb_memory_view_t){0};
+        }
       }
-      ssize_t n_samples_size = view.byte_size / view.item_size;
-      if (n_samples_size > INT_MAX) {
-        rb_raise(rb_eArgError, "samples are too long");
+      if (parsed.memview_exported) {
+        ssize_t n_samples_size = parsed.memview.byte_size / parsed.memview.item_size;
+        if (n_samples_size > INT_MAX) {
+          rb_memory_view_release(&parsed.memview);
+          rb_raise(rb_eArgError, "samples are too long: %zd", n_samples_size);
+        }
+        parsed.n_samples = (int)n_samples_size;
+      } else {
+        rb_warn("unable to get a memory view. fallbacks to Ruby object");
+        if (rb_respond_to(*samples, id_length)) {
+          parsed.n_samples = NUM2INT(rb_funcall(*samples, id_length, 0));
+        } else {
+          rb_raise(rb_eArgError, "samples must respond to :length");
+        }
       }
-      n_samples = (int)n_samples_size;
-    } else if (rb_respond_to(samples, id_length)) {
-      n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
+    } else if (rb_respond_to(*samples, id_length)) {
+      parsed.n_samples = NUM2INT(rb_funcall(*samples, id_length, 0));
     } else {
-      rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
+      rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of float when n_samples is not given");
     }
   }
-  float * c_samples = (float *)malloc(n_samples * sizeof(float));
-  if (memory_view_available_p)  {
-    c_samples = (float *)view.data;
+
+  if (parsed.memview_exported)  {
+    parsed.samples = (float *)parsed.memview.data;
   } else {
-    if (TYPE(samples) == T_ARRAY) {
-      for (int i = 0; i < n_samples; i++) {
-        c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
-      }
-    } else {
-      // TODO: use rb_block_call
-      VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
-      for (int i = 0; i < n_samples; i++) {
-        // TODO: check if iter is exhausted and raise ArgumentError appropriately
-        VALUE sample = rb_funcall(iter, id_next, 0);
-        c_samples[i] = RFLOAT_VALUE(sample);
-      }
+    parsed.samples = ALLOC_N(float, parsed.n_samples);
+    fill_samples_args args = {
+      parsed.samples,
+      samples,
+      parsed.n_samples,
+    };
+    int state;
+    rb_protect(fill_samples, (VALUE)&args, &state);
+    if (state) {
+      xfree(parsed.samples);
+      rb_jump_tag(state);
     }
   }
-  prepare_transcription(rwp, &self);
-  const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples);
+
+  return parsed;
+}
+
+VALUE
+release_samples(VALUE rb_parsed_args)
+{
+  parsed_samples_t *parsed_args = (parsed_samples_t *)rb_parsed_args;
+
+  if (parsed_args->memview_exported) {
+    rb_memory_view_release(&parsed_args->memview);
+  } else {
+    xfree(parsed_args->samples);
+  }
+  *parsed_args = (parsed_samples_t){0};
+
+  return Qnil;
+}
+
+static VALUE
+full_body(VALUE rb_args)
+{
+  full_args *args = (full_args *)rb_args;
+
+  ruby_whisper *rw;
+  ruby_whisper_params *rwp;
+  GetContext(*args->context, rw);
+  TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
+
+  prepare_transcription(rwp, args->context);
+  int result = whisper_full(rw->context, rwp->params, args->samples, args->n_samples);
+
+  return INT2NUM(result);
+}
+
+/*
+ * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
+ * Not thread safe for same context
+ * Uses the specified decoding strategy to obtain the text.
+ *
+ * call-seq:
+ *   full(params, samples, n_samples) -> nil
+ *   full(params, samples) -> nil
+ *
+ * The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data.
+ */
+VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self)
+{
+  if (argc < 2 || argc > 3) {
+    rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
+  }
+
+  VALUE n_samples = argc == 2 ? Qnil : argv[2];
+
+  struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples);
+  full_args args = {
+    &self,
+    &argv[0],
+    parsed.samples,
+    parsed.n_samples,
+  };
+  VALUE rb_result = rb_ensure(full_body, (VALUE)&args, release_samples, (VALUE)&parsed);
+  const int result = NUM2INT(rb_result);
   if (0 == result) {
     return self;
   } else {
@@ -355,6 +467,22 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self)
   }
 }
 
+static VALUE
+full_parallel_body(VALUE rb_args)
+{
+  full_parallel_args *args = (full_parallel_args *)rb_args;
+
+  ruby_whisper *rw;
+  ruby_whisper_params *rwp;
+  GetContext(*args->context, rw);
+  TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
+
+  prepare_transcription(rwp, args->context);
+  int result = whisper_full_parallel(rw->context, rwp->params, args->samples, args->n_samples, args->n_processors);
+
+  return INT2NUM(result);
+}
+
 /*
  * Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
  * Result is stored in the default state of the context
@@ -372,19 +500,11 @@ static VALUE
 ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self)
 {
   if (argc < 2 || argc > 4) {
-    rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
+    rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..4)", argc);
   }
 
-  ruby_whisper *rw;
-  ruby_whisper_params *rwp;
-  GetContext(self, rw);
-  VALUE params = argv[0];
-  TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
-  VALUE samples = argv[1];
-  int n_samples;
+  VALUE n_samples = argc == 2 ? Qnil : argv[2];
   int n_processors;
-  rb_memory_view_t view;
-  const bool memory_view_available_p = rb_memory_view_available_p(samples);
   switch (argc) {
   case 2:
     n_processors = 1;
@@ -396,56 +516,16 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self)
     n_processors = NUM2INT(argv[3]);
     break;
   }
-  if (argc >= 3 && !NIL_P(argv[2])) {
-    n_samples = NUM2INT(argv[2]);
-    if (TYPE(samples) == T_ARRAY) {
-      if (RARRAY_LEN(samples) < n_samples) {
-        rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
-      }
-    }
-    // Should check when samples.respond_to?(:length)?
-  } else if (memory_view_available_p) {
-    if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
-      view.obj = Qnil;
-      rb_raise(rb_eArgError, "unable to get a memory view");
-    }
-    ssize_t n_samples_size = view.byte_size / view.item_size;
-    if (n_samples_size > INT_MAX) {
-      rb_raise(rb_eArgError, "samples are too long");
-    }
-    n_samples = (int)n_samples_size;
-  } else {
-    if (TYPE(samples) == T_ARRAY) {
-      if (RARRAY_LEN(samples) > INT_MAX) {
-        rb_raise(rb_eArgError, "samples are too long");
-      }
-      n_samples = (int)RARRAY_LEN(samples);
-    } else if (rb_respond_to(samples, id_length)) {
-      n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
-    } else {
-      rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
-    }
-  }
-  float * c_samples = (float *)malloc(n_samples * sizeof(float));
-  if (memory_view_available_p) {
-    c_samples = (float *)view.data;
-  } else {
-    if (TYPE(samples) == T_ARRAY) {
-      for (int i = 0; i < n_samples; i++) {
-        c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
-      }
-    } else {
-      // FIXME: use rb_block_call
-      VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
-      for (int i = 0; i < n_samples; i++) {
-        // TODO: check if iter is exhausted and raise ArgumentError
-        VALUE sample = rb_funcall(iter, id_next, 0);
-        c_samples[i] = RFLOAT_VALUE(sample);
-      }
-    }
-  }
-  prepare_transcription(rwp, &self);
-  const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors);
+  struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples);
+  const full_parallel_args args = {
+    &self,
+    &argv[0],
+    parsed.samples,
+    parsed.n_samples,
+    n_processors,
+  };
+  const VALUE rb_result = rb_ensure(full_parallel_body, (VALUE)&args, release_samples, (VALUE)&parsed);
+  const int result = NUM2INT(rb_result);
   if (0 == result) {
     return self;
   } else {
index b196a8b5cb5248b93651d670aef653cc283d4cfc..0e91fb3f87f852cc392014831061fc8fde974740 100644 (file)
@@ -1,4 +1,3 @@
-#include <ruby.h>
 #include "ruby_whisper.h"
 
 extern const rb_data_type_t ruby_whisper_type;
index 4dfe2575a39d78ee5245e60d8c94318b23040259..61eb17336767c9300e160fc3c8bd6294d4e4579d 100644 (file)
@@ -1,4 +1,3 @@
-#include <ruby.h>
 #include "ruby_whisper.h"
 
 #define BOOL_PARAMS_SETTER(self, prop, value) \
index 5229cb539003c3369a8fe21179c512ef33d3fac1..ee0d66c4cc8f7dd4f61045915705fad68c08cca7 100644 (file)
@@ -1,4 +1,3 @@
-#include <ruby.h>
 #include "ruby_whisper.h"
 
 #define N_KEY_NAMES 6
index ea4f4e635d2681270895a43d71a979761bcb096b..56a7eab2231c157345d1a6ca3b8469e0aed44105 100644 (file)
@@ -1,4 +1,3 @@
-#include <ruby.h>
 #include "ruby_whisper.h"
 
 #define N_KEY_NAMES 11
index 594b2db90e3b33635180ab909c0850fe6c810bfe..c00fbcd1defbc945ff397df32be57f0bc049795f 100644 (file)
@@ -1,4 +1,3 @@
-#include <ruby.h>
 #include "ruby_whisper.h"
 #include "common-whisper.h"
 #include <string>
@@ -13,6 +12,7 @@ extern const rb_data_type_t ruby_whisper_params_type;
 
 extern ID id_to_s;
 extern ID id_call;
+extern ID id_to_path;
 extern ID transcribe_option_names[1];
 
 extern void
@@ -50,6 +50,9 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
     rb_raise(rb_eRuntimeError, "Expected file path to wave file");
   }
 
+  if (rb_respond_to(wave_file_path, id_to_path)) {
+    wave_file_path = rb_funcall(wave_file_path, id_to_path, 0);
+  }
   std::string fname_inp = StringValueCStr(wave_file_path);
 
   std::vector<float> pcmf32; // mono-channel F32 PCM
index bf2ed2ba4651668b005a3676ef4a91f6fee19a53..97c9736b6f425efcdb6f07521b299042f6bef5c1 100644 (file)
@@ -1,12 +1,23 @@
-#include <ruby.h>
 #include "ruby_whisper.h"
 
 extern ID id_to_s;
 
 extern VALUE cVADContext;
 
+extern const rb_data_type_t ruby_whisper_vad_params_type;
 extern VALUE ruby_whisper_vad_detect(VALUE self, VALUE file_path, VALUE params);
 extern VALUE ruby_whisper_normalize_model_path(VALUE model_path);
+extern parsed_samples_t parse_samples(VALUE *samples, VALUE *n_samples);
+extern VALUE release_samples(VALUE parsed);
+
+extern VALUE ruby_whisper_vad_segments_s_init(struct whisper_vad_segments *segments);
+
+typedef struct segments_from_samples_args {
+  VALUE *context;
+  VALUE *params;
+  float *samples;
+  int n_samples;
+} segments_from_samples_args;
 
 static size_t
 ruby_whisper_vad_context_memsize(const void *p)
@@ -66,10 +77,46 @@ ruby_whisper_vad_context_initialize(VALUE self, VALUE model_path)
   return Qnil;
 }
 
+static VALUE
+segments_from_samples_body(VALUE rb_args)
+{
+  segments_from_samples_args *args = (segments_from_samples_args *)rb_args;
+
+  ruby_whisper_vad_context *rwvc;
+  ruby_whisper_vad_params *rwvp;
+  GetVADContext(*args->context, rwvc);
+  GetVADParams(*args->params, rwvp);
+
+  struct whisper_vad_segments *segments = whisper_vad_segments_from_samples(rwvc->context, rwvp->params, args->samples, args->n_samples);
+
+  return ruby_whisper_vad_segments_s_init(segments);
+}
+
+static VALUE
+ruby_whisper_vad_segments_from_samples(int argc, VALUE *argv, VALUE self)
+{
+  if (argc < 2 || argc > 3) {
+    rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
+  }
+
+  VALUE n_samples = argc == 2 ? Qnil : argv[2];
+  struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples);
+  segments_from_samples_args args = {
+    &self,
+    &argv[0],
+    parsed.samples,
+    parsed.n_samples,
+  };
+  VALUE segments = rb_ensure(segments_from_samples_body, (VALUE)&args, release_samples, (VALUE)&parsed);
+
+  return segments;
+}
+
 void init_ruby_whisper_vad_context(VALUE *mVAD)
 {
   cVADContext = rb_define_class_under(*mVAD, "Context", rb_cObject);
   rb_define_alloc_func(cVADContext, ruby_whisper_vad_context_s_allocate);
   rb_define_method(cVADContext, "initialize", ruby_whisper_vad_context_initialize, 1);
+  rb_define_method(cVADContext, "segments_from_samples", ruby_whisper_vad_segments_from_samples, -1);
   rb_define_method(cVADContext, "detect", ruby_whisper_vad_detect, 2);
 }
index 58609f877429cc27f87acaed3d40105b5f40dfe3..802b0222dbd27d9497c55323aeac2f3d10170664 100644 (file)
@@ -1,4 +1,3 @@
-#include <ruby.h>
 #include "ruby_whisper.h"
 #include "common-whisper.h"
 #include <string>
@@ -8,6 +7,8 @@
 extern "C" {
 #endif
 
+extern ID id_to_path;
+
 extern VALUE cVADSegments;
 
 extern const rb_data_type_t ruby_whisper_vad_context_type;
@@ -25,12 +26,12 @@ ruby_whisper_vad_detect(VALUE self, VALUE file_path, VALUE params) {
   std::vector<std::vector<float>> pcmf32s;
   whisper_vad_segments *segments;
 
-  TypedData_Get_Struct(self, ruby_whisper_vad_context, &ruby_whisper_vad_context_type, rwvc);
-  if (rwvc->context == NULL) {
-    rb_raise(rb_eRuntimeError, "Doesn't have referenxe to context internally");
-  }
+  GetVADContext(self, rwvc);
   TypedData_Get_Struct(params, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
 
+  if (rb_respond_to(file_path, id_to_path)) {
+    file_path = rb_funcall(file_path, id_to_path, 0);
+  }
   cpp_file_path = StringValueCStr(file_path);
 
   if (!read_audio_data(cpp_file_path, pcmf32, pcmf32s, false)) {
index f254bfa2138781a58e1c387d8f63b1748df7c525..28256650e32beac425164ca03c73e08f76d66260 100644 (file)
@@ -1,4 +1,3 @@
-#include <ruby.h>
 #include "ruby_whisper.h"
 
 #define DEFINE_PARAM(param_name, nth) \
index 49ff0aadcce81a46b7570dde37fc8d09a34cfea3..84a007bb7257633652956f3aa686170f146553bf 100644 (file)
@@ -1,4 +1,3 @@
-#include <ruby.h>
 #include "ruby_whisper.h"
 
 #define N_KEY_NAMES 2
index 1bb375937a43e3a556532d259bdf40d31c2ec1b8..db62fdb622222297c36e7b1d6d167c8fdb5490eb 100644 (file)
@@ -1,4 +1,3 @@
-#include <ruby.h>
 #include "ruby_whisper.h"
 
 extern ID id___method__;
index 1137e3f36abcb85cf4b889ffa947e6a405d519e8..0e7b2c276e8f24599b167643f1e745460298b203 100644 (file)
@@ -37,8 +37,8 @@ module Whisper
     #       puts text
     #     end
     #
-    def transcribe: (string, Params, ?n_processors: Integer) -> self
-                  | (string, Params, ?n_processors: Integer) { (String) -> void } -> self
+    def transcribe: (path, Params, ?n_processors: Integer) -> self
+                  | (path, Params, ?n_processors: Integer) { (String) -> void } -> self
 
     def model_n_vocab: () -> Integer
     def model_n_audio_ctx: () -> Integer
@@ -603,6 +603,8 @@ module Whisper
 
     class Context
       def self.new: (String | path | ::URI::HTTP model_name_or_path) -> instance
+      def segments_from_samples: (Params, Array[Float] samples, ?Integer n_samples) -> Segments
+                               | (Params, _Samples, ?Integer n_samples) -> Segments
       def detect: (path wav_file_path, Params) -> Segments
     end
 
index 704916db6de5adc17e7984e71032b2bda625b60e..b4558d34fafb8b502bb9473a05df54da4ec4a5ab 100644 (file)
@@ -9,6 +9,25 @@ class TestVADContext < TestBase
   def test_detect
     context = Whisper::VAD::Context.new("silero-v6.2.0")
     segments = context.detect(AUDIO, Whisper::VAD::Params.new)
+    assert_segments segments
+  end
+
+  def test_invalid_model_type
+    assert_raise TypeError do
+      Whisper::VAD::Context.new(Object.new)
+    end
+  end
+
+  def test_allocate
+    vad = Whisper::VAD::Context.allocate
+    assert_raise do
+      vad.detect(AUDIO, Whisper::VAD::Params.new)
+    end
+  end
+
+  private
+
+  def assert_segments(segments)
     assert_instance_of Whisper::VAD::Segments, segments
 
     i = 0
@@ -35,16 +54,47 @@ class TestVADContext < TestBase
     assert_equal 4, segments.length
   end
 
-  def test_invalid_model_type
-    assert_raise TypeError do
-      Whisper::VAD::Context.new(Object.new)
+  sub_test_case "from samples" do
+    def setup
+      super
+      @vad = Whisper::VAD::Context.new("silero-v6.2.0")
+      @samples = File.read(AUDIO, nil, 78).unpack("s<*").collect {|i| i.to_f / 2**15}
     end
-  end
 
-  def test_allocate
-    vad = Whisper::VAD::Context.allocate
-    assert_raise do
-      vad.detect(AUDIO, Whisper::VAD::Params.new)
+    def test_segments_from_samples
+      segments = @vad.segments_from_samples(Whisper::VAD::Params.new, @samples, @samples.length)
+      assert_segments segments
+    end
+
+    def test_segments_from_samples_without_length
+      segments = @vad.segments_from_samples(Whisper::VAD::Params.new, @samples)
+      assert_segments segments
+    end
+
+    def test_segments_from_samples_enumerator
+      samples = @samples.each
+      segments = @vad.segments_from_samples(Whisper::VAD::Params.new, samples, @samples.length)
+      assert_segments segments
+    end
+
+    def test_segments_from_samples_enumerator_without_length
+      samples = @samples.each
+      assert_raise ArgumentError do
+        @vad.segments_from_samples(Whisper::VAD::Params.new, samples)
+      end
+    end
+
+    def test_segments_from_samples_enumerator_with_too_large_length
+      samples = @samples.each.take(10).to_enum
+      assert_raise StopIteration do
+        @vad.segments_from_samples(Whisper::VAD::Params.new, samples, 11)
+      end
+    end
+
+    def test_segments_from_samples_with_memory_view
+      samples = JFKReader.new(AUDIO)
+      segments = @vad.segments_from_samples(Whisper::VAD::Params.new, samples)
+      assert_segments segments
     end
   end
 end
index 96e248aca3a9513bc1606fcfc2e0a3503bc796e6..29071210072a355bba1f2ed92bda0beace2b724b 100644 (file)
@@ -1,6 +1,7 @@
 require_relative "helper"
 require "stringio"
 require "etc"
+require "pathname"
 
 # Exists to detect memory-related bug
 Whisper.log_set ->(level, buffer, user_data) {}, nil
@@ -20,6 +21,15 @@ class TestWhisper < TestBase
     }
   end
 
+  def test_whisper_pathname
+    @whisper = Whisper::Context.new("base.en")
+    params  = Whisper::Params.new
+
+    @whisper.transcribe(Pathname(AUDIO), params) {|text|
+      assert_match(/ask not what your country can do for you, ask what you can do for your country/, text)
+    }
+  end
+
   def test_transcribe_non_parallel
     @whisper = Whisper::Context.new("base.en")
     params  = Whisper::Params.new
@@ -207,6 +217,16 @@ class TestWhisper < TestBase
       assert_match(/ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text)
     end
 
+    def test_full_with_memroy_view_gc
+      samples = JFKReader.new(AUDIO)
+      @whisper.full(@params, samples)
+      GC.start
+      require "fiddle"
+      Fiddle::MemoryView.export samples do |view|
+        assert_equal 176000, view.to_s.unpack("#{view.format}*").length
+      end
+    end
+
     def test_full_parallel
       nprocessors = 2
       @whisper.full_parallel(@params, @samples, @samples.length, nprocessors)
index 2e05769a22c8b49ba280da1078e6916cead2233e..88b94e7eb8aa3e5518bcc88471fe486f3c41b483 100644 (file)
@@ -3,7 +3,7 @@ require_relative "extsources"
 Gem::Specification.new do |s|
   s.name    = "whispercpp"
   s.authors = ["Georgi Gerganov", "Todd A. Fisher"]
-  s.version = '1.3.5'
+  s.version = '1.3.6'
   s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby}
   s.email   = 'todd.fisher@gmail.com'
   s.extra_rdoc_files = ['LICENSE', 'README.md']