#include <ruby.h>
+#include <ruby/memory_view.h>
#include "ruby_whisper.h"
#define DR_WAV_IMPLEMENTATION
#include "dr_wav.h"
VALUE mWhisper;
VALUE cContext;
VALUE cParams;
+VALUE eError;
static ID id_to_s;
static ID id_call;
static ID id___method__;
static ID id_to_enum;
+static ID id_length;
+static ID id_next;
+static ID id_new;
static bool is_log_callback_finalized = false;
* log_set ->(level, buffer, user_data) { ... }, user_data -> nil
*/
static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) {
- VALUE old_callback = rb_iv_get(self, "@log_callback");
+ VALUE old_callback = rb_iv_get(self, "log_callback");
if (!NIL_P(old_callback)) {
rb_undefine_finalizer(old_callback);
}
- rb_iv_set(self, "@log_callback", log_callback);
- rb_iv_set(self, "@user_data", user_data);
+ rb_iv_set(self, "log_callback", log_callback);
+ rb_iv_set(self, "user_data", user_data);
VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback"));
rb_define_finalizer(log_callback, finalize_log_callback);
if (is_log_callback_finalized) {
return;
}
- VALUE log_callback = rb_iv_get(mWhisper, "@log_callback");
- VALUE udata = rb_iv_get(mWhisper, "@user_data");
+ VALUE log_callback = rb_iv_get(mWhisper, "log_callback");
+ VALUE udata = rb_iv_get(mWhisper, "user_data");
rb_funcall(log_callback, id_call, 3, INT2NUM(level), rb_str_new2(buffer), udata);
}, nullptr);
return rb_str_new2(whisper_model_type_readable(rw->context));
}
+/*
+ * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
+ * Not thread safe for same context
+ * Uses the specified decoding strategy to obtain the text.
+ *
+ * call-seq:
+ * full(params, samples, n_samples) -> nil
+ * full(params, samples) -> nil
+ *
+ * The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data.
+ */
+VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) {
+ if (argc < 2 || argc > 3) {
+ rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
+ }
+
+ ruby_whisper *rw;
+ ruby_whisper_params *rwp;
+ Data_Get_Struct(self, ruby_whisper, rw);
+ VALUE params = argv[0];
+ Data_Get_Struct(params, ruby_whisper_params, rwp);
+ VALUE samples = argv[1];
+ int n_samples;
+ rb_memory_view_t view;
+ const bool memory_view_available_p = rb_memory_view_available_p(samples);
+ if (argc == 3) {
+ n_samples = NUM2INT(argv[2]);
+ if (TYPE(samples) == T_ARRAY) {
+ if (RARRAY_LEN(samples) < n_samples) {
+ rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
+ }
+ }
+ // Should check when samples.respond_to?(:length)?
+ } else {
+ if (TYPE(samples) == T_ARRAY) {
+ n_samples = RARRAY_LEN(samples);
+ } else if (memory_view_available_p) {
+ if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
+ view.obj = Qnil;
+ rb_raise(rb_eArgError, "unable to get a memory view");
+ }
+ n_samples = view.byte_size / view.item_size;
+ } else if (rb_respond_to(samples, id_length)) {
+ n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
+ } else {
+ rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
+ }
+ }
+ float * c_samples = (float *)malloc(n_samples * sizeof(float));
+ if (memory_view_available_p) {
+ c_samples = (float *)view.data;
+ } else {
+ if (TYPE(samples) == T_ARRAY) {
+ for (int i = 0; i < n_samples; i++) {
+ c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
+ }
+ } else {
+ // TODO: use rb_block_call
+ VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
+ for (int i = 0; i < n_samples; i++) {
+ // TODO: check if iter is exhausted and raise ArgumentError appropriately
+ VALUE sample = rb_funcall(iter, id_next, 0);
+ c_samples[i] = RFLOAT_VALUE(sample);
+ }
+ }
+ }
+ const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples);
+ if (0 == result) {
+ return Qnil;
+ } else {
+ rb_exc_raise(rb_funcall(eError, id_new, 1, result));
+ }
+}
+
+/*
+ * Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
+ * Result is stored in the default state of the context
+ * Not thread safe if executed in parallel on the same context.
+ * It seems this approach can offer some speedup in some cases.
+ * However, the transcription accuracy can be worse at the beginning and end of each chunk.
+ *
+ * call-seq:
+ * full_parallel(params, samples) -> nil
+ * full_parallel(params, samples, n_samples) -> nil
+ * full_parallel(params, samples, n_samples, n_processors) -> nil
+ * full_parallel(params, samples, nil, n_processors) -> nil
+ */
+static VALUE ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) {
+ if (argc < 2 || argc > 4) {
+ rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
+ }
+
+ ruby_whisper *rw;
+ ruby_whisper_params *rwp;
+ Data_Get_Struct(self, ruby_whisper, rw);
+ VALUE params = argv[0];
+ Data_Get_Struct(params, ruby_whisper_params, rwp);
+ VALUE samples = argv[1];
+ int n_samples;
+ int n_processors;
+ rb_memory_view_t view;
+ const bool memory_view_available_p = rb_memory_view_available_p(samples);
+ switch (argc) {
+ case 2:
+ n_processors = 1;
+ break;
+ case 3:
+ n_processors = 1;
+ break;
+ case 4:
+ n_processors = NUM2INT(argv[3]);
+ break;
+ }
+ if (argc >= 3 && !NIL_P(argv[2])) {
+ n_samples = NUM2INT(argv[2]);
+ if (TYPE(samples) == T_ARRAY) {
+ if (RARRAY_LEN(samples) < n_samples) {
+ rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
+ }
+ }
+ // Should check when samples.respond_to?(:length)?
+ } else if (memory_view_available_p) {
+ if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
+ view.obj = Qnil;
+ rb_raise(rb_eArgError, "unable to get a memory view");
+ }
+ n_samples = view.byte_size / view.item_size;
+ } else {
+ if (TYPE(samples) == T_ARRAY) {
+ n_samples = RARRAY_LEN(samples);
+ } else if (rb_respond_to(samples, id_length)) {
+ n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
+ } else {
+ rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
+ }
+ }
+ float * c_samples = (float *)malloc(n_samples * sizeof(float));
+ if (memory_view_available_p) {
+ c_samples = (float *)view.data;
+ } else {
+ if (TYPE(samples) == T_ARRAY) {
+ for (int i = 0; i < n_samples; i++) {
+ c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
+ }
+ } else {
+ // FIXME: use rb_block_call
+ VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
+ for (int i = 0; i < n_samples; i++) {
+ // TODO: check if iter is exhausted and raise ArgumentError
+ VALUE sample = rb_funcall(iter, id_next, 0);
+ c_samples[i] = RFLOAT_VALUE(sample);
+ }
+ }
+ }
+ const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors);
+ if (0 == result) {
+ return Qnil;
+ } else {
+ rb_exc_raise(rb_funcall(eError, id_new, 1, result));
+ }
+}
+
/*
* Number of segments.
*
return rb_str_new2(whisper_model_type_readable(rw->context));
}
+static VALUE ruby_whisper_error_initialize(VALUE self, VALUE code) {
+ const int c_code = NUM2INT(code);
+ char *raw_message;
+ switch (c_code) {
+ case -2:
+ raw_message = "failed to compute log mel spectrogram";
+ break;
+ case -3:
+ raw_message = "failed to auto-detect language";
+ break;
+ case -4:
+ raw_message = "too many decoders requested";
+ break;
+ case -5:
+ raw_message = "audio_ctx is larger than the maximum allowed";
+ break;
+ case -6:
+ raw_message = "failed to encode";
+ break;
+ case -7:
+ raw_message = "whisper_kv_cache_init() failed for self-attention cache";
+ break;
+ case -8:
+ raw_message = "failed to decode";
+ break;
+ case -9:
+ raw_message = "failed to decode";
+ break;
+ default:
+ raw_message = "unknown error";
+ break;
+ }
+ const VALUE message = rb_str_new2(raw_message);
+ rb_call_super(1, &message);
+ rb_iv_set(self, "@code", code);
+
+ return self;
+}
+
+
void Init_whisper() {
id_to_s = rb_intern("to_s");
id_call = rb_intern("call");
id___method__ = rb_intern("__method__");
id_to_enum = rb_intern("to_enum");
+ id_length = rb_intern("length");
+ id_next = rb_intern("next");
+ id_new = rb_intern("new");
mWhisper = rb_define_module("Whisper");
cContext = rb_define_class_under(mWhisper, "Context", rb_cObject);
cParams = rb_define_class_under(mWhisper, "Params", rb_cObject);
+ eError = rb_define_class_under(mWhisper, "Error", rb_eStandardError);
rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE));
rb_define_const(mWhisper, "LOG_LEVEL_INFO", INT2NUM(GGML_LOG_LEVEL_INFO));
rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1);
rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1);
rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1);
+ rb_define_method(cContext, "full", ruby_whisper_full, -1);
+ rb_define_method(cContext, "full_parallel", ruby_whisper_full_parallel, -1);
rb_define_alloc_func(cParams, ruby_whisper_params_allocate);
rb_define_method(cParams, "abort_callback=", ruby_whisper_params_set_abort_callback, 1);
rb_define_method(cParams, "abort_callback_user_data=", ruby_whisper_params_set_abort_callback_user_data, 1);
+ rb_define_attr(eError, "code", true, false);
+ rb_define_method(eError, "initialize", ruby_whisper_error_initialize, 1);
+
// High leve
cSegment = rb_define_class_under(mWhisper, "Segment", rb_cObject);
--- /dev/null
+#include <ruby.h>
+#include <ruby/memory_view.h>
+#include <ruby/encoding.h>
+
+static VALUE
+jfk_reader_initialize(VALUE self, VALUE audio_path)
+{
+ rb_iv_set(self, "audio_path", audio_path);
+ return Qnil;
+}
+
+static bool
+jfk_reader_get_memory_view(const VALUE obj, rb_memory_view_t *view, int flags)
+{
+ VALUE audio_path = rb_iv_get(obj, "audio_path");
+ const char *audio_path_str = StringValueCStr(audio_path);
+ const int n_samples = 176000;
+ float *data = (float *)malloc(n_samples * sizeof(float));
+ short *samples = (short *)malloc(n_samples * sizeof(short));
+ FILE *file = fopen(audio_path_str, "rb");
+
+ fseek(file, 78, SEEK_SET);
+ fread(samples, sizeof(short), n_samples, file);
+ fclose(file);
+ for (int i = 0; i < n_samples; i++) {
+ data[i] = samples[i]/32768.0;
+ }
+
+ view->obj = obj;
+ view->data = (void *)data;
+ view->byte_size = sizeof(float) * n_samples;
+ view->readonly = true;
+ view->format = "f";
+ view->item_size = sizeof(float);
+ view->item_desc.components = NULL;
+ view->item_desc.length = 0;
+ view->ndim = 1;
+ view->shape = NULL;
+ view->sub_offsets = NULL;
+ view->private_data = NULL;
+
+ return true;
+}
+
+static bool
+jfk_reader_release_memory_view(const VALUE obj, rb_memory_view_t *view)
+{
+ return true;
+}
+
+static bool
+jfk_reader_memory_view_available_p(const VALUE obj)
+{
+ return true;
+}
+
+static const rb_memory_view_entry_t jfk_reader_view_entry = {
+ jfk_reader_get_memory_view,
+ jfk_reader_release_memory_view,
+ jfk_reader_memory_view_available_p
+};
+
+static VALUE
+read_jfk(int argc, VALUE *argv, VALUE obj)
+{
+ const char *audio_path_str = StringValueCStr(argv[0]);
+ const int n_samples = 176000;
+
+ short samples[n_samples];
+ FILE *file = fopen(audio_path_str, "rb");
+
+ fseek(file, 78, SEEK_SET);
+ fread(samples, sizeof(short), n_samples, file);
+ fclose(file);
+
+ VALUE rb_samples = rb_ary_new2(n_samples);
+ for (int i = 0; i < n_samples; i++) {
+ rb_ary_push(rb_samples, INT2FIX(samples[i]));
+ }
+
+ VALUE rb_data = rb_ary_new2(n_samples);
+ for (int i = 0; i < n_samples; i++) {
+ rb_ary_push(rb_data, DBL2NUM(samples[i]/32768.0));
+ }
+
+ float data[n_samples];
+ for (int i = 0; i < n_samples; i++) {
+ data[i] = samples[i]/32768.0;
+ }
+ void *c_data = (void *)data;
+ VALUE rb_void = rb_enc_str_new((const char *)c_data, sizeof(data), rb_ascii8bit_encoding());
+
+ VALUE rb_result = rb_ary_new3(3, rb_samples, rb_data, rb_void);
+ return rb_result;
+}
+
+void Init_jfk_reader(void)
+{
+ VALUE cJFKReader = rb_define_class("JFKReader", rb_cObject);
+ rb_memory_view_register(cJFKReader, &jfk_reader_view_entry);
+ rb_define_method(cJFKReader, "initialize", jfk_reader_initialize, 1);
+
+
+ rb_define_global_function("read_jfk", read_jfk, -1);
+
+
+
+}
require_relative "helper"
require "stringio"
+require "etc"
# Exists to detect memory-related bug
Whisper.log_set ->(level, buffer, user_data) {}, nil
ensure
$stderr = stderr
end
+
+ sub_test_case "full" do
+ def setup
+ super
+ @whisper = Whisper::Context.new(MODEL)
+ @samples = File.read(AUDIO, nil, 78).unpack("s<*").collect {|i| i.to_f / 2**15}
+ end
+
+ def test_full
+ @whisper.full(@params, @samples, @samples.length)
+
+ assert_equal 1, @whisper.full_n_segments
+ assert_match /ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text
+ end
+
+ def test_full_without_length
+ @whisper.full(@params, @samples)
+
+ assert_equal 1, @whisper.full_n_segments
+ assert_match /ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text
+ end
+
+ def test_full_enumerator
+ samples = @samples.each
+ @whisper.full(@params, samples, @samples.length)
+
+ assert_equal 1, @whisper.full_n_segments
+ assert_match /ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text
+ end
+
+ def test_full_enumerator_without_length
+ samples = @samples.each
+ assert_raise ArgumentError do
+ @whisper.full(@params, samples)
+ end
+ end
+
+ def test_full_enumerator_with_too_large_length
+ samples = @samples.each.take(10).to_enum
+ assert_raise StopIteration do
+ @whisper.full(@params, samples, 11)
+ end
+ end
+
+ def test_full_with_memory_view
+ samples = JFKReader.new(AUDIO)
+ @whisper.full(@params, samples)
+
+ assert_equal 1, @whisper.full_n_segments
+ assert_match /ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text
+ end
+
+ def test_full_parallel
+ @whisper.full_parallel(@params, @samples, @samples.length, Etc.nprocessors)
+
+ assert_equal Etc.nprocessors, @whisper.full_n_segments
+ text = @whisper.each_segment.collect(&:text).join
+ assert_match /ask what you can do/i, text
+ assert_match /for your country/i, text
+ end
+
+ def test_full_parallel_with_memory_view
+ samples = JFKReader.new(AUDIO)
+ @whisper.full_parallel(@params, samples, nil, Etc.nprocessors)
+
+ assert_equal Etc.nprocessors, @whisper.full_n_segments
+ text = @whisper.each_segment.collect(&:text).join
+ assert_match /ask what you can do/i, text
+ assert_match /for your country/i, text
+ end
+
+ def test_full_parallel_without_length_and_n_processors
+ @whisper.full_parallel(@params, @samples)
+
+ assert_equal 1, @whisper.full_n_segments
+ text = @whisper.each_segment.collect(&:text).join
+ assert_match /ask what you can do/i, text
+ assert_match /for your country/i, text
+ end
+
+ def test_full_parallel_without_length
+ @whisper.full_parallel(@params, @samples, nil, Etc.nprocessors)
+
+ assert_equal Etc.nprocessors, @whisper.full_n_segments
+ text = @whisper.each_segment.collect(&:text).join
+ assert_match /ask what you can do/i, text
+ assert_match /for your country/i, text
+ end
+
+ def test_full_parallel_without_n_processors
+ @whisper.full_parallel(@params, @samples, @samples.length)
+
+ assert_equal 1, @whisper.full_n_segments
+ text = @whisper.each_segment.collect(&:text).join
+ assert_match /ask what you can do/i, text
+ assert_match /for your country/i, text
+ end
+ end
end