-#include <ruby.h>
-#include <ruby/memory_view.h>
#include "ruby_whisper.h"
extern ID id_to_s;
ID transcribe_option_names[1];
+typedef struct fill_samples_args {
+ float *dest;
+ VALUE *src;
+ int n_samples;
+} fill_samples_args;
+
+typedef struct full_args {
+ VALUE *context;
+ VALUE *params;
+ float *samples;
+ int n_samples;
+} full_args;
+
+typedef struct full_parallel_args {
+ VALUE *context;
+ VALUE *params;
+ float *samples;
+ int n_samples;
+ int n_processors;
+} full_parallel_args;
+
static void
ruby_whisper_free(ruby_whisper *rw)
{
return rb_str_new2(whisper_model_type_readable(rw->context));
}
-/*
- * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
- * Not thread safe for same context
- * Uses the specified decoding strategy to obtain the text.
- *
- * call-seq:
- * full(params, samples, n_samples) -> nil
- * full(params, samples) -> nil
- *
- * The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data.
- */
-VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self)
+static bool
+check_memory_view(rb_memory_view_t *memview)
{
- if (argc < 2 || argc > 3) {
- rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
+ if (strcmp(memview->format, "f") != 0) {
+ rb_warn("currently only format \"f\" is supported for MemoryView, but given: %s", memview->format);
+ return false;
+ }
+ if (memview->ndim != 1) {
+ rb_warn("currently only 1 dimensional MemoryView is supported, but given: %zd", memview->ndim);
+ return false;
}
- ruby_whisper *rw;
- ruby_whisper_params *rwp;
- GetContext(self, rw);
- VALUE params = argv[0];
- TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
- VALUE samples = argv[1];
- int n_samples;
- rb_memory_view_t view;
- const bool memory_view_available_p = rb_memory_view_available_p(samples);
- if (argc == 3) {
- n_samples = NUM2INT(argv[2]);
- if (TYPE(samples) == T_ARRAY) {
- if (RARRAY_LEN(samples) < n_samples) {
- rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
+ return true;
+}
+
+static VALUE
+fill_samples(VALUE rb_args)
+{
+ fill_samples_args *args = (fill_samples_args *)rb_args;
+
+ if (RB_TYPE_P(*args->src, T_ARRAY)) {
+ for (int i = 0; i < args->n_samples; i++) {
+ args->dest[i] = RFLOAT_VALUE(rb_ary_entry(*args->src, i));
+ }
+ } else {
+ // TODO: use rb_block_call
+ VALUE iter = rb_funcall(*args->src, id_to_enum, 1, rb_str_new2("each"));
+ for (int i = 0; i < args->n_samples; i++) {
+ // TODO: check if iter is exhausted and raise ArgumentError appropriately
+ VALUE sample = rb_funcall(iter, id_next, 0);
+ args->dest[i] = RFLOAT_VALUE(sample);
+ }
+ }
+
+ return Qnil;
+}
+
+struct parsed_samples_t
+parse_samples(VALUE *samples, VALUE *n_samples)
+{
+ bool memview_available = rb_memory_view_available_p(*samples);
+ struct parsed_samples_t parsed = {0};
+ parsed.memview_exported = false;
+ const bool is_array = RB_TYPE_P(*samples, T_ARRAY);
+
+ if (!NIL_P(*n_samples)) {
+ parsed.n_samples = NUM2INT(*n_samples);
+ if (is_array) {
+ if (RARRAY_LEN(*samples) < parsed.n_samples) {
+ rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(*samples), parsed.n_samples);
}
}
// Should check when samples.respond_to?(:length)?
} else {
- if (TYPE(samples) == T_ARRAY) {
- if (RARRAY_LEN(samples) > INT_MAX) {
+ if (is_array) {
+ if (RARRAY_LEN(*samples) > INT_MAX) {
rb_raise(rb_eArgError, "samples are too long");
}
- n_samples = (int)RARRAY_LEN(samples);
- } else if (memory_view_available_p) {
- if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
- view.obj = Qnil;
- rb_raise(rb_eArgError, "unable to get a memory view");
+ parsed.n_samples = (int)RARRAY_LEN(*samples);
+ } else if (memview_available) {
+ bool memview_got = rb_memory_view_get(*samples, &parsed.memview, RUBY_MEMORY_VIEW_SIMPLE);
+ if (memview_got) {
+ parsed.memview_exported = check_memory_view(&parsed.memview);
+ if (!parsed.memview_exported) {
+ rb_memory_view_release(&parsed.memview);
+ parsed.memview = (rb_memory_view_t){0};
+ }
}
- ssize_t n_samples_size = view.byte_size / view.item_size;
- if (n_samples_size > INT_MAX) {
- rb_raise(rb_eArgError, "samples are too long");
+ if (parsed.memview_exported) {
+ ssize_t n_samples_size = parsed.memview.byte_size / parsed.memview.item_size;
+ if (n_samples_size > INT_MAX) {
+ rb_memory_view_release(&parsed.memview);
+ rb_raise(rb_eArgError, "samples are too long: %zd", n_samples_size);
+ }
+ parsed.n_samples = (int)n_samples_size;
+ } else {
+ rb_warn("unable to get a memory view. fallbacks to Ruby object");
+ if (rb_respond_to(*samples, id_length)) {
+ parsed.n_samples = NUM2INT(rb_funcall(*samples, id_length, 0));
+ } else {
+ rb_raise(rb_eArgError, "samples must respond to :length");
+ }
}
- n_samples = (int)n_samples_size;
- } else if (rb_respond_to(samples, id_length)) {
- n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
+ } else if (rb_respond_to(*samples, id_length)) {
+ parsed.n_samples = NUM2INT(rb_funcall(*samples, id_length, 0));
} else {
- rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
+ rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of float when n_samples is not given");
}
}
- float * c_samples = (float *)malloc(n_samples * sizeof(float));
- if (memory_view_available_p) {
- c_samples = (float *)view.data;
+
+ if (parsed.memview_exported) {
+ parsed.samples = (float *)parsed.memview.data;
} else {
- if (TYPE(samples) == T_ARRAY) {
- for (int i = 0; i < n_samples; i++) {
- c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
- }
- } else {
- // TODO: use rb_block_call
- VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
- for (int i = 0; i < n_samples; i++) {
- // TODO: check if iter is exhausted and raise ArgumentError appropriately
- VALUE sample = rb_funcall(iter, id_next, 0);
- c_samples[i] = RFLOAT_VALUE(sample);
- }
+ parsed.samples = ALLOC_N(float, parsed.n_samples);
+ fill_samples_args args = {
+ parsed.samples,
+ samples,
+ parsed.n_samples,
+ };
+ int state;
+ rb_protect(fill_samples, (VALUE)&args, &state);
+ if (state) {
+ xfree(parsed.samples);
+ rb_jump_tag(state);
}
}
- prepare_transcription(rwp, &self);
- const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples);
+
+ return parsed;
+}
+
+VALUE
+release_samples(VALUE rb_parsed_args)
+{
+ parsed_samples_t *parsed_args = (parsed_samples_t *)rb_parsed_args;
+
+ if (parsed_args->memview_exported) {
+ rb_memory_view_release(&parsed_args->memview);
+ } else {
+ xfree(parsed_args->samples);
+ }
+ *parsed_args = (parsed_samples_t){0};
+
+ return Qnil;
+}
+
+static VALUE
+full_body(VALUE rb_args)
+{
+ full_args *args = (full_args *)rb_args;
+
+ ruby_whisper *rw;
+ ruby_whisper_params *rwp;
+ GetContext(*args->context, rw);
+ TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
+
+ prepare_transcription(rwp, args->context);
+ int result = whisper_full(rw->context, rwp->params, args->samples, args->n_samples);
+
+ return INT2NUM(result);
+}
+
+/*
+ * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
+ * Not thread safe for same context
+ * Uses the specified decoding strategy to obtain the text.
+ *
+ * call-seq:
+ * full(params, samples, n_samples) -> nil
+ * full(params, samples) -> nil
+ *
+ * The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data.
+ */
+VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self)
+{
+ if (argc < 2 || argc > 3) {
+ rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
+ }
+
+ VALUE n_samples = argc == 2 ? Qnil : argv[2];
+
+ struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples);
+ full_args args = {
+ &self,
+ &argv[0],
+ parsed.samples,
+ parsed.n_samples,
+ };
+ VALUE rb_result = rb_ensure(full_body, (VALUE)&args, release_samples, (VALUE)&parsed);
+ const int result = NUM2INT(rb_result);
if (0 == result) {
return self;
} else {
}
}
+static VALUE
+full_parallel_body(VALUE rb_args)
+{
+ full_parallel_args *args = (full_parallel_args *)rb_args;
+
+ ruby_whisper *rw;
+ ruby_whisper_params *rwp;
+ GetContext(*args->context, rw);
+ TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
+
+ prepare_transcription(rwp, args->context);
+ int result = whisper_full_parallel(rw->context, rwp->params, args->samples, args->n_samples, args->n_processors);
+
+ return INT2NUM(result);
+}
+
/*
* Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
* Result is stored in the default state of the context
ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self)
{
if (argc < 2 || argc > 4) {
- rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
+ rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..4)", argc);
}
- ruby_whisper *rw;
- ruby_whisper_params *rwp;
- GetContext(self, rw);
- VALUE params = argv[0];
- TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
- VALUE samples = argv[1];
- int n_samples;
+ VALUE n_samples = argc == 2 ? Qnil : argv[2];
int n_processors;
- rb_memory_view_t view;
- const bool memory_view_available_p = rb_memory_view_available_p(samples);
switch (argc) {
case 2:
n_processors = 1;
n_processors = NUM2INT(argv[3]);
break;
}
- if (argc >= 3 && !NIL_P(argv[2])) {
- n_samples = NUM2INT(argv[2]);
- if (TYPE(samples) == T_ARRAY) {
- if (RARRAY_LEN(samples) < n_samples) {
- rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
- }
- }
- // Should check when samples.respond_to?(:length)?
- } else if (memory_view_available_p) {
- if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
- view.obj = Qnil;
- rb_raise(rb_eArgError, "unable to get a memory view");
- }
- ssize_t n_samples_size = view.byte_size / view.item_size;
- if (n_samples_size > INT_MAX) {
- rb_raise(rb_eArgError, "samples are too long");
- }
- n_samples = (int)n_samples_size;
- } else {
- if (TYPE(samples) == T_ARRAY) {
- if (RARRAY_LEN(samples) > INT_MAX) {
- rb_raise(rb_eArgError, "samples are too long");
- }
- n_samples = (int)RARRAY_LEN(samples);
- } else if (rb_respond_to(samples, id_length)) {
- n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
- } else {
- rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
- }
- }
- float * c_samples = (float *)malloc(n_samples * sizeof(float));
- if (memory_view_available_p) {
- c_samples = (float *)view.data;
- } else {
- if (TYPE(samples) == T_ARRAY) {
- for (int i = 0; i < n_samples; i++) {
- c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
- }
- } else {
- // FIXME: use rb_block_call
- VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
- for (int i = 0; i < n_samples; i++) {
- // TODO: check if iter is exhausted and raise ArgumentError
- VALUE sample = rb_funcall(iter, id_next, 0);
- c_samples[i] = RFLOAT_VALUE(sample);
- }
- }
- }
- prepare_transcription(rwp, &self);
- const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors);
+ struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples);
+ const full_parallel_args args = {
+ &self,
+ &argv[0],
+ parsed.samples,
+ parsed.n_samples,
+ n_processors,
+ };
+ const VALUE rb_result = rb_ensure(full_parallel_body, (VALUE)&args, release_samples, (VALUE)&parsed);
+ const int result = NUM2INT(rb_result);
if (0 == result) {
return self;
} else {