From: KITAITI Makoto Date: Fri, 30 Jan 2026 13:59:36 +0000 (+0900) Subject: ruby : add `VAD::Context#segments_from_samples`, allow Pathname, etc. (#3633) X-Git-Tag: upstream/1.8.3+155~81 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=aa1bc0d1a6dfd70dbb9f60c11df12441e03a9075;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp ruby : add `VAD::Context#segments_from_samples`, allow Pathname, etc. (#3633) * ruby : Bump version to 1.3.6 * Fix code in example * Add sample code to transcribe from MemoryView * Define GetVADContext macro * Use GetVADContext * Extract parse_full_args function * Use parse_full_args in ruby_whisper_full_parallel * Free samples after use * Check return value of parse_full_args() * Define GetVADParams macro * Add VAD::Context#segments_from_samples * Add tests for VAD::Context#segments_from_samples * Add signature for VAD::Context#segments_from_samples * Add sample code for VAD::Context#segments_from_samples * Add test for Whisper::Context#transcribe with Pathname * Make Whisper::Context#transcribe and Whisper::VAD::Context#detect accept Pathname * Update signature of Whisper::Context#transcribe * Fix variable name * Don't free memory view * Make parse_full_args return struct * Fallback when failed to get MemoryView * Add num of samples when too long * Check members of MemoryView * Fix a typo * Remove unnecessary include * Fix a typo * Fix a typo * Care the case of MemoryView doesn't fit spec * Add TODO comment * Add optimazation option to compiler flags * Use ALLOC_N instead of malloc * Add description to sample code * Rename and change args: parse_full_args -> parse_samples * Free samples when exception raised * Assign type check result to a variable * Define wrapper function of whisper_full * Change signature of parse_samples for rb_ensure * Ensure release MemoryView * Extract fill_samples function * Free samples memory when filling it failed * Free samples memory when transcription failed * Prepare transcription in wrapper funciton * Change function name * Simplify function boundary --- diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md index ea202753..86774158 100644 --- a/bindings/ruby/README.md +++ b/bindings/ruby/README.md @@ -323,7 +323,24 @@ whisper end ``` -The second argument `samples` may be an array, an object with `length` and `each` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy. +The second argument `samples` may be an array, an object with `length` and `each` method, or a MemoryView. + +If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy. + +```ruby +require "torchaudio" +require "arrow-numo-narray" +require "whisper" + +waveform, sample_rate = TorchAudio.load("test/fixtures/jfk.wav") +# Convert Torch::Tensor to Arrow::Array via Numo::NArray +samples = waveform.squeeze.numo.to_arrow.to_arrow_array + +whisper = Whisper::Context.new("base") +whisper + # Arrow::Array exports MemoryView + .full(Whisper::Params.new, samples) +``` Using VAD separately from ASR ----------------------------- @@ -334,13 +351,27 @@ VAD feature itself is useful. You can use it separately from ASR: vad = Whisper::VAD::Context.new("silero-v6.2.0") vad .detect("path/to/audio.wav", Whisper::VAD::Params.new) - .each_with_index do |segment, index| + .each.with_index do |segment, index| segment => {start_time: st, end_time: ed} # `Segment` responds to `#deconstruct_keys` puts "[%{nth}: %{st} --> %{ed}]" % {nth: index + 1, st:, ed:} end ``` +You may also low level API `Whisper::VAD::Context#segments_from_samples` as such `Whisper::Context#full`: + +```ruby +# Ruby Array +reader = WaveFile::Reader.new("path/to/audio.wav", WaveFile::Format.new(:mono, :float, 16000)) +samples = reader.enum_for(:each_buffer).map(&:samples).flatten + +# Or, object which exports MemoryView +waveform, sample_rate = TorchAudio.load("test/fixtures/jfk.wav") +samples = waveform.squeeze.numo.to_arrow.to_arrow_array + +segments = vad.segments_from_samples(Whisper::VAD::Params.new, samples) +``` + Development ----------- diff --git a/bindings/ruby/ext/extconf.rb b/bindings/ruby/ext/extconf.rb index 8a5ac674..acff501a 100644 --- a/bindings/ruby/ext/extconf.rb +++ b/bindings/ruby/ext/extconf.rb @@ -7,6 +7,7 @@ options = Options.new(cmake).to_s have_library("gomp") rescue nil libs = Dependencies.new(cmake, options).to_s +$CFLAGS << " -O3 -march=native" $INCFLAGS << " -Isources/include -Isources/ggml/include -Isources/examples" $LOCAL_LIBS << " #{libs}" $cleanfiles << " build #{libs}" diff --git a/bindings/ruby/ext/ruby_whisper.c b/bindings/ruby/ext/ruby_whisper.c index ac677e9e..eb95829c 100644 --- a/bindings/ruby/ext/ruby_whisper.c +++ b/bindings/ruby/ext/ruby_whisper.c @@ -1,5 +1,3 @@ -#include -#include #include "ruby_whisper.h" VALUE mWhisper; diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 3f5660c3..c2c9866a 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -1,6 +1,8 @@ #ifndef RUBY_WHISPER_H #define RUBY_WHISPER_H +#include +#include #include "whisper.h" typedef struct { @@ -55,6 +57,13 @@ typedef struct { struct whisper_vad_context *context; } ruby_whisper_vad_context; +typedef struct parsed_samples_t { + float *samples; + int n_samples; + rb_memory_view_t memview; + bool memview_exported; +} parsed_samples_t; + #define GetContext(obj, rw) do { \ TypedData_Get_Struct((obj), ruby_whisper, &ruby_whisper_type, (rw)); \ if ((rw)->context == NULL) { \ @@ -69,6 +78,17 @@ typedef struct { } \ } while (0) +#define GetVADContext(obj, rwvc) do { \ + TypedData_Get_Struct((obj), ruby_whisper_vad_context, &ruby_whisper_vad_context_type, (rwvc)); \ + if ((rwvc)->context == NULL) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + +#define GetVADParams(obj, rwvp) do { \ + TypedData_Get_Struct((obj), ruby_whisper_vad_params, &ruby_whisper_vad_params_type, (rwvp)); \ +} while (0) + #define GetVADSegments(obj, rwvss) do { \ TypedData_Get_Struct((obj), ruby_whisper_vad_segments, &ruby_whisper_vad_segments_type, (rwvss)); \ if ((rwvss)->segments == NULL) { \ diff --git a/bindings/ruby/ext/ruby_whisper_context.c b/bindings/ruby/ext/ruby_whisper_context.c index a7b5f851..84790e3d 100644 --- a/bindings/ruby/ext/ruby_whisper_context.c +++ b/bindings/ruby/ext/ruby_whisper_context.c @@ -1,5 +1,3 @@ -#include -#include #include "ruby_whisper.h" extern ID id_to_s; @@ -27,6 +25,27 @@ extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context); ID transcribe_option_names[1]; +typedef struct fill_samples_args { + float *dest; + VALUE *src; + int n_samples; +} fill_samples_args; + +typedef struct full_args { + VALUE *context; + VALUE *params; + float *samples; + int n_samples; +} full_args; + +typedef struct full_parallel_args { + VALUE *context; + VALUE *params; + float *samples; + int n_samples; + int n_processors; +} full_parallel_args; + static void ruby_whisper_free(ruby_whisper *rw) { @@ -272,82 +291,175 @@ VALUE ruby_whisper_model_type(VALUE self) return rb_str_new2(whisper_model_type_readable(rw->context)); } -/* - * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text - * Not thread safe for same context - * Uses the specified decoding strategy to obtain the text. - * - * call-seq: - * full(params, samples, n_samples) -> nil - * full(params, samples) -> nil - * - * The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data. - */ -VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) +static bool +check_memory_view(rb_memory_view_t *memview) { - if (argc < 2 || argc > 3) { - rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc); + if (strcmp(memview->format, "f") != 0) { + rb_warn("currently only format \"f\" is supported for MemoryView, but given: %s", memview->format); + return false; + } + if (memview->ndim != 1) { + rb_warn("currently only 1 dimensional MemoryView is supported, but given: %zd", memview->ndim); + return false; } - ruby_whisper *rw; - ruby_whisper_params *rwp; - GetContext(self, rw); - VALUE params = argv[0]; - TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp); - VALUE samples = argv[1]; - int n_samples; - rb_memory_view_t view; - const bool memory_view_available_p = rb_memory_view_available_p(samples); - if (argc == 3) { - n_samples = NUM2INT(argv[2]); - if (TYPE(samples) == T_ARRAY) { - if (RARRAY_LEN(samples) < n_samples) { - rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples); + return true; +} + +static VALUE +fill_samples(VALUE rb_args) +{ + fill_samples_args *args = (fill_samples_args *)rb_args; + + if (RB_TYPE_P(*args->src, T_ARRAY)) { + for (int i = 0; i < args->n_samples; i++) { + args->dest[i] = RFLOAT_VALUE(rb_ary_entry(*args->src, i)); + } + } else { + // TODO: use rb_block_call + VALUE iter = rb_funcall(*args->src, id_to_enum, 1, rb_str_new2("each")); + for (int i = 0; i < args->n_samples; i++) { + // TODO: check if iter is exhausted and raise ArgumentError appropriately + VALUE sample = rb_funcall(iter, id_next, 0); + args->dest[i] = RFLOAT_VALUE(sample); + } + } + + return Qnil; +} + +struct parsed_samples_t +parse_samples(VALUE *samples, VALUE *n_samples) +{ + bool memview_available = rb_memory_view_available_p(*samples); + struct parsed_samples_t parsed = {0}; + parsed.memview_exported = false; + const bool is_array = RB_TYPE_P(*samples, T_ARRAY); + + if (!NIL_P(*n_samples)) { + parsed.n_samples = NUM2INT(*n_samples); + if (is_array) { + if (RARRAY_LEN(*samples) < parsed.n_samples) { + rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(*samples), parsed.n_samples); } } // Should check when samples.respond_to?(:length)? } else { - if (TYPE(samples) == T_ARRAY) { - if (RARRAY_LEN(samples) > INT_MAX) { + if (is_array) { + if (RARRAY_LEN(*samples) > INT_MAX) { rb_raise(rb_eArgError, "samples are too long"); } - n_samples = (int)RARRAY_LEN(samples); - } else if (memory_view_available_p) { - if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) { - view.obj = Qnil; - rb_raise(rb_eArgError, "unable to get a memory view"); + parsed.n_samples = (int)RARRAY_LEN(*samples); + } else if (memview_available) { + bool memview_got = rb_memory_view_get(*samples, &parsed.memview, RUBY_MEMORY_VIEW_SIMPLE); + if (memview_got) { + parsed.memview_exported = check_memory_view(&parsed.memview); + if (!parsed.memview_exported) { + rb_memory_view_release(&parsed.memview); + parsed.memview = (rb_memory_view_t){0}; + } } - ssize_t n_samples_size = view.byte_size / view.item_size; - if (n_samples_size > INT_MAX) { - rb_raise(rb_eArgError, "samples are too long"); + if (parsed.memview_exported) { + ssize_t n_samples_size = parsed.memview.byte_size / parsed.memview.item_size; + if (n_samples_size > INT_MAX) { + rb_memory_view_release(&parsed.memview); + rb_raise(rb_eArgError, "samples are too long: %zd", n_samples_size); + } + parsed.n_samples = (int)n_samples_size; + } else { + rb_warn("unable to get a memory view. fallbacks to Ruby object"); + if (rb_respond_to(*samples, id_length)) { + parsed.n_samples = NUM2INT(rb_funcall(*samples, id_length, 0)); + } else { + rb_raise(rb_eArgError, "samples must respond to :length"); + } } - n_samples = (int)n_samples_size; - } else if (rb_respond_to(samples, id_length)) { - n_samples = NUM2INT(rb_funcall(samples, id_length, 0)); + } else if (rb_respond_to(*samples, id_length)) { + parsed.n_samples = NUM2INT(rb_funcall(*samples, id_length, 0)); } else { - rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given"); + rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of float when n_samples is not given"); } } - float * c_samples = (float *)malloc(n_samples * sizeof(float)); - if (memory_view_available_p) { - c_samples = (float *)view.data; + + if (parsed.memview_exported) { + parsed.samples = (float *)parsed.memview.data; } else { - if (TYPE(samples) == T_ARRAY) { - for (int i = 0; i < n_samples; i++) { - c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i)); - } - } else { - // TODO: use rb_block_call - VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each")); - for (int i = 0; i < n_samples; i++) { - // TODO: check if iter is exhausted and raise ArgumentError appropriately - VALUE sample = rb_funcall(iter, id_next, 0); - c_samples[i] = RFLOAT_VALUE(sample); - } + parsed.samples = ALLOC_N(float, parsed.n_samples); + fill_samples_args args = { + parsed.samples, + samples, + parsed.n_samples, + }; + int state; + rb_protect(fill_samples, (VALUE)&args, &state); + if (state) { + xfree(parsed.samples); + rb_jump_tag(state); } } - prepare_transcription(rwp, &self); - const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples); + + return parsed; +} + +VALUE +release_samples(VALUE rb_parsed_args) +{ + parsed_samples_t *parsed_args = (parsed_samples_t *)rb_parsed_args; + + if (parsed_args->memview_exported) { + rb_memory_view_release(&parsed_args->memview); + } else { + xfree(parsed_args->samples); + } + *parsed_args = (parsed_samples_t){0}; + + return Qnil; +} + +static VALUE +full_body(VALUE rb_args) +{ + full_args *args = (full_args *)rb_args; + + ruby_whisper *rw; + ruby_whisper_params *rwp; + GetContext(*args->context, rw); + TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); + + prepare_transcription(rwp, args->context); + int result = whisper_full(rw->context, rwp->params, args->samples, args->n_samples); + + return INT2NUM(result); +} + +/* + * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text + * Not thread safe for same context + * Uses the specified decoding strategy to obtain the text. + * + * call-seq: + * full(params, samples, n_samples) -> nil + * full(params, samples) -> nil + * + * The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data. + */ +VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) +{ + if (argc < 2 || argc > 3) { + rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc); + } + + VALUE n_samples = argc == 2 ? Qnil : argv[2]; + + struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples); + full_args args = { + &self, + &argv[0], + parsed.samples, + parsed.n_samples, + }; + VALUE rb_result = rb_ensure(full_body, (VALUE)&args, release_samples, (VALUE)&parsed); + const int result = NUM2INT(rb_result); if (0 == result) { return self; } else { @@ -355,6 +467,22 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) } } +static VALUE +full_parallel_body(VALUE rb_args) +{ + full_parallel_args *args = (full_parallel_args *)rb_args; + + ruby_whisper *rw; + ruby_whisper_params *rwp; + GetContext(*args->context, rw); + TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); + + prepare_transcription(rwp, args->context); + int result = whisper_full_parallel(rw->context, rwp->params, args->samples, args->n_samples, args->n_processors); + + return INT2NUM(result); +} + /* * Split the input audio in chunks and process each chunk separately using whisper_full_with_state() * Result is stored in the default state of the context @@ -372,19 +500,11 @@ static VALUE ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) { if (argc < 2 || argc > 4) { - rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc); + rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..4)", argc); } - ruby_whisper *rw; - ruby_whisper_params *rwp; - GetContext(self, rw); - VALUE params = argv[0]; - TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp); - VALUE samples = argv[1]; - int n_samples; + VALUE n_samples = argc == 2 ? Qnil : argv[2]; int n_processors; - rb_memory_view_t view; - const bool memory_view_available_p = rb_memory_view_available_p(samples); switch (argc) { case 2: n_processors = 1; @@ -396,56 +516,16 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) n_processors = NUM2INT(argv[3]); break; } - if (argc >= 3 && !NIL_P(argv[2])) { - n_samples = NUM2INT(argv[2]); - if (TYPE(samples) == T_ARRAY) { - if (RARRAY_LEN(samples) < n_samples) { - rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples); - } - } - // Should check when samples.respond_to?(:length)? - } else if (memory_view_available_p) { - if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) { - view.obj = Qnil; - rb_raise(rb_eArgError, "unable to get a memory view"); - } - ssize_t n_samples_size = view.byte_size / view.item_size; - if (n_samples_size > INT_MAX) { - rb_raise(rb_eArgError, "samples are too long"); - } - n_samples = (int)n_samples_size; - } else { - if (TYPE(samples) == T_ARRAY) { - if (RARRAY_LEN(samples) > INT_MAX) { - rb_raise(rb_eArgError, "samples are too long"); - } - n_samples = (int)RARRAY_LEN(samples); - } else if (rb_respond_to(samples, id_length)) { - n_samples = NUM2INT(rb_funcall(samples, id_length, 0)); - } else { - rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given"); - } - } - float * c_samples = (float *)malloc(n_samples * sizeof(float)); - if (memory_view_available_p) { - c_samples = (float *)view.data; - } else { - if (TYPE(samples) == T_ARRAY) { - for (int i = 0; i < n_samples; i++) { - c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i)); - } - } else { - // FIXME: use rb_block_call - VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each")); - for (int i = 0; i < n_samples; i++) { - // TODO: check if iter is exhausted and raise ArgumentError - VALUE sample = rb_funcall(iter, id_next, 0); - c_samples[i] = RFLOAT_VALUE(sample); - } - } - } - prepare_transcription(rwp, &self); - const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors); + struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples); + const full_parallel_args args = { + &self, + &argv[0], + parsed.samples, + parsed.n_samples, + n_processors, + }; + const VALUE rb_result = rb_ensure(full_parallel_body, (VALUE)&args, release_samples, (VALUE)&parsed); + const int result = NUM2INT(rb_result); if (0 == result) { return self; } else { diff --git a/bindings/ruby/ext/ruby_whisper_model.c b/bindings/ruby/ext/ruby_whisper_model.c index b196a8b5..0e91fb3f 100644 --- a/bindings/ruby/ext/ruby_whisper_model.c +++ b/bindings/ruby/ext/ruby_whisper_model.c @@ -1,4 +1,3 @@ -#include #include "ruby_whisper.h" extern const rb_data_type_t ruby_whisper_type; diff --git a/bindings/ruby/ext/ruby_whisper_params.c b/bindings/ruby/ext/ruby_whisper_params.c index 4dfe2575..61eb1733 100644 --- a/bindings/ruby/ext/ruby_whisper_params.c +++ b/bindings/ruby/ext/ruby_whisper_params.c @@ -1,4 +1,3 @@ -#include #include "ruby_whisper.h" #define BOOL_PARAMS_SETTER(self, prop, value) \ diff --git a/bindings/ruby/ext/ruby_whisper_segment.c b/bindings/ruby/ext/ruby_whisper_segment.c index 5229cb53..ee0d66c4 100644 --- a/bindings/ruby/ext/ruby_whisper_segment.c +++ b/bindings/ruby/ext/ruby_whisper_segment.c @@ -1,4 +1,3 @@ -#include #include "ruby_whisper.h" #define N_KEY_NAMES 6 diff --git a/bindings/ruby/ext/ruby_whisper_token.c b/bindings/ruby/ext/ruby_whisper_token.c index ea4f4e63..56a7eab2 100644 --- a/bindings/ruby/ext/ruby_whisper_token.c +++ b/bindings/ruby/ext/ruby_whisper_token.c @@ -1,4 +1,3 @@ -#include #include "ruby_whisper.h" #define N_KEY_NAMES 11 diff --git a/bindings/ruby/ext/ruby_whisper_transcribe.cpp b/bindings/ruby/ext/ruby_whisper_transcribe.cpp index 594b2db9..c00fbcd1 100644 --- a/bindings/ruby/ext/ruby_whisper_transcribe.cpp +++ b/bindings/ruby/ext/ruby_whisper_transcribe.cpp @@ -1,4 +1,3 @@ -#include #include "ruby_whisper.h" #include "common-whisper.h" #include @@ -13,6 +12,7 @@ extern const rb_data_type_t ruby_whisper_params_type; extern ID id_to_s; extern ID id_call; +extern ID id_to_path; extern ID transcribe_option_names[1]; extern void @@ -50,6 +50,9 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { rb_raise(rb_eRuntimeError, "Expected file path to wave file"); } + if (rb_respond_to(wave_file_path, id_to_path)) { + wave_file_path = rb_funcall(wave_file_path, id_to_path, 0); + } std::string fname_inp = StringValueCStr(wave_file_path); std::vector pcmf32; // mono-channel F32 PCM diff --git a/bindings/ruby/ext/ruby_whisper_vad_context.c b/bindings/ruby/ext/ruby_whisper_vad_context.c index bf2ed2ba..97c9736b 100644 --- a/bindings/ruby/ext/ruby_whisper_vad_context.c +++ b/bindings/ruby/ext/ruby_whisper_vad_context.c @@ -1,12 +1,23 @@ -#include #include "ruby_whisper.h" extern ID id_to_s; extern VALUE cVADContext; +extern const rb_data_type_t ruby_whisper_vad_params_type; extern VALUE ruby_whisper_vad_detect(VALUE self, VALUE file_path, VALUE params); extern VALUE ruby_whisper_normalize_model_path(VALUE model_path); +extern parsed_samples_t parse_samples(VALUE *samples, VALUE *n_samples); +extern VALUE release_samples(VALUE parsed); + +extern VALUE ruby_whisper_vad_segments_s_init(struct whisper_vad_segments *segments); + +typedef struct segments_from_samples_args { + VALUE *context; + VALUE *params; + float *samples; + int n_samples; +} segments_from_samples_args; static size_t ruby_whisper_vad_context_memsize(const void *p) @@ -66,10 +77,46 @@ ruby_whisper_vad_context_initialize(VALUE self, VALUE model_path) return Qnil; } +static VALUE +segments_from_samples_body(VALUE rb_args) +{ + segments_from_samples_args *args = (segments_from_samples_args *)rb_args; + + ruby_whisper_vad_context *rwvc; + ruby_whisper_vad_params *rwvp; + GetVADContext(*args->context, rwvc); + GetVADParams(*args->params, rwvp); + + struct whisper_vad_segments *segments = whisper_vad_segments_from_samples(rwvc->context, rwvp->params, args->samples, args->n_samples); + + return ruby_whisper_vad_segments_s_init(segments); +} + +static VALUE +ruby_whisper_vad_segments_from_samples(int argc, VALUE *argv, VALUE self) +{ + if (argc < 2 || argc > 3) { + rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc); + } + + VALUE n_samples = argc == 2 ? Qnil : argv[2]; + struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples); + segments_from_samples_args args = { + &self, + &argv[0], + parsed.samples, + parsed.n_samples, + }; + VALUE segments = rb_ensure(segments_from_samples_body, (VALUE)&args, release_samples, (VALUE)&parsed); + + return segments; +} + void init_ruby_whisper_vad_context(VALUE *mVAD) { cVADContext = rb_define_class_under(*mVAD, "Context", rb_cObject); rb_define_alloc_func(cVADContext, ruby_whisper_vad_context_s_allocate); rb_define_method(cVADContext, "initialize", ruby_whisper_vad_context_initialize, 1); + rb_define_method(cVADContext, "segments_from_samples", ruby_whisper_vad_segments_from_samples, -1); rb_define_method(cVADContext, "detect", ruby_whisper_vad_detect, 2); } diff --git a/bindings/ruby/ext/ruby_whisper_vad_context_detect.cpp b/bindings/ruby/ext/ruby_whisper_vad_context_detect.cpp index 58609f87..802b0222 100644 --- a/bindings/ruby/ext/ruby_whisper_vad_context_detect.cpp +++ b/bindings/ruby/ext/ruby_whisper_vad_context_detect.cpp @@ -1,4 +1,3 @@ -#include #include "ruby_whisper.h" #include "common-whisper.h" #include @@ -8,6 +7,8 @@ extern "C" { #endif +extern ID id_to_path; + extern VALUE cVADSegments; extern const rb_data_type_t ruby_whisper_vad_context_type; @@ -25,12 +26,12 @@ ruby_whisper_vad_detect(VALUE self, VALUE file_path, VALUE params) { std::vector> pcmf32s; whisper_vad_segments *segments; - TypedData_Get_Struct(self, ruby_whisper_vad_context, &ruby_whisper_vad_context_type, rwvc); - if (rwvc->context == NULL) { - rb_raise(rb_eRuntimeError, "Doesn't have referenxe to context internally"); - } + GetVADContext(self, rwvc); TypedData_Get_Struct(params, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp); + if (rb_respond_to(file_path, id_to_path)) { + file_path = rb_funcall(file_path, id_to_path, 0); + } cpp_file_path = StringValueCStr(file_path); if (!read_audio_data(cpp_file_path, pcmf32, pcmf32s, false)) { diff --git a/bindings/ruby/ext/ruby_whisper_vad_params.c b/bindings/ruby/ext/ruby_whisper_vad_params.c index f254bfa2..28256650 100644 --- a/bindings/ruby/ext/ruby_whisper_vad_params.c +++ b/bindings/ruby/ext/ruby_whisper_vad_params.c @@ -1,4 +1,3 @@ -#include #include "ruby_whisper.h" #define DEFINE_PARAM(param_name, nth) \ diff --git a/bindings/ruby/ext/ruby_whisper_vad_segment.c b/bindings/ruby/ext/ruby_whisper_vad_segment.c index 49ff0aad..84a007bb 100644 --- a/bindings/ruby/ext/ruby_whisper_vad_segment.c +++ b/bindings/ruby/ext/ruby_whisper_vad_segment.c @@ -1,4 +1,3 @@ -#include #include "ruby_whisper.h" #define N_KEY_NAMES 2 diff --git a/bindings/ruby/ext/ruby_whisper_vad_segments.c b/bindings/ruby/ext/ruby_whisper_vad_segments.c index 1bb37593..db62fdb6 100644 --- a/bindings/ruby/ext/ruby_whisper_vad_segments.c +++ b/bindings/ruby/ext/ruby_whisper_vad_segments.c @@ -1,4 +1,3 @@ -#include #include "ruby_whisper.h" extern ID id___method__; diff --git a/bindings/ruby/sig/whisper.rbs b/bindings/ruby/sig/whisper.rbs index 1137e3f3..0e7b2c27 100644 --- a/bindings/ruby/sig/whisper.rbs +++ b/bindings/ruby/sig/whisper.rbs @@ -37,8 +37,8 @@ module Whisper # puts text # end # - def transcribe: (string, Params, ?n_processors: Integer) -> self - | (string, Params, ?n_processors: Integer) { (String) -> void } -> self + def transcribe: (path, Params, ?n_processors: Integer) -> self + | (path, Params, ?n_processors: Integer) { (String) -> void } -> self def model_n_vocab: () -> Integer def model_n_audio_ctx: () -> Integer @@ -603,6 +603,8 @@ module Whisper class Context def self.new: (String | path | ::URI::HTTP model_name_or_path) -> instance + def segments_from_samples: (Params, Array[Float] samples, ?Integer n_samples) -> Segments + | (Params, _Samples, ?Integer n_samples) -> Segments def detect: (path wav_file_path, Params) -> Segments end diff --git a/bindings/ruby/test/test_vad_context.rb b/bindings/ruby/test/test_vad_context.rb index 704916db..b4558d34 100644 --- a/bindings/ruby/test/test_vad_context.rb +++ b/bindings/ruby/test/test_vad_context.rb @@ -9,6 +9,25 @@ class TestVADContext < TestBase def test_detect context = Whisper::VAD::Context.new("silero-v6.2.0") segments = context.detect(AUDIO, Whisper::VAD::Params.new) + assert_segments segments + end + + def test_invalid_model_type + assert_raise TypeError do + Whisper::VAD::Context.new(Object.new) + end + end + + def test_allocate + vad = Whisper::VAD::Context.allocate + assert_raise do + vad.detect(AUDIO, Whisper::VAD::Params.new) + end + end + + private + + def assert_segments(segments) assert_instance_of Whisper::VAD::Segments, segments i = 0 @@ -35,16 +54,47 @@ class TestVADContext < TestBase assert_equal 4, segments.length end - def test_invalid_model_type - assert_raise TypeError do - Whisper::VAD::Context.new(Object.new) + sub_test_case "from samples" do + def setup + super + @vad = Whisper::VAD::Context.new("silero-v6.2.0") + @samples = File.read(AUDIO, nil, 78).unpack("s<*").collect {|i| i.to_f / 2**15} end - end - def test_allocate - vad = Whisper::VAD::Context.allocate - assert_raise do - vad.detect(AUDIO, Whisper::VAD::Params.new) + def test_segments_from_samples + segments = @vad.segments_from_samples(Whisper::VAD::Params.new, @samples, @samples.length) + assert_segments segments + end + + def test_segments_from_samples_without_length + segments = @vad.segments_from_samples(Whisper::VAD::Params.new, @samples) + assert_segments segments + end + + def test_segments_from_samples_enumerator + samples = @samples.each + segments = @vad.segments_from_samples(Whisper::VAD::Params.new, samples, @samples.length) + assert_segments segments + end + + def test_segments_from_samples_enumerator_without_length + samples = @samples.each + assert_raise ArgumentError do + @vad.segments_from_samples(Whisper::VAD::Params.new, samples) + end + end + + def test_segments_from_samples_enumerator_with_too_large_length + samples = @samples.each.take(10).to_enum + assert_raise StopIteration do + @vad.segments_from_samples(Whisper::VAD::Params.new, samples, 11) + end + end + + def test_segments_from_samples_with_memory_view + samples = JFKReader.new(AUDIO) + segments = @vad.segments_from_samples(Whisper::VAD::Params.new, samples) + assert_segments segments end end end diff --git a/bindings/ruby/test/test_whisper.rb b/bindings/ruby/test/test_whisper.rb index 96e248ac..29071210 100644 --- a/bindings/ruby/test/test_whisper.rb +++ b/bindings/ruby/test/test_whisper.rb @@ -1,6 +1,7 @@ require_relative "helper" require "stringio" require "etc" +require "pathname" # Exists to detect memory-related bug Whisper.log_set ->(level, buffer, user_data) {}, nil @@ -20,6 +21,15 @@ class TestWhisper < TestBase } end + def test_whisper_pathname + @whisper = Whisper::Context.new("base.en") + params = Whisper::Params.new + + @whisper.transcribe(Pathname(AUDIO), params) {|text| + assert_match(/ask not what your country can do for you, ask what you can do for your country/, text) + } + end + def test_transcribe_non_parallel @whisper = Whisper::Context.new("base.en") params = Whisper::Params.new @@ -207,6 +217,16 @@ class TestWhisper < TestBase assert_match(/ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text) end + def test_full_with_memroy_view_gc + samples = JFKReader.new(AUDIO) + @whisper.full(@params, samples) + GC.start + require "fiddle" + Fiddle::MemoryView.export samples do |view| + assert_equal 176000, view.to_s.unpack("#{view.format}*").length + end + end + def test_full_parallel nprocessors = 2 @whisper.full_parallel(@params, @samples, @samples.length, nprocessors) diff --git a/bindings/ruby/whispercpp.gemspec b/bindings/ruby/whispercpp.gemspec index 2e05769a..88b94e7e 100644 --- a/bindings/ruby/whispercpp.gemspec +++ b/bindings/ruby/whispercpp.gemspec @@ -3,7 +3,7 @@ require_relative "extsources" Gem::Specification.new do |s| s.name = "whispercpp" s.authors = ["Georgi Gerganov", "Todd A. Fisher"] - s.version = '1.3.5' + s.version = '1.3.6' s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby} s.email = 'todd.fisher@gmail.com' s.extra_rdoc_files = ['LICENSE', 'README.md']