]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ruby : Make context accept initial parameters, API to retrieve a segment and more... upstream/1.7.4+33
authorKITAITI Makoto <redacted>
Tue, 21 Jan 2025 07:39:54 +0000 (16:39 +0900)
committerGitHub <redacted>
Tue, 21 Jan 2025 07:39:54 +0000 (09:39 +0200)
* Fix type signature for Whisper.log_set

* Use cache file for model when offline

* Extract ruby_whisper_transcribe() into a file

* Extract Whisper::Error

* Use FileList for ext/*.{c,cpp,h}

* Extract Whisper::Segment

* Extract Whisper::Model

* Extract Whisper::Params

* Extract Whisper::Context

* Extract log_callback function

* Write base code in C rather than C++

* Use chdir instead of Dir.chdir in Rakefile

* Define alloc func for Whisper::Model

* Define Whisper::Params' calback and user data reader

* Add test for Whisper::Params.new with keyword arguments

* Make Whisper::Params.new accept keyword arguments

* Update type signatures

* Update README

* Update CLEAN targets

* Fix document comment for Whisper::Params#new_segment_callback=

* Use macro to define params

* Fix dependency of build task

* Set Whisper.finalize_log_callback visibility to private

* Make Whisper::Context#full and full_parallel return self

* Add test for Whisper::Context#full_get_segment

* Add Whisper::Context#full_get_segment

* Update signatures

* Update README

* Fix signature

* Resplace #initialize with .new in signature file [skip ci]

* Fix potential overflow

17 files changed:
bindings/ruby/README.md
bindings/ruby/Rakefile
bindings/ruby/ext/.gitignore
bindings/ruby/ext/extconf.rb
bindings/ruby/ext/ruby_whisper.c [new file with mode: 0644]
bindings/ruby/ext/ruby_whisper.cpp [deleted file]
bindings/ruby/ext/ruby_whisper.h
bindings/ruby/ext/ruby_whisper_context.c [new file with mode: 0644]
bindings/ruby/ext/ruby_whisper_error.c [new file with mode: 0644]
bindings/ruby/ext/ruby_whisper_model.c [new file with mode: 0644]
bindings/ruby/ext/ruby_whisper_params.c [new file with mode: 0644]
bindings/ruby/ext/ruby_whisper_segment.c [new file with mode: 0644]
bindings/ruby/ext/ruby_whisper_transcribe.cpp [new file with mode: 0644]
bindings/ruby/lib/whisper/model/uri.rb
bindings/ruby/sig/whisper.rbs
bindings/ruby/tests/test_params.rb
bindings/ruby/tests/test_whisper.rb

index 13ff1f00ad16e7d13277242857c849d3785424b3..f66d8d651e21a8ff7abe008444b6c8254fdf6709 100644 (file)
@@ -24,14 +24,15 @@ require "whisper"
 
 whisper = Whisper::Context.new("base")
 
-params = Whisper::Params.new
-params.language = "en"
-params.offset = 10_000
-params.duration = 60_000
-params.max_text_tokens = 300
-params.translate = true
-params.print_timestamps = false
-params.initial_prompt = "Initial prompt here."
+params = Whisper::Params.new(
+  language: "en",
+  offset: 10_000,
+  duration: 60_000,
+  max_text_tokens: 300,
+  translate: true,
+  print_timestamps: false,
+  initial_prompt: "Initial prompt here."
+)
 
 whisper.transcribe("path/to/audio.wav", params) do |whole_text|
   puts whole_text
@@ -113,18 +114,18 @@ def format_time(time_ms)
   "%02d:%02d:%02d.%03d" % [hour, min, sec, decimal_part]
 end
 
-whisper.transcribe("path/to/audio.wav", params)
-
-whisper.each_segment.with_index do |segment, index|
-  line = "[%{nth}: %{st} --> %{ed}] %{text}" % {
-    nth: index + 1,
-    st: format_time(segment.start_time),
-    ed: format_time(segment.end_time),
-    text: segment.text
-  }
-  line << " (speaker turned)" if segment.speaker_next_turn?
-  puts line
-end
+whisper
+  .transcribe("path/to/audio.wav", params)
+  .each_segment.with_index do |segment, index|
+    line = "[%{nth}: %{st} --> %{ed}] %{text}" % {
+      nth: index + 1,
+      st: format_time(segment.start_time),
+      ed: format_time(segment.end_time),
+      text: segment.text
+    }
+    line << " (speaker turned)" if segment.speaker_next_turn?
+    puts line
+  end
 
 ```
 
@@ -215,10 +216,11 @@ reader = WaveFile::Reader.new("path/to/audio.wav", WaveFile::Format.new(:mono, :
 samples = reader.enum_for(:each_buffer).map(&:samples).flatten
 
 whisper = Whisper::Context.new("base")
-whisper.full(Whisper::Params.new, samples)
-whisper.each_segment do |segment|
-  puts segment.text
-end
+whisper
+  .full(Whisper::Params.new, samples)
+  .each_segment do |segment|
+    puts segment.text
+  end
 ```
 
 The second argument `samples` may be an array, an object with `length` and `each` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy.
index 3a7809b750550ae50b91e76129c9d96b0d383ee5..0d52e88a31a26793830226f9f59477feb871d614 100644 (file)
@@ -18,9 +18,11 @@ EXTSOURCES.each do |src|
 end
 
 CLEAN.include SOURCES
-CLEAN.include FileList["ext/*.o", "ext/*.metal", "ext/whisper.{so,bundle,dll}"]
+CLEAN.include FileList["ext/**/*.o", "ext/**/*.metal", "ext/**/*.tmp", "ext/whisper.{so,bundle,dll}"]
 
-task build: ["ext/Makefile", "ext/ruby_whisper.h", "ext/ruby_whisper.cpp", "whispercpp.gemspec"]
+SRC = FileList["ext/*.{c,cpp,h}"]
+
+task build: SOURCES
 
 directory "pkg"
 CLOBBER.include "pkg"
@@ -29,14 +31,14 @@ LIB_NAME = "whisper".ext(RbConfig::CONFIG["DLEXT"])
 SO_FILE = File.join("ext", LIB_NAME)
 LIB_FILE = File.join("lib", LIB_NAME)
 
-file "ext/Makefile" => ["ext/extconf.rb", "ext/ruby_whisper.h", "ext/ruby_whisper.cpp"] + SOURCES do |t|
-  Dir.chdir "ext" do
+file "ext/Makefile" => SRC + ["ext/extconf.rb"] + SOURCES do |t|
+  chdir "ext" do
     ruby "extconf.rb"
   end
 end
 
 file SO_FILE => "ext/Makefile" do |t|
-  Dir.chdir "ext" do
+  chdir "ext" do
     sh "make"
   end
 end
@@ -54,7 +56,7 @@ end
 
 TEST_MEMORY_VIEW = "tests/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}"
 file TEST_MEMORY_VIEW => "tests/jfk_reader/jfk_reader.c" do |t|
-  Dir.chdir "tests/jfk_reader" do
+  chdir "tests/jfk_reader" do
     ruby "extconf.rb"
     sh "make"
   end
index e96a8584c94322569a9f37e30d1d42be4f20d9e7..7703146ff8ef53706a579dccf96b8291b2e20621 100644 (file)
@@ -4,10 +4,8 @@ whisper.bundle
 whisper.dll
 scripts/get-flags.mk
 *.o
-*.c
-*.cpp
-*.h
-*.m
-*.metal
-!ruby_whisper.cpp
-!ruby_whisper.h
+/*/**/*.c
+/*/**/*.cpp
+/*/**/*.h
+/*/**/*.m
+/*/**/*.metal
index 6ffac109e3f7716ed1503238c15c7f69b2e356c8..af50904d8c55c7b173670a3d789cf76f71135b3f 100644 (file)
@@ -174,7 +174,14 @@ $OBJ_WHISPER <<
   'src/whisper.o'
 
 $objs = $OBJ_GGML + $OBJ_WHISPER + $OBJ_COMMON + $OBJ_SDL
-$objs << "ruby_whisper.o"
+$objs <<
+  "ruby_whisper.o" <<
+  "ruby_whisper_context.o" <<
+  "ruby_whisper_transcribe.o" <<
+  "ruby_whisper_params.o" <<
+  "ruby_whisper_error.o" <<
+  "ruby_whisper_segment.o" <<
+  "ruby_whisper_model.o"
 
 $CPPFLAGS  = "#{$MK_CPPFLAGS} #{$CPPFLAGS}"
 $CFLAGS    = "#{$CPPFLAGS} #{$MK_CFLAGS} #{$GF_CFLAGS} #{$CFLAGS}"
diff --git a/bindings/ruby/ext/ruby_whisper.c b/bindings/ruby/ext/ruby_whisper.c
new file mode 100644 (file)
index 0000000..4322778
--- /dev/null
@@ -0,0 +1,164 @@
+#include <ruby.h>
+#include <ruby/memory_view.h>
+#include "ruby_whisper.h"
+
+VALUE mWhisper;
+VALUE cContext;
+VALUE cParams;
+VALUE eError;
+
+VALUE cSegment;
+VALUE cModel;
+
+ID id_to_s;
+ID id_call;
+ID id___method__;
+ID id_to_enum;
+ID id_length;
+ID id_next;
+ID id_new;
+ID id_to_path;
+ID id_URI;
+ID id_pre_converted_models;
+
+static bool is_log_callback_finalized = false;
+
+// High level API
+extern VALUE ruby_whisper_segment_allocate(VALUE klass);
+
+extern void init_ruby_whisper_context(VALUE *mWhisper);
+extern void init_ruby_whisper_params(VALUE *mWhisper);
+extern void init_ruby_whisper_error(VALUE *mWhisper);
+extern void init_ruby_whisper_segment(VALUE *mWhisper, VALUE *cSegment);
+extern void init_ruby_whisper_model(VALUE *mWhisper);
+extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context);
+
+/*
+ * call-seq:
+ *   lang_max_id -> Integer
+ */
+static VALUE ruby_whisper_s_lang_max_id(VALUE self) {
+  return INT2NUM(whisper_lang_max_id());
+}
+
+/*
+ * call-seq:
+ *   lang_id(lang_name) -> Integer
+ */
+static VALUE ruby_whisper_s_lang_id(VALUE self, VALUE lang) {
+  const char * lang_str = StringValueCStr(lang);
+  const int id = whisper_lang_id(lang_str);
+  if (-1 == id) {
+    rb_raise(rb_eArgError, "language not found: %s", lang_str);
+  }
+  return INT2NUM(id);
+}
+
+/*
+ * call-seq:
+ *   lang_str(lang_id) -> String
+ */
+static VALUE ruby_whisper_s_lang_str(VALUE self, VALUE id) {
+  const int lang_id = NUM2INT(id);
+  const char * str = whisper_lang_str(lang_id);
+  if (NULL == str) {
+    rb_raise(rb_eIndexError, "id %d outside of language id", lang_id);
+  }
+  return rb_str_new2(str);
+}
+
+/*
+ * call-seq:
+ *   lang_str(lang_id) -> String
+ */
+static VALUE ruby_whisper_s_lang_str_full(VALUE self, VALUE id) {
+  const int lang_id = NUM2INT(id);
+  const char * str_full = whisper_lang_str_full(lang_id);
+  if (NULL == str_full) {
+    rb_raise(rb_eIndexError, "id %d outside of language id", lang_id);
+  }
+  return rb_str_new2(str_full);
+}
+
+static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) {
+  is_log_callback_finalized = true;
+  return Qnil;
+}
+
+static void
+ruby_whisper_log_callback(enum ggml_log_level level, const char * buffer, void * user_data) {
+  if (is_log_callback_finalized) {
+    return;
+  }
+  VALUE log_callback = rb_iv_get(mWhisper, "log_callback");
+  VALUE udata = rb_iv_get(mWhisper, "user_data");
+  rb_funcall(log_callback, id_call, 3, INT2NUM(level), rb_str_new2(buffer), udata);
+}
+
+/*
+ * call-seq:
+ *   log_set ->(level, buffer, user_data) { ... }, user_data -> nil
+ */
+static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) {
+  VALUE old_callback = rb_iv_get(self, "log_callback");
+  if (!NIL_P(old_callback)) {
+    rb_undefine_finalizer(old_callback);
+  }
+
+  rb_iv_set(self, "log_callback", log_callback);
+  rb_iv_set(self, "user_data", user_data);
+
+  VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback"));
+  rb_define_finalizer(log_callback, finalize_log_callback);
+
+  whisper_log_set(ruby_whisper_log_callback, NULL);
+
+  return Qnil;
+}
+
+static void rb_whisper_model_mark(ruby_whisper_model *rwm) {
+  rb_gc_mark(rwm->context);
+}
+
+static VALUE ruby_whisper_model_allocate(VALUE klass) {
+  ruby_whisper_model *rwm;
+  rwm = ALLOC(ruby_whisper_model);
+  return Data_Wrap_Struct(klass, rb_whisper_model_mark, RUBY_DEFAULT_FREE, rwm);
+}
+
+void Init_whisper() {
+  id_to_s = rb_intern("to_s");
+  id_call = rb_intern("call");
+  id___method__ = rb_intern("__method__");
+  id_to_enum = rb_intern("to_enum");
+  id_length = rb_intern("length");
+  id_next = rb_intern("next");
+  id_new = rb_intern("new");
+  id_to_path = rb_intern("to_path");
+  id_URI = rb_intern("URI");
+  id_pre_converted_models = rb_intern("pre_converted_models");
+
+  mWhisper = rb_define_module("Whisper");
+
+  rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE));
+  rb_define_const(mWhisper, "LOG_LEVEL_INFO", INT2NUM(GGML_LOG_LEVEL_INFO));
+  rb_define_const(mWhisper, "LOG_LEVEL_WARN", INT2NUM(GGML_LOG_LEVEL_WARN));
+  rb_define_const(mWhisper, "LOG_LEVEL_ERROR", INT2NUM(GGML_LOG_LEVEL_ERROR));
+  rb_define_const(mWhisper, "LOG_LEVEL_DEBUG", INT2NUM(GGML_LOG_LEVEL_DEBUG));
+  rb_define_const(mWhisper, "LOG_LEVEL_CONT", INT2NUM(GGML_LOG_LEVEL_CONT));
+
+  rb_define_singleton_method(mWhisper, "lang_max_id", ruby_whisper_s_lang_max_id, 0);
+  rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1);
+  rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1);
+  rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1);
+  rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2);
+  rb_define_private_method(rb_singleton_class(mWhisper), "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1);
+
+  init_ruby_whisper_context(&mWhisper);
+  init_ruby_whisper_params(&mWhisper);
+  init_ruby_whisper_error(&mWhisper);
+  init_ruby_whisper_segment(&mWhisper, &cContext);
+  init_ruby_whisper_model(&mWhisper);
+
+  rb_require("whisper/model/uri");
+}
diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp
deleted file mode 100644 (file)
index 5979f20..0000000
+++ /dev/null
@@ -1,1962 +0,0 @@
-#include <ruby.h>
-#include <ruby/memory_view.h>
-#include "ruby_whisper.h"
-#define DR_WAV_IMPLEMENTATION
-#include "dr_wav.h"
-#include <cmath>
-#include <fstream>
-#include <cstdio>
-#include <string>
-#include <thread>
-#include <vector>
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-#define BOOL_PARAMS_SETTER(self, prop, value) \
-  ruby_whisper_params *rwp; \
-  Data_Get_Struct(self, ruby_whisper_params, rwp); \
-  if (value == Qfalse || value == Qnil) { \
-    rwp->params.prop = false; \
-  } else { \
-    rwp->params.prop = true; \
-  } \
-  return value; \
-
-#define BOOL_PARAMS_GETTER(self,  prop) \
-  ruby_whisper_params *rwp; \
-  Data_Get_Struct(self, ruby_whisper_params, rwp); \
-  if (rwp->params.prop) { \
-    return Qtrue; \
-  } else { \
-    return Qfalse; \
-  }
-
-VALUE mWhisper;
-VALUE cContext;
-VALUE cParams;
-VALUE eError;
-
-VALUE cSegment;
-VALUE cModel;
-
-static ID id_to_s;
-static ID id_call;
-static ID id___method__;
-static ID id_to_enum;
-static ID id_length;
-static ID id_next;
-static ID id_new;
-static ID id_to_path;
-static ID id_URI;
-static ID id_pre_converted_models;
-
-static bool is_log_callback_finalized = false;
-
-// High level API
-static VALUE rb_whisper_segment_initialize(VALUE context, int index);
-
-/*
- * call-seq:
- *   lang_max_id -> Integer
- */
-static VALUE ruby_whisper_s_lang_max_id(VALUE self) {
-  return INT2NUM(whisper_lang_max_id());
-}
-
-/*
- * call-seq:
- *   lang_id(lang_name) -> Integer
- */
-static VALUE ruby_whisper_s_lang_id(VALUE self, VALUE lang) {
-  const char * lang_str = StringValueCStr(lang);
-  const int id = whisper_lang_id(lang_str);
-  if (-1 == id) {
-    rb_raise(rb_eArgError, "language not found: %s", lang_str);
-  }
-  return INT2NUM(id);
-}
-
-/*
- * call-seq:
- *   lang_str(lang_id) -> String
- */
-static VALUE ruby_whisper_s_lang_str(VALUE self, VALUE id) {
-  const int lang_id = NUM2INT(id);
-  const char * str = whisper_lang_str(lang_id);
-  if (nullptr == str) {
-    rb_raise(rb_eIndexError, "id %d outside of language id", lang_id);
-  }
-  return rb_str_new2(str);
-}
-
-/*
- * call-seq:
- *   lang_str(lang_id) -> String
- */
-static VALUE ruby_whisper_s_lang_str_full(VALUE self, VALUE id) {
-  const int lang_id = NUM2INT(id);
-  const char * str_full = whisper_lang_str_full(lang_id);
-  if (nullptr == str_full) {
-    rb_raise(rb_eIndexError, "id %d outside of language id", lang_id);
-  }
-  return rb_str_new2(str_full);
-}
-
-static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) {
-  is_log_callback_finalized = true;
-  return Qnil;
-}
-
-/*
- * call-seq:
- *   log_set ->(level, buffer, user_data) { ... }, user_data -> nil
- */
-static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) {
-  VALUE old_callback = rb_iv_get(self, "log_callback");
-  if (!NIL_P(old_callback)) {
-    rb_undefine_finalizer(old_callback);
-  }
-
-  rb_iv_set(self, "log_callback", log_callback);
-  rb_iv_set(self, "user_data", user_data);
-
-  VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback"));
-  rb_define_finalizer(log_callback, finalize_log_callback);
-
-  whisper_log_set([](ggml_log_level level, const char * buffer, void * user_data) {
-    if (is_log_callback_finalized) {
-      return;
-    }
-    VALUE log_callback = rb_iv_get(mWhisper, "log_callback");
-    VALUE udata = rb_iv_get(mWhisper, "user_data");
-    rb_funcall(log_callback, id_call, 3, INT2NUM(level), rb_str_new2(buffer), udata);
-  }, nullptr);
-
-  return Qnil;
-}
-
-static void ruby_whisper_free(ruby_whisper *rw) {
-  if (rw->context) {
-    whisper_free(rw->context);
-    rw->context = NULL;
-  }
-}
-
-static void ruby_whisper_params_free(ruby_whisper_params *rwp) {
-}
-
-void rb_whisper_mark(ruby_whisper *rw) {
-  // call rb_gc_mark on any ruby references in rw
-}
-
-void rb_whisper_free(ruby_whisper *rw) {
-  ruby_whisper_free(rw);
-  free(rw);
-}
-
-void rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc) {
-  rb_gc_mark(rwc->user_data);
-  rb_gc_mark(rwc->callback);
-  rb_gc_mark(rwc->callbacks);
-}
-
-void rb_whisper_params_mark(ruby_whisper_params *rwp) {
-  rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container);
-  rb_whisper_callbcack_container_mark(rwp->progress_callback_container);
-  rb_whisper_callbcack_container_mark(rwp->abort_callback_container);
-}
-
-void rb_whisper_params_free(ruby_whisper_params *rwp) {
-  // How to free user_data and callback only when not referred to by others?
-  ruby_whisper_params_free(rwp);
-  free(rwp);
-}
-
-static VALUE ruby_whisper_allocate(VALUE klass) {
-  ruby_whisper *rw;
-  rw = ALLOC(ruby_whisper);
-  rw->context = NULL;
-  return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw);
-}
-
-static ruby_whisper_callback_container * rb_whisper_callback_container_allocate() {
-  ruby_whisper_callback_container *container;
-  container = ALLOC(ruby_whisper_callback_container);
-  container->context = nullptr;
-  container->user_data = Qnil;
-  container->callback = Qnil;
-  container->callbacks = rb_ary_new();
-  return container;
-}
-
-static void new_segment_callback(struct whisper_context *ctx, struct whisper_state *state, int n_new, void *user_data) {
-  const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
-
-  // Currently, doesn't support state because
-  // those require to resolve GC-related problems.
-  if (!NIL_P(container->callback)) {
-    rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data);
-  }
-  const long callbacks_len = RARRAY_LEN(container->callbacks);
-  if (0 == callbacks_len) {
-    return;
-  }
-  const int n_segments = whisper_full_n_segments_from_state(state);
-  for (int i = n_new; i > 0; i--) {
-    int i_segment = n_segments - i;
-    VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment);
-    for (int j = 0; j < callbacks_len; j++) {
-      VALUE cb = rb_ary_entry(container->callbacks, j);
-      rb_funcall(cb, id_call, 1, segment);
-    }
-  }
-}
-
-static void progress_callback(struct whisper_context *ctx, struct whisper_state *state, int progress_cur, void *user_data) {
-  const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
-  const VALUE progress = INT2NUM(progress_cur);
-  // Currently, doesn't support state because
-  // those require to resolve GC-related problems.
-  if (!NIL_P(container->callback)) {
-    rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data);
-  }
-  const long callbacks_len = RARRAY_LEN(container->callbacks);
-  if (0 == callbacks_len) {
-    return;
-  }
-  for (int j = 0; j < callbacks_len; j++) {
-    VALUE cb = rb_ary_entry(container->callbacks, j);
-    rb_funcall(cb, id_call, 1, progress);
-  }
-}
-
-static bool abort_callback(void * user_data) {
-  const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
-  if (!NIL_P(container->callback)) {
-    VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data);
-    if (!NIL_P(result) && Qfalse != result) {
-      return true;
-    }
-  }
-  const long callbacks_len = RARRAY_LEN(container->callbacks);
-  if (0 == callbacks_len) {
-    return false;
-  }
-  for (int j = 0; j < callbacks_len; j++) {
-    VALUE cb = rb_ary_entry(container->callbacks, j);
-    VALUE result = rb_funcall(cb, id_call, 1, container->user_data);
-    if (!NIL_P(result) && Qfalse != result) {
-      return true;
-    }
-  }
-  return false;
-}
-
-static VALUE ruby_whisper_params_allocate(VALUE klass) {
-  ruby_whisper_params *rwp;
-  rwp = ALLOC(ruby_whisper_params);
-  rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
-  rwp->diarize = false;
-  rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
-  rwp->progress_callback_container = rb_whisper_callback_container_allocate();
-  rwp->abort_callback_container = rb_whisper_callback_container_allocate();
-  return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
-}
-
-/*
- * call-seq:
- *   new("base.en") -> Whisper::Context
- *   new("path/to/model.bin") -> Whisper::Context
- *   new(Whisper::Model::URI.new("https://example.net/uri/of/model.bin")) -> Whisper::Context
- */
-static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
-  ruby_whisper *rw;
-  VALUE whisper_model_file_path;
-
-  // TODO: we can support init from buffer here too maybe another ruby object to expose
-  rb_scan_args(argc, argv, "01", &whisper_model_file_path);
-  Data_Get_Struct(self, ruby_whisper, rw);
-
-  VALUE pre_converted_models = rb_funcall(cModel, id_pre_converted_models, 0);
-  VALUE pre_converted_model = rb_hash_aref(pre_converted_models, whisper_model_file_path);
-  if (!NIL_P(pre_converted_model)) {
-    whisper_model_file_path = pre_converted_model;
-  }
-  if (TYPE(whisper_model_file_path) == T_STRING) {
-    const char * whisper_model_file_path_str = StringValueCStr(whisper_model_file_path);
-    if (strncmp("http://", whisper_model_file_path_str, 7) == 0 || strncmp("https://", whisper_model_file_path_str, 8) == 0) {
-      VALUE uri_class = rb_const_get(cModel, id_URI);
-      whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class);
-    }
-  }
-  if (rb_obj_is_kind_of(whisper_model_file_path, rb_path2class("URI::HTTP"))) {
-    VALUE uri_class = rb_const_get(cModel, id_URI);
-    whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class);
-  }
-  if (rb_respond_to(whisper_model_file_path, id_to_path)) {
-    whisper_model_file_path = rb_funcall(whisper_model_file_path, id_to_path, 0);
-  }
-  if (!rb_respond_to(whisper_model_file_path, id_to_s)) {
-    rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
-  }
-  rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params());
-  if (rw->context == nullptr) {
-    rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context");
-  }
-  return self;
-}
-
-static void register_callbacks(ruby_whisper_params * rwp, VALUE * self) {
-  if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) {
-    rwp->new_segment_callback_container->context = self;
-    rwp->params.new_segment_callback = new_segment_callback;
-    rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container;
-  }
-
-  if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) {
-    rwp->progress_callback_container->context = self;
-    rwp->params.progress_callback = progress_callback;
-    rwp->params.progress_callback_user_data = rwp->progress_callback_container;
-  }
-
-  if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
-    rwp->abort_callback_container->context = self;
-    rwp->params.abort_callback = abort_callback;
-    rwp->params.abort_callback_user_data = rwp->abort_callback_container;
-  }
-}
-
-/*
- * transcribe a single file
- * can emit to a block results
- *
- *   params = Whisper::Params.new
- *   params.duration = 60_000
- *   whisper.transcribe "path/to/audio.wav", params do |text|
- *     puts text
- *   end
- *
- * call-seq:
- *   transcribe(path_to_audio, params) {|text| ...}
- **/
-static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
-  ruby_whisper *rw;
-  ruby_whisper_params *rwp;
-  VALUE wave_file_path, blk, params;
-
-  rb_scan_args(argc, argv, "02&", &wave_file_path, &params, &blk);
-  Data_Get_Struct(self, ruby_whisper, rw);
-  Data_Get_Struct(params, ruby_whisper_params, rwp);
-
-  if (!rb_respond_to(wave_file_path, id_to_s)) {
-    rb_raise(rb_eRuntimeError, "Expected file path to wave file");
-  }
-
-  std::string fname_inp = StringValueCStr(wave_file_path);
-
-  std::vector<float> pcmf32; // mono-channel F32 PCM
-  std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
-
-  // WAV input - this is directly from main.cpp example
-  {
-    drwav wav;
-    std::vector<uint8_t> wav_data; // used for pipe input from stdin
-
-    if (fname_inp == "-") {
-      {
-        uint8_t buf[1024];
-        while (true) {
-          const size_t n = fread(buf, 1, sizeof(buf), stdin);
-          if (n == 0) {
-            break;
-          }
-          wav_data.insert(wav_data.end(), buf, buf + n);
-        }
-      }
-
-      if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
-        fprintf(stderr, "error: failed to open WAV file from stdin\n");
-        return self;
-      }
-
-      fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
-    } else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
-      fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
-      return self;
-    }
-
-    if (wav.channels != 1 && wav.channels != 2) {
-      fprintf(stderr, "WAV file '%s' must be mono or stereo\n", fname_inp.c_str());
-      return self;
-    }
-
-    if (rwp->diarize && wav.channels != 2 && rwp->params.print_timestamps == false) {
-      fprintf(stderr, "WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", fname_inp.c_str());
-      return self;
-    }
-
-    if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
-      fprintf(stderr, "WAV file '%s' must be %i kHz\n", fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
-      return self;
-    }
-
-    if (wav.bitsPerSample != 16) {
-      fprintf(stderr, "WAV file '%s' must be 16-bit\n", fname_inp.c_str());
-      return self;
-    }
-
-    const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
-
-    std::vector<int16_t> pcm16;
-    pcm16.resize(n*wav.channels);
-    drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
-    drwav_uninit(&wav);
-
-    // convert to mono, float
-    pcmf32.resize(n);
-    if (wav.channels == 1) {
-      for (uint64_t i = 0; i < n; i++) {
-        pcmf32[i] = float(pcm16[i])/32768.0f;
-      }
-    } else {
-      for (uint64_t i = 0; i < n; i++) {
-        pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
-      }
-    }
-
-    if (rwp->diarize) {
-      // convert to stereo, float
-      pcmf32s.resize(2);
-
-      pcmf32s[0].resize(n);
-      pcmf32s[1].resize(n);
-      for (uint64_t i = 0; i < n; i++) {
-        pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
-        pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
-      }
-    }
-  }
-  {
-    static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
-
-    rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
-      bool is_aborted = *(bool*)user_data;
-      return !is_aborted;
-    };
-    rwp->params.encoder_begin_callback_user_data = &is_aborted;
-  }
-
-  register_callbacks(rwp, &self);
-
-  if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
-    fprintf(stderr, "failed to process audio\n");
-    return self;
-  }
-  const int n_segments = whisper_full_n_segments(rw->context);
-  VALUE output = rb_str_new2("");
-  for (int i = 0; i < n_segments; ++i) {
-    const char * text = whisper_full_get_segment_text(rw->context, i);
-    output = rb_str_concat(output, rb_str_new2(text));
-  }
-  VALUE idCall = id_call;
-  if (blk != Qnil) {
-    rb_funcall(blk, idCall, 1, output);
-  }
-  return self;
-}
-
-/*
- * call-seq:
- *   model_n_vocab -> Integer
- */
-VALUE ruby_whisper_model_n_vocab(VALUE self) {
-  ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
-  return INT2NUM(whisper_model_n_vocab(rw->context));
-}
-
-/*
- * call-seq:
- *   model_n_audio_ctx -> Integer
- */
-VALUE ruby_whisper_model_n_audio_ctx(VALUE self) {
-  ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
-  return INT2NUM(whisper_model_n_audio_ctx(rw->context));
-}
-
-/*
- * call-seq:
- *   model_n_audio_state -> Integer
- */
-VALUE ruby_whisper_model_n_audio_state(VALUE self) {
-  ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
-  return INT2NUM(whisper_model_n_audio_state(rw->context));
-}
-
-/*
- * call-seq:
- *   model_n_audio_head -> Integer
- */
-VALUE ruby_whisper_model_n_audio_head(VALUE self) {
-  ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
-  return INT2NUM(whisper_model_n_audio_head(rw->context));
-}
-
-/*
- * call-seq:
- *   model_n_audio_layer -> Integer
- */
-VALUE ruby_whisper_model_n_audio_layer(VALUE self) {
-  ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
-  return INT2NUM(whisper_model_n_audio_layer(rw->context));
-}
-
-/*
- * call-seq:
- *   model_n_text_ctx -> Integer
- */
-VALUE ruby_whisper_model_n_text_ctx(VALUE self) {
-  ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
-  return INT2NUM(whisper_model_n_text_ctx(rw->context));
-}
-
-/*
- * call-seq:
- *   model_n_text_state -> Integer
- */
-VALUE ruby_whisper_model_n_text_state(VALUE self) {
-  ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
-  return INT2NUM(whisper_model_n_text_state(rw->context));
-}
-
-/*
- * call-seq:
- *   model_n_text_head -> Integer
- */
-VALUE ruby_whisper_model_n_text_head(VALUE self) {
-  ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
-  return INT2NUM(whisper_model_n_text_head(rw->context));
-}
-
-/*
- * call-seq:
- *   model_n_text_layer -> Integer
- */
-VALUE ruby_whisper_model_n_text_layer(VALUE self) {
-  ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
-  return INT2NUM(whisper_model_n_text_layer(rw->context));
-}
-
-/*
- * call-seq:
- *   model_n_mels -> Integer
- */
-VALUE ruby_whisper_model_n_mels(VALUE self) {
-  ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
-  return INT2NUM(whisper_model_n_mels(rw->context));
-}
-
-/*
- * call-seq:
- *   model_ftype -> Integer
- */
-VALUE ruby_whisper_model_ftype(VALUE self) {
-  ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
-  return INT2NUM(whisper_model_ftype(rw->context));
-}
-
-/*
- * call-seq:
- *   model_type -> String
- */
-VALUE ruby_whisper_model_type(VALUE self) {
-  ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
-  return rb_str_new2(whisper_model_type_readable(rw->context));
-}
-
-/*
- * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
- * Not thread safe for same context
- * Uses the specified decoding strategy to obtain the text.
- *
- * call-seq:
- *   full(params, samples, n_samples) -> nil
- *   full(params, samples) -> nil
- *
- * The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data.
- */
-VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) {
-  if (argc < 2 || argc > 3) {
-    rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
-  }
-
-  ruby_whisper *rw;
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper, rw);
-  VALUE params = argv[0];
-  Data_Get_Struct(params, ruby_whisper_params, rwp);
-  VALUE samples = argv[1];
-  int n_samples;
-  rb_memory_view_t view;
-  const bool memory_view_available_p = rb_memory_view_available_p(samples);
-  if (argc == 3) {
-    n_samples = NUM2INT(argv[2]);
-    if (TYPE(samples) == T_ARRAY) {
-      if (RARRAY_LEN(samples) < n_samples) {
-        rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
-      }
-    }
-    // Should check when samples.respond_to?(:length)?
-  } else {
-    if (TYPE(samples) == T_ARRAY) {
-      n_samples = RARRAY_LEN(samples);
-    } else if (memory_view_available_p) {
-      if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
-        view.obj = Qnil;
-        rb_raise(rb_eArgError, "unable to get a memory view");
-      }
-      n_samples = view.byte_size / view.item_size;
-    } else if (rb_respond_to(samples, id_length)) {
-      n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
-    } else {
-      rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
-    }
-  }
-  float * c_samples = (float *)malloc(n_samples * sizeof(float));
-  if (memory_view_available_p)  {
-    c_samples = (float *)view.data;
-  } else {
-    if (TYPE(samples) == T_ARRAY) {
-      for (int i = 0; i < n_samples; i++) {
-        c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
-      }
-    } else {
-      // TODO: use rb_block_call
-      VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
-      for (int i = 0; i < n_samples; i++) {
-        // TODO: check if iter is exhausted and raise ArgumentError appropriately
-        VALUE sample = rb_funcall(iter, id_next, 0);
-        c_samples[i] = RFLOAT_VALUE(sample);
-      }
-    }
-  }
-  register_callbacks(rwp, &self);
-  const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples);
-  if (0 == result) {
-    return Qnil;
-  } else {
-    rb_exc_raise(rb_funcall(eError, id_new, 1, result));
-  }
-}
-
-/*
- * Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
- * Result is stored in the default state of the context
- * Not thread safe if executed in parallel on the same context.
- * It seems this approach can offer some speedup in some cases.
- * However, the transcription accuracy can be worse at the beginning and end of each chunk.
- *
- * call-seq:
- *   full_parallel(params, samples) -> nil
- *   full_parallel(params, samples, n_samples) -> nil
- *   full_parallel(params, samples, n_samples, n_processors) -> nil
- *   full_parallel(params, samples, nil, n_processors) -> nil
- */
-static VALUE ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) {
-  if (argc < 2 || argc > 4) {
-    rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
-  }
-
-  ruby_whisper *rw;
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper, rw);
-  VALUE params = argv[0];
-  Data_Get_Struct(params, ruby_whisper_params, rwp);
-  VALUE samples = argv[1];
-  int n_samples;
-  int n_processors;
-  rb_memory_view_t view;
-  const bool memory_view_available_p = rb_memory_view_available_p(samples);
-  switch (argc) {
-  case 2:
-    n_processors = 1;
-    break;
-  case 3:
-    n_processors = 1;
-    break;
-  case 4:
-    n_processors = NUM2INT(argv[3]);
-    break;
-  }
-  if (argc >= 3 && !NIL_P(argv[2])) {
-    n_samples = NUM2INT(argv[2]);
-    if (TYPE(samples) == T_ARRAY) {
-      if (RARRAY_LEN(samples) < n_samples) {
-        rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
-      }
-    }
-    // Should check when samples.respond_to?(:length)?
-  } else if (memory_view_available_p) {
-    if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
-      view.obj = Qnil;
-      rb_raise(rb_eArgError, "unable to get a memory view");
-    }
-    n_samples = view.byte_size / view.item_size;
-  } else {
-    if (TYPE(samples) == T_ARRAY) {
-      n_samples = RARRAY_LEN(samples);
-    } else if (rb_respond_to(samples, id_length)) {
-      n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
-    } else {
-      rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
-    }
-  }
-  float * c_samples = (float *)malloc(n_samples * sizeof(float));
-  if (memory_view_available_p) {
-    c_samples = (float *)view.data;
-  } else {
-    if (TYPE(samples) == T_ARRAY) {
-      for (int i = 0; i < n_samples; i++) {
-        c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
-      }
-    } else {
-      // FIXME: use rb_block_call
-      VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
-      for (int i = 0; i < n_samples; i++) {
-        // TODO: check if iter is exhausted and raise ArgumentError
-        VALUE sample = rb_funcall(iter, id_next, 0);
-        c_samples[i] = RFLOAT_VALUE(sample);
-      }
-    }
-  }
-  register_callbacks(rwp, &self);
-  const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors);
-  if (0 == result) {
-    return Qnil;
-  } else {
-    rb_exc_raise(rb_funcall(eError, id_new, 1, result));
-  }
-}
-
-/*
- * Number of segments.
- *
- * call-seq:
- *   full_n_segments -> Integer
- */
-static VALUE ruby_whisper_full_n_segments(VALUE self) {
-  ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
-  return INT2NUM(whisper_full_n_segments(rw->context));
-}
-
-/*
- * Language ID, which can be converted to string by Whisper.lang_str and Whisper.lang_str_full.
- *
- * call-seq:
- *   full_lang_id -> Integer
- */
-static VALUE ruby_whisper_full_lang_id(VALUE self) {
-  ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
-  return INT2NUM(whisper_full_lang_id(rw->context));
-}
-
-static int ruby_whisper_full_check_segment_index(const ruby_whisper * rw, const VALUE i_segment) {
-  const int c_i_segment = NUM2INT(i_segment);
-  if (c_i_segment < 0 || c_i_segment >= whisper_full_n_segments(rw->context)) {
-    rb_raise(rb_eIndexError, "segment index %d out of range", c_i_segment);
-  }
-  return c_i_segment;
-}
-
-/*
- * Start time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
- *
- *   full_get_segment_t0(3) # => 1668 (16680 ms)
- *
- * call-seq:
- *   full_get_segment_t0(segment_index) -> Integer
- */
-static VALUE ruby_whisper_full_get_segment_t0(VALUE self, VALUE i_segment) {
-  ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
-  const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
-  const int64_t t0 = whisper_full_get_segment_t0(rw->context, c_i_segment);
-  return INT2NUM(t0);
-}
-
-/*
- * End time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
- *
- *   full_get_segment_t1(3) # => 1668 (16680 ms)
- *
- * call-seq:
- *   full_get_segment_t1(segment_index) -> Integer
- */
-static VALUE ruby_whisper_full_get_segment_t1(VALUE self, VALUE i_segment) {
-  ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
-  const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
-  const int64_t t1 = whisper_full_get_segment_t1(rw->context, c_i_segment);
-  return INT2NUM(t1);
-}
-
-/*
- * Whether the next segment indexed by +segment_index+ is predicated as a speaker turn.
- *
- *   full_get_segment_speacker_turn_next(3) # => true
- *
- * call-seq:
- *   full_get_segment_speacker_turn_next(segment_index) -> bool
- */
-static VALUE ruby_whisper_full_get_segment_speaker_turn_next(VALUE self, VALUE i_segment) {
-  ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
-  const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
-  const bool speaker_turn_next = whisper_full_get_segment_speaker_turn_next(rw->context, c_i_segment);
-  return speaker_turn_next ? Qtrue : Qfalse;
-}
-
-/*
- * Text of a segment indexed by +segment_index+.
- *
- *   full_get_segment_text(3) # => "ask not what your country can do for you, ..."
- *
- * call-seq:
- *   full_get_segment_text(segment_index) -> String
- */
-static VALUE ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment) {
-  ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
-  const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
-  const char * text = whisper_full_get_segment_text(rw->context, c_i_segment);
-  return rb_str_new2(text);
-}
-
-/*
- * call-seq:
- *   full_get_segment_no_speech_prob(segment_index) -> Float
- */
-static VALUE ruby_whisper_full_get_segment_no_speech_prob(VALUE self, VALUE i_segment) {
-  ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
-  const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
-  const float no_speech_prob = whisper_full_get_segment_no_speech_prob(rw->context, c_i_segment);
-  return DBL2NUM(no_speech_prob);
-}
-
-/*
- * params.language = "auto" | "en", etc...
- *
- * call-seq:
- *   language = lang_name -> lang_name
- */
-static VALUE ruby_whisper_params_set_language(VALUE self, VALUE value) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  if (value == Qfalse || value == Qnil) {
-    rwp->params.language = "auto";
-  } else {
-    rwp->params.language = StringValueCStr(value);
-  }
-  return value;
-}
-/*
- * call-seq:
- *   language -> String
- */
-static VALUE ruby_whisper_params_get_language(VALUE self) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  if (rwp->params.language) {
-    return rb_str_new2(rwp->params.language);
-  } else {
-    return rb_str_new2("auto");
-  }
-}
-/*
- * call-seq:
- *   translate = do_translate -> do_translate
- */
-static VALUE ruby_whisper_params_set_translate(VALUE self, VALUE value) {
-  BOOL_PARAMS_SETTER(self, translate, value)
-}
-/*
- * call-seq:
- *   translate -> bool
- */
-static VALUE ruby_whisper_params_get_translate(VALUE self) {
-  BOOL_PARAMS_GETTER(self, translate)
-}
-/*
- * call-seq:
- *   no_context = dont_use_context -> dont_use_context
- */
-static VALUE ruby_whisper_params_set_no_context(VALUE self, VALUE value) {
-  BOOL_PARAMS_SETTER(self, no_context, value)
-}
-/*
- * If true, does not use past transcription (if any) as initial prompt for the decoder.
- *
- * call-seq:
- *   no_context -> bool
- */
-static VALUE ruby_whisper_params_get_no_context(VALUE self) {
-  BOOL_PARAMS_GETTER(self, no_context)
-}
-/*
- * call-seq:
- *   single_segment = force_single -> force_single
- */
-static VALUE ruby_whisper_params_set_single_segment(VALUE self, VALUE value) {
-  BOOL_PARAMS_SETTER(self, single_segment, value)
-}
-/*
- * If true, forces single segment output (useful for streaming).
- *
- * call-seq:
- *   single_segment -> bool
- */
-static VALUE ruby_whisper_params_get_single_segment(VALUE self) {
-  BOOL_PARAMS_GETTER(self, single_segment)
-}
-/*
- * call-seq:
- *   print_special = force_print -> force_print
- */
-static VALUE ruby_whisper_params_set_print_special(VALUE self, VALUE value) {
-  BOOL_PARAMS_SETTER(self, print_special, value)
-}
-/*
- * If true, prints special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.).
- *
- * call-seq:
- *   print_special -> bool
- */
-static VALUE ruby_whisper_params_get_print_special(VALUE self) {
-  BOOL_PARAMS_GETTER(self, print_special)
-}
-/*
- * call-seq:
- *   print_progress = force_print -> force_print
- */
-static VALUE ruby_whisper_params_set_print_progress(VALUE self, VALUE value) {
-  BOOL_PARAMS_SETTER(self, print_progress, value)
-}
-/*
- * If true, prints progress information.
- *
- * call-seq:
- *   print_progress -> bool
- */
-static VALUE ruby_whisper_params_get_print_progress(VALUE self) {
-  BOOL_PARAMS_GETTER(self, print_progress)
-}
-/*
- * call-seq:
- *   print_realtime = force_print -> force_print
- */
-static VALUE ruby_whisper_params_set_print_realtime(VALUE self, VALUE value) {
-  BOOL_PARAMS_SETTER(self, print_realtime, value)
-}
-/*
- * If true, prints results from within whisper.cpp. (avoid it, use callback instead)
- * call-seq:
- *   print_realtime -> bool
- */
-static VALUE ruby_whisper_params_get_print_realtime(VALUE self) {
-  BOOL_PARAMS_GETTER(self, print_realtime)
-}
-/*
- * call-seq:
- *   print_timestamps = force_print -> force_print
- */
-static VALUE ruby_whisper_params_set_print_timestamps(VALUE self, VALUE value) {
-  BOOL_PARAMS_SETTER(self, print_timestamps, value)
-}
-/*
- * If true, prints timestamps for each text segment when printing realtime.
- *
- * call-seq:
- *   print_timestamps -> bool
- */
-static VALUE ruby_whisper_params_get_print_timestamps(VALUE self) {
-  BOOL_PARAMS_GETTER(self, print_timestamps)
-}
-/*
- * call-seq:
- *   suppress_blank = force_suppress -> force_suppress
- */
-static VALUE ruby_whisper_params_set_suppress_blank(VALUE self, VALUE value) {
-  BOOL_PARAMS_SETTER(self, suppress_blank, value)
-}
-/*
- * If true, suppresses blank outputs.
- *
- * call-seq:
- *   suppress_blank -> bool
- */
-static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) {
-  BOOL_PARAMS_GETTER(self, suppress_blank)
-}
-/*
- * call-seq:
- *   suppress_nst = force_suppress -> force_suppress
- */
-static VALUE ruby_whisper_params_set_suppress_nst(VALUE self, VALUE value) {
-  BOOL_PARAMS_SETTER(self, suppress_nst, value)
-}
-/*
- * If true, suppresses non-speech-tokens.
- *
- * call-seq:
- *   suppress_nst -> bool
- */
-static VALUE ruby_whisper_params_get_suppress_nst(VALUE self) {
-  BOOL_PARAMS_GETTER(self, suppress_nst)
-}
-/*
- * If true, enables token-level timestamps.
- *
- * call-seq:
- *   token_timestamps -> bool
- */
-static VALUE ruby_whisper_params_get_token_timestamps(VALUE self) {
-  BOOL_PARAMS_GETTER(self, token_timestamps)
-}
-/*
- * call-seq:
- *   token_timestamps = force_timestamps -> force_timestamps
- */
-static VALUE ruby_whisper_params_set_token_timestamps(VALUE self, VALUE value) {
-  BOOL_PARAMS_SETTER(self, token_timestamps, value)
-}
-/*
- * If true, split on word rather than on token (when used with max_len).
- *
- * call-seq:
- *   translate -> bool
- */
-static VALUE ruby_whisper_params_get_split_on_word(VALUE self) {
-  BOOL_PARAMS_GETTER(self, split_on_word)
-}
-/*
- * call-seq:
- *   split_on_word = force_split -> force_split
- */
-static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) {
-  BOOL_PARAMS_SETTER(self, split_on_word, value)
-}
-/*
- * Tokens to provide to the whisper decoder as initial prompt
- * these are prepended to any existing text context from a previous call
- * use whisper_tokenize() to convert text to tokens.
- * Maximum of whisper_n_text_ctx()/2 tokens are used (typically 224).
- *
- * call-seq:
- *   initial_prompt -> String
- */
-static VALUE ruby_whisper_params_get_initial_prompt(VALUE self) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  return rwp->params.initial_prompt == nullptr ? Qnil : rb_str_new2(rwp->params.initial_prompt);
-}
-/*
- * call-seq:
- *   initial_prompt = prompt -> prompt
- */
-static VALUE ruby_whisper_params_set_initial_prompt(VALUE self, VALUE value) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  rwp->params.initial_prompt = StringValueCStr(value);
-  return value;
-}
-/*
- * If true, enables diarization.
- *
- * call-seq:
- *   diarize -> bool
- */
-static VALUE ruby_whisper_params_get_diarize(VALUE self) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  if (rwp->diarize) {
-    return Qtrue;
-  } else {
-    return Qfalse;
-  }
-}
-/*
- * call-seq:
- *   diarize = force_diarize -> force_diarize
- */
-static VALUE ruby_whisper_params_set_diarize(VALUE self, VALUE value) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  if (value == Qfalse || value == Qnil) {
-    rwp->diarize = false;
-  } else {
-    rwp->diarize = true;
-  } \
-  return value;
-}
-
-/*
- * Start offset in ms.
- *
- * call-seq:
- *   offset -> Integer
- */
-static VALUE ruby_whisper_params_get_offset(VALUE self) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  return INT2NUM(rwp->params.offset_ms);
-}
-/*
- * call-seq:
- *   offset = offset_ms -> offset_ms
- */
-static VALUE ruby_whisper_params_set_offset(VALUE self, VALUE value) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  rwp->params.offset_ms = NUM2INT(value);
-  return value;
-}
-/*
- * Audio duration to process in ms.
- *
- * call-seq:
- *   duration -> Integer
- */
-static VALUE ruby_whisper_params_get_duration(VALUE self) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  return INT2NUM(rwp->params.duration_ms);
-}
-/*
- * call-seq:
- *   duration = duration_ms -> duration_ms
- */
-static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  rwp->params.duration_ms = NUM2INT(value);
-  return value;
-}
-
-/*
- * Max tokens to use from past text as prompt for the decoder.
- *
- * call-seq:
- *   max_text_tokens -> Integer
- */
-static VALUE ruby_whisper_params_get_max_text_tokens(VALUE self) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  return INT2NUM(rwp->params.n_max_text_ctx);
-}
-/*
- * call-seq:
- *   max_text_tokens = n_tokens -> n_tokens
- */
-static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  rwp->params.n_max_text_ctx = NUM2INT(value);
-  return value;
-}
-/*
- * call-seq:
- *   temperature -> Float
- */
-static VALUE ruby_whisper_params_get_temperature(VALUE self) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  return DBL2NUM(rwp->params.temperature);
-}
-/*
- * call-seq:
- *   temperature = temp -> temp
- */
-static VALUE ruby_whisper_params_set_temperature(VALUE self, VALUE value) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  rwp->params.temperature = RFLOAT_VALUE(value);
-  return value;
-}
-/*
- * See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
- *
- * call-seq:
- *   max_initial_ts -> Flaot
- */
-static VALUE ruby_whisper_params_get_max_initial_ts(VALUE self) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  return DBL2NUM(rwp->params.max_initial_ts);
-}
-/*
- * call-seq:
- *   max_initial_ts = timestamp -> timestamp
- */
-static VALUE ruby_whisper_params_set_max_initial_ts(VALUE self, VALUE value) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  rwp->params.max_initial_ts = RFLOAT_VALUE(value);
-  return value;
-}
-/*
- * call-seq:
- *   length_penalty -> Float
- */
-static VALUE ruby_whisper_params_get_length_penalty(VALUE self) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  return DBL2NUM(rwp->params.length_penalty);
-}
-/*
- * call-seq:
- *   length_penalty = penalty -> penalty
- */
-static VALUE ruby_whisper_params_set_length_penalty(VALUE self, VALUE value) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  rwp->params.length_penalty = RFLOAT_VALUE(value);
-  return value;
-}
-/*
- * call-seq:
- *   temperature_inc -> Float
- */
-static VALUE ruby_whisper_params_get_temperature_inc(VALUE self) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  return DBL2NUM(rwp->params.temperature_inc);
-}
-/*
- * call-seq:
- *   temperature_inc = inc -> inc
- */
-static VALUE ruby_whisper_params_set_temperature_inc(VALUE self, VALUE value) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  rwp->params.temperature_inc = RFLOAT_VALUE(value);
-  return value;
-}
-/*
- * Similar to OpenAI's "compression_ratio_threshold"
- *
- * call-seq:
- *   entropy_thold -> Float
- */
-static VALUE ruby_whisper_params_get_entropy_thold(VALUE self) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  return DBL2NUM(rwp->params.entropy_thold);
-}
-/*
- * call-seq:
- *   entropy_thold = threshold -> threshold
- */
-static VALUE ruby_whisper_params_set_entropy_thold(VALUE self, VALUE value) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  rwp->params.entropy_thold = RFLOAT_VALUE(value);
-  return value;
-}
-/*
- * call-seq:
- *   logprob_thold -> Float
- */
-static VALUE ruby_whisper_params_get_logprob_thold(VALUE self) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  return DBL2NUM(rwp->params.logprob_thold);
-}
-/*
- * call-seq:
- *   logprob_thold = threshold -> threshold
- */
-static VALUE ruby_whisper_params_set_logprob_thold(VALUE self, VALUE value) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  rwp->params.logprob_thold = RFLOAT_VALUE(value);
-  return value;
-}
-/*
- * call-seq:
- *   no_speech_thold -> Float
- */
-static VALUE ruby_whisper_params_get_no_speech_thold(VALUE self) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  return DBL2NUM(rwp->params.no_speech_thold);
-}
-/*
- * call-seq:
- *   no_speech_thold = threshold -> threshold
- */
-static VALUE ruby_whisper_params_set_no_speech_thold(VALUE self, VALUE value) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  rwp->params.no_speech_thold = RFLOAT_VALUE(value);
-  return value;
-}
-/*
- * Sets new segment callback, called for every newly generated text segment.
- *
- *   params.new_segment_callback = ->(context, _, n_new, user_data) {
- *     # ...
- *   }
- *
- * call-seq:
- *   new_segment_callback = callback -> callback
- */
-static VALUE ruby_whisper_params_set_new_segment_callback(VALUE self, VALUE value) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  rwp->new_segment_callback_container->callback = value;
-  return value;
-}
-/*
- * Sets user data passed to the last argument of new segment callback.
- *
- * call-seq:
- *   new_segment_callback_user_data = user_data -> use_data
- */
-static VALUE ruby_whisper_params_set_new_segment_callback_user_data(VALUE self, VALUE value) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  rwp->new_segment_callback_container->user_data = value;
-  return value;
-}
-/*
- * Sets progress callback, called on each progress update.
- *
- *   params.new_segment_callback = ->(context, _, n_new, user_data) {
- *     # ...
- *   }
- *
- * call-seq:
- *   progress_callback = callback -> callback
- */
-static VALUE ruby_whisper_params_set_progress_callback(VALUE self, VALUE value) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  rwp->progress_callback_container->callback = value;
-  return value;
-}
-/*
- * Sets user data passed to the last argument of progress callback.
- *
- * call-seq:
- *   progress_callback_user_data = user_data -> use_data
- */
-static VALUE ruby_whisper_params_set_progress_callback_user_data(VALUE self, VALUE value) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  rwp->progress_callback_container->user_data = value;
-  return value;
-}
-/*
- * Sets abort callback, called to check if the process should be aborted.
- *
- *   params.abort_callback = ->(user_data) {
- *     # ...
- *   }
- *
- * call-seq:
- *   abort_callback = callback -> callback
- */
-static VALUE ruby_whisper_params_set_abort_callback(VALUE self, VALUE value) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  rwp->abort_callback_container->callback = value;
-  return value;
-}
-/*
- * Sets user data passed to the last argument of abort callback.
- *
- * call-seq:
- *   abort_callback_user_data = user_data -> use_data
- */
-static VALUE ruby_whisper_params_set_abort_callback_user_data(VALUE self, VALUE value) {
-  ruby_whisper_params *rwp;
-  Data_Get_Struct(self, ruby_whisper_params, rwp);
-  rwp->abort_callback_container->user_data = value;
-  return value;
-}
-
-// High level API
-
-typedef struct {
-  VALUE context;
-  int index;
-} ruby_whisper_segment;
-
-typedef struct {
-  VALUE context;
-} ruby_whisper_model;
-
-static void rb_whisper_segment_mark(ruby_whisper_segment *rws) {
-  rb_gc_mark(rws->context);
-}
-
-static VALUE ruby_whisper_segment_allocate(VALUE klass) {
-  ruby_whisper_segment *rws;
-  rws = ALLOC(ruby_whisper_segment);
-  return Data_Wrap_Struct(klass, rb_whisper_segment_mark, RUBY_DEFAULT_FREE, rws);
-}
-
-static VALUE rb_whisper_segment_initialize(VALUE context, int index) {
-  ruby_whisper_segment *rws;
-  const VALUE segment = ruby_whisper_segment_allocate(cSegment);
-  Data_Get_Struct(segment, ruby_whisper_segment, rws);
-  rws->context = context;
-  rws->index = index;
-  return segment;
-};
-
-/*
- * Yields each Whisper::Segment:
- *
- *   whisper.transcribe("path/to/audio.wav", params)
- *   whisper.each_segment do |segment|
- *     puts segment.text
- *   end
- *
- * Returns an Enumerator if no block given:
- *
- *   whisper.transcribe("path/to/audio.wav", params)
- *   enum = whisper.each_segment
- *   enum.to_a # => [#<Whisper::Segment>, ...]
- *
- * call-seq:
- *   each_segment {|segment| ... }
- *   each_segment -> Enumerator
- */
-static VALUE ruby_whisper_each_segment(VALUE self) {
-  if (!rb_block_given_p()) {
-    const VALUE method_name = rb_funcall(self, id___method__, 0);
-    return rb_funcall(self, id_to_enum, 1, method_name);
-  }
-
-  ruby_whisper *rw;
-  Data_Get_Struct(self, ruby_whisper, rw);
-
-  const int n_segments = whisper_full_n_segments(rw->context);
-  for (int i = 0; i < n_segments; ++i) {
-    rb_yield(rb_whisper_segment_initialize(self, i));
-  }
-
-  return self;
-}
-
-/*
- * Hook called on new segment. Yields each Whisper::Segment.
- *
- *   whisper.on_new_segment do |segment|
- *     # ...
- *   end
- *
- * call-seq:
- *   on_new_segment {|segment| ... }
- */
-static VALUE ruby_whisper_params_on_new_segment(VALUE self) {
-  ruby_whisper_params *rws;
-  Data_Get_Struct(self, ruby_whisper_params, rws);
-  const VALUE blk = rb_block_proc();
-  rb_ary_push(rws->new_segment_callback_container->callbacks, blk);
-  return Qnil;
-}
-
-/*
- * Hook called on progress update. Yields each progress Integer between 0 and 100.
- *
- *   whisper.on_progress do |progress|
- *     # ...
- *   end
- *
- * call-seq:
- *   on_progress {|progress| ... }
- */
-static VALUE ruby_whisper_params_on_progress(VALUE self) {
-  ruby_whisper_params *rws;
-  Data_Get_Struct(self, ruby_whisper_params, rws);
-  const VALUE blk = rb_block_proc();
-  rb_ary_push(rws->progress_callback_container->callbacks, blk);
-  return Qnil;
-}
-
-/*
- * Call block to determine whether abort or not. Return +true+ when you want to abort.
- *
- *   params.abort_on do
- *     if some_condition
- *       true # abort
- *     else
- *       false # continue
- *     end
- *   end
- *
- * call-seq:
- *   abort_on { ... }
- */
-static VALUE ruby_whisper_params_abort_on(VALUE self) {
-  ruby_whisper_params *rws;
-  Data_Get_Struct(self, ruby_whisper_params, rws);
-  const VALUE blk = rb_block_proc();
-  rb_ary_push(rws->abort_callback_container->callbacks, blk);
-  return Qnil;
-}
-
-/*
- * Start time in milliseconds.
- *
- * call-seq:
- *   start_time -> Integer
- */
-static VALUE ruby_whisper_segment_get_start_time(VALUE self) {
-  ruby_whisper_segment *rws;
-  Data_Get_Struct(self, ruby_whisper_segment, rws);
-  ruby_whisper *rw;
-  Data_Get_Struct(rws->context, ruby_whisper, rw);
-  const int64_t t0 = whisper_full_get_segment_t0(rw->context, rws->index);
-  // able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it
-  return INT2NUM(t0 * 10);
-}
-
-/*
- * End time in milliseconds.
- *
- * call-seq:
- *   end_time -> Integer
- */
-static VALUE ruby_whisper_segment_get_end_time(VALUE self) {
-  ruby_whisper_segment *rws;
-  Data_Get_Struct(self, ruby_whisper_segment, rws);
-  ruby_whisper *rw;
-  Data_Get_Struct(rws->context, ruby_whisper, rw);
-  const int64_t t1 = whisper_full_get_segment_t1(rw->context, rws->index);
-  // able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it
-  return INT2NUM(t1 * 10);
-}
-
-/*
- * Whether the next segment is predicted as a speaker turn.
- *
- * call-seq:
- *   speaker_turn_next? -> bool
- */
-static VALUE ruby_whisper_segment_get_speaker_turn_next(VALUE self) {
-  ruby_whisper_segment *rws;
-  Data_Get_Struct(self, ruby_whisper_segment, rws);
-  ruby_whisper *rw;
-  Data_Get_Struct(rws->context, ruby_whisper, rw);
-  return whisper_full_get_segment_speaker_turn_next(rw->context, rws->index) ? Qtrue : Qfalse;
-}
-
-/*
- * call-seq:
- *   text -> String
- */
-static VALUE ruby_whisper_segment_get_text(VALUE self) {
-  ruby_whisper_segment *rws;
-  Data_Get_Struct(self, ruby_whisper_segment, rws);
-  ruby_whisper *rw;
-  Data_Get_Struct(rws->context, ruby_whisper, rw);
-  const char * text = whisper_full_get_segment_text(rw->context, rws->index);
-  return rb_str_new2(text);
-}
-
-/*
- * call-seq:
- *   no_speech_prob -> Float
- */
-static VALUE ruby_whisper_segment_get_no_speech_prob(VALUE self) {
-  ruby_whisper_segment *rws;
-  Data_Get_Struct(self, ruby_whisper_segment, rws);
-  ruby_whisper *rw;
-  Data_Get_Struct(rws->context, ruby_whisper, rw);
-  return DBL2NUM(whisper_full_get_segment_no_speech_prob(rw->context, rws->index));
-}
-
-static void rb_whisper_model_mark(ruby_whisper_model *rwm) {
-  rb_gc_mark(rwm->context);
-}
-
-static VALUE ruby_whisper_model_allocate(VALUE klass) {
-  ruby_whisper_model *rwm;
-  rwm = ALLOC(ruby_whisper_model);
-  return Data_Wrap_Struct(klass, rb_whisper_model_mark, RUBY_DEFAULT_FREE, rwm);
-}
-
-static VALUE rb_whisper_model_initialize(VALUE context) {
-  ruby_whisper_model *rwm;
-  const VALUE model = ruby_whisper_model_allocate(cModel);
-  Data_Get_Struct(model, ruby_whisper_model, rwm);
-  rwm->context = context;
-  return model;
-};
-
-/*
- * call-seq:
- *   model -> Whisper::Model
- */
-static VALUE ruby_whisper_get_model(VALUE self) {
-  return rb_whisper_model_initialize(self);
-}
-
-/*
- * call-seq:
- *   n_vocab -> Integer
- */
-static VALUE ruby_whisper_c_model_n_vocab(VALUE self) {
-  ruby_whisper_model *rwm;
-  Data_Get_Struct(self, ruby_whisper_model, rwm);
-  ruby_whisper *rw;
-  Data_Get_Struct(rwm->context, ruby_whisper, rw);
-  return INT2NUM(whisper_model_n_vocab(rw->context));
-}
-
-/*
- * call-seq:
- *   n_audio_ctx -> Integer
- */
-static VALUE ruby_whisper_c_model_n_audio_ctx(VALUE self) {
-  ruby_whisper_model *rwm;
-  Data_Get_Struct(self, ruby_whisper_model, rwm);
-  ruby_whisper *rw;
-  Data_Get_Struct(rwm->context, ruby_whisper, rw);
-  return INT2NUM(whisper_model_n_audio_ctx(rw->context));
-}
-
-/*
- * call-seq:
- *   n_audio_state -> Integer
- */
-static VALUE ruby_whisper_c_model_n_audio_state(VALUE self) {
-  ruby_whisper_model *rwm;
-  Data_Get_Struct(self, ruby_whisper_model, rwm);
-  ruby_whisper *rw;
-  Data_Get_Struct(rwm->context, ruby_whisper, rw);
-  return INT2NUM(whisper_model_n_audio_state(rw->context));
-}
-
-/*
- * call-seq:
- *   n_audio_head -> Integer
- */
-static VALUE ruby_whisper_c_model_n_audio_head(VALUE self) {
-  ruby_whisper_model *rwm;
-  Data_Get_Struct(self, ruby_whisper_model, rwm);
-  ruby_whisper *rw;
-  Data_Get_Struct(rwm->context, ruby_whisper, rw);
-  return INT2NUM(whisper_model_n_audio_head(rw->context));
-}
-
-/*
- * call-seq:
- *   n_audio_layer -> Integer
- */
-static VALUE ruby_whisper_c_model_n_audio_layer(VALUE self) {
-  ruby_whisper_model *rwm;
-  Data_Get_Struct(self, ruby_whisper_model, rwm);
-  ruby_whisper *rw;
-  Data_Get_Struct(rwm->context, ruby_whisper, rw);
-  return INT2NUM(whisper_model_n_audio_layer(rw->context));
-}
-
-/*
- * call-seq:
- *   n_text_ctx -> Integer
- */
-static VALUE ruby_whisper_c_model_n_text_ctx(VALUE self) {
-  ruby_whisper_model *rwm;
-  Data_Get_Struct(self, ruby_whisper_model, rwm);
-  ruby_whisper *rw;
-  Data_Get_Struct(rwm->context, ruby_whisper, rw);
-  return INT2NUM(whisper_model_n_text_ctx(rw->context));
-}
-
-/*
- * call-seq:
- *   n_text_state -> Integer
- */
-static VALUE ruby_whisper_c_model_n_text_state(VALUE self) {
-  ruby_whisper_model *rwm;
-  Data_Get_Struct(self, ruby_whisper_model, rwm);
-  ruby_whisper *rw;
-  Data_Get_Struct(rwm->context, ruby_whisper, rw);
-  return INT2NUM(whisper_model_n_text_state(rw->context));
-}
-
-/*
- * call-seq:
- *   n_text_head -> Integer
- */
-static VALUE ruby_whisper_c_model_n_text_head(VALUE self) {
-  ruby_whisper_model *rwm;
-  Data_Get_Struct(self, ruby_whisper_model, rwm);
-  ruby_whisper *rw;
-  Data_Get_Struct(rwm->context, ruby_whisper, rw);
-  return INT2NUM(whisper_model_n_text_head(rw->context));
-}
-
-/*
- * call-seq:
- *   n_text_layer -> Integer
- */
-static VALUE ruby_whisper_c_model_n_text_layer(VALUE self) {
-  ruby_whisper_model *rwm;
-  Data_Get_Struct(self, ruby_whisper_model, rwm);
-  ruby_whisper *rw;
-  Data_Get_Struct(rwm->context, ruby_whisper, rw);
-  return INT2NUM(whisper_model_n_text_layer(rw->context));
-}
-
-/*
- * call-seq:
- *   n_mels -> Integer
- */
-static VALUE ruby_whisper_c_model_n_mels(VALUE self) {
-  ruby_whisper_model *rwm;
-  Data_Get_Struct(self, ruby_whisper_model, rwm);
-  ruby_whisper *rw;
-  Data_Get_Struct(rwm->context, ruby_whisper, rw);
-  return INT2NUM(whisper_model_n_mels(rw->context));
-}
-
-/*
- * call-seq:
- *   ftype -> Integer
- */
-static VALUE ruby_whisper_c_model_ftype(VALUE self) {
-  ruby_whisper_model *rwm;
-  Data_Get_Struct(self, ruby_whisper_model, rwm);
-  ruby_whisper *rw;
-  Data_Get_Struct(rwm->context, ruby_whisper, rw);
-  return INT2NUM(whisper_model_ftype(rw->context));
-}
-
-/*
- * call-seq:
- *   type -> String
- */
-static VALUE ruby_whisper_c_model_type(VALUE self) {
-  ruby_whisper_model *rwm;
-  Data_Get_Struct(self, ruby_whisper_model, rwm);
-  ruby_whisper *rw;
-  Data_Get_Struct(rwm->context, ruby_whisper, rw);
-  return rb_str_new2(whisper_model_type_readable(rw->context));
-}
-
-static VALUE ruby_whisper_error_initialize(VALUE self, VALUE code) {
-  const int c_code = NUM2INT(code);
-  const char *raw_message;
-  switch (c_code) {
-  case -2:
-    raw_message = "failed to compute log mel spectrogram";
-    break;
-  case -3:
-    raw_message = "failed to auto-detect language";
-    break;
-  case -4:
-    raw_message = "too many decoders requested";
-    break;
-  case -5:
-    raw_message = "audio_ctx is larger than the maximum allowed";
-    break;
-  case -6:
-    raw_message = "failed to encode";
-    break;
-  case -7:
-    raw_message = "whisper_kv_cache_init() failed for self-attention cache";
-    break;
-  case -8:
-    raw_message = "failed to decode";
-    break;
-  case -9:
-    raw_message = "failed to decode";
-    break;
-  default:
-    raw_message = "unknown error";
-    break;
-  }
-  const VALUE message = rb_str_new2(raw_message);
-  rb_call_super(1, &message);
-  rb_iv_set(self, "@code", code);
-
-  return self;
-}
-
-
-void Init_whisper() {
-  id_to_s = rb_intern("to_s");
-  id_call = rb_intern("call");
-  id___method__ = rb_intern("__method__");
-  id_to_enum = rb_intern("to_enum");
-  id_length = rb_intern("length");
-  id_next = rb_intern("next");
-  id_new = rb_intern("new");
-  id_to_path = rb_intern("to_path");
-  id_URI = rb_intern("URI");
-  id_pre_converted_models = rb_intern("pre_converted_models");
-
-  mWhisper = rb_define_module("Whisper");
-  cContext = rb_define_class_under(mWhisper, "Context", rb_cObject);
-  cParams  = rb_define_class_under(mWhisper, "Params", rb_cObject);
-  eError   = rb_define_class_under(mWhisper, "Error", rb_eStandardError);
-
-  rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE));
-  rb_define_const(mWhisper, "LOG_LEVEL_INFO", INT2NUM(GGML_LOG_LEVEL_INFO));
-  rb_define_const(mWhisper, "LOG_LEVEL_WARN", INT2NUM(GGML_LOG_LEVEL_WARN));
-  rb_define_const(mWhisper, "LOG_LEVEL_ERROR", INT2NUM(GGML_LOG_LEVEL_ERROR));
-  rb_define_const(mWhisper, "LOG_LEVEL_DEBUG", INT2NUM(GGML_LOG_LEVEL_DEBUG));
-  rb_define_const(mWhisper, "LOG_LEVEL_CONT", INT2NUM(GGML_LOG_LEVEL_CONT));
-
-  rb_define_singleton_method(mWhisper, "lang_max_id", ruby_whisper_s_lang_max_id, 0);
-  rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1);
-  rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1);
-  rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1);
-  rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2);
-  rb_define_singleton_method(mWhisper, "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1);
-
-  rb_define_alloc_func(cContext, ruby_whisper_allocate);
-  rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1);
-
-  rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1);
-  rb_define_method(cContext, "model_n_vocab", ruby_whisper_model_n_vocab, 0);
-  rb_define_method(cContext, "model_n_audio_ctx", ruby_whisper_model_n_audio_ctx, 0);
-  rb_define_method(cContext, "model_n_audio_state", ruby_whisper_model_n_audio_state, 0);
-  rb_define_method(cContext, "model_n_audio_head", ruby_whisper_model_n_audio_head, 0);
-  rb_define_method(cContext, "model_n_audio_layer", ruby_whisper_model_n_audio_layer, 0);
-  rb_define_method(cContext, "model_n_text_ctx", ruby_whisper_model_n_text_ctx, 0);
-  rb_define_method(cContext, "model_n_text_state", ruby_whisper_model_n_text_state, 0);
-  rb_define_method(cContext, "model_n_text_head", ruby_whisper_model_n_text_head, 0);
-  rb_define_method(cContext, "model_n_text_layer", ruby_whisper_model_n_text_layer, 0);
-  rb_define_method(cContext, "model_n_mels", ruby_whisper_model_n_mels, 0);
-  rb_define_method(cContext, "model_ftype", ruby_whisper_model_ftype, 0);
-  rb_define_method(cContext, "model_type", ruby_whisper_model_type, 0);
-  rb_define_method(cContext, "full_n_segments", ruby_whisper_full_n_segments, 0);
-  rb_define_method(cContext, "full_lang_id", ruby_whisper_full_lang_id, 0);
-  rb_define_method(cContext, "full_get_segment_t0", ruby_whisper_full_get_segment_t0, 1);
-  rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1);
-  rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1);
-  rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1);
-  rb_define_method(cContext, "full_get_segment_no_speech_prob", ruby_whisper_full_get_segment_no_speech_prob, 1);
-  rb_define_method(cContext, "full", ruby_whisper_full, -1);
-  rb_define_method(cContext, "full_parallel", ruby_whisper_full_parallel, -1);
-
-  rb_define_alloc_func(cParams, ruby_whisper_params_allocate);
-
-  rb_define_method(cParams, "language=", ruby_whisper_params_set_language, 1);
-  rb_define_method(cParams, "language", ruby_whisper_params_get_language, 0);
-  rb_define_method(cParams, "translate=", ruby_whisper_params_set_translate, 1);
-  rb_define_method(cParams, "translate", ruby_whisper_params_get_translate, 0);
-  rb_define_method(cParams, "no_context=", ruby_whisper_params_set_no_context, 1);
-  rb_define_method(cParams, "no_context", ruby_whisper_params_get_no_context, 0);
-  rb_define_method(cParams, "single_segment=", ruby_whisper_params_set_single_segment, 1);
-  rb_define_method(cParams, "single_segment", ruby_whisper_params_get_single_segment, 0);
-  rb_define_method(cParams, "print_special", ruby_whisper_params_get_print_special, 0);
-  rb_define_method(cParams, "print_special=", ruby_whisper_params_set_print_special, 1);
-  rb_define_method(cParams, "print_progress", ruby_whisper_params_get_print_progress, 0);
-  rb_define_method(cParams, "print_progress=", ruby_whisper_params_set_print_progress, 1);
-  rb_define_method(cParams, "print_realtime", ruby_whisper_params_get_print_realtime, 0);
-  rb_define_method(cParams, "print_realtime=", ruby_whisper_params_set_print_realtime, 1);
-  rb_define_method(cParams, "print_timestamps", ruby_whisper_params_get_print_timestamps, 0);
-  rb_define_method(cParams, "print_timestamps=", ruby_whisper_params_set_print_timestamps, 1);
-  rb_define_method(cParams, "suppress_blank", ruby_whisper_params_get_suppress_blank, 0);
-  rb_define_method(cParams, "suppress_blank=", ruby_whisper_params_set_suppress_blank, 1);
-  rb_define_method(cParams, "suppress_nst", ruby_whisper_params_get_suppress_nst, 0);
-  rb_define_method(cParams, "suppress_nst=", ruby_whisper_params_set_suppress_nst, 1);
-  rb_define_method(cParams, "token_timestamps", ruby_whisper_params_get_token_timestamps, 0);
-  rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1);
-  rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0);
-  rb_define_method(cParams, "split_on_word=", ruby_whisper_params_set_split_on_word, 1);
-  rb_define_method(cParams, "initial_prompt", ruby_whisper_params_get_initial_prompt, 0);
-  rb_define_method(cParams, "initial_prompt=", ruby_whisper_params_set_initial_prompt, 1);
-  rb_define_method(cParams, "diarize", ruby_whisper_params_get_diarize, 0);
-  rb_define_method(cParams, "diarize=", ruby_whisper_params_set_diarize, 1);
-
-  rb_define_method(cParams, "offset", ruby_whisper_params_get_offset, 0);
-  rb_define_method(cParams, "offset=", ruby_whisper_params_set_offset, 1);
-  rb_define_method(cParams, "duration", ruby_whisper_params_get_duration, 0);
-  rb_define_method(cParams, "duration=", ruby_whisper_params_set_duration, 1);
-
-  rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0);
-  rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1);
-  rb_define_method(cParams, "temperature", ruby_whisper_params_get_temperature, 0);
-  rb_define_method(cParams, "temperature=", ruby_whisper_params_set_temperature, 1);
-  rb_define_method(cParams, "max_initial_ts", ruby_whisper_params_get_max_initial_ts, 0);
-  rb_define_method(cParams, "max_initial_ts=", ruby_whisper_params_set_max_initial_ts, 1);
-  rb_define_method(cParams, "length_penalty", ruby_whisper_params_get_length_penalty, 0);
-  rb_define_method(cParams, "length_penalty=", ruby_whisper_params_set_length_penalty, 1);
-  rb_define_method(cParams, "temperature_inc", ruby_whisper_params_get_temperature_inc, 0);
-  rb_define_method(cParams, "temperature_inc=", ruby_whisper_params_set_temperature_inc, 1);
-  rb_define_method(cParams, "entropy_thold", ruby_whisper_params_get_entropy_thold, 0);
-  rb_define_method(cParams, "entropy_thold=", ruby_whisper_params_set_entropy_thold, 1);
-  rb_define_method(cParams, "logprob_thold", ruby_whisper_params_get_logprob_thold, 0);
-  rb_define_method(cParams, "logprob_thold=", ruby_whisper_params_set_logprob_thold, 1);
-  rb_define_method(cParams, "no_speech_thold", ruby_whisper_params_get_no_speech_thold, 0);
-  rb_define_method(cParams, "no_speech_thold=", ruby_whisper_params_set_no_speech_thold, 1);
-
-  rb_define_method(cParams, "new_segment_callback=", ruby_whisper_params_set_new_segment_callback, 1);
-  rb_define_method(cParams, "new_segment_callback_user_data=", ruby_whisper_params_set_new_segment_callback_user_data, 1);
-  rb_define_method(cParams, "progress_callback=", ruby_whisper_params_set_progress_callback, 1);
-  rb_define_method(cParams, "progress_callback_user_data=", ruby_whisper_params_set_progress_callback_user_data, 1);
-  rb_define_method(cParams, "abort_callback=", ruby_whisper_params_set_abort_callback, 1);
-  rb_define_method(cParams, "abort_callback_user_data=", ruby_whisper_params_set_abort_callback_user_data, 1);
-
-  rb_define_attr(eError, "code", true, false);
-  rb_define_method(eError, "initialize", ruby_whisper_error_initialize, 1);
-
-  // High leve
-  cSegment  = rb_define_class_under(mWhisper, "Segment", rb_cObject);
-
-  rb_define_alloc_func(cSegment, ruby_whisper_segment_allocate);
-  rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0);
-  rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
-  rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
-  rb_define_method(cParams, "abort_on", ruby_whisper_params_abort_on, 0);
-  rb_define_method(cSegment, "start_time", ruby_whisper_segment_get_start_time, 0);
-  rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0);
-  rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0);
-  rb_define_method(cSegment, "text", ruby_whisper_segment_get_text, 0);
-  rb_define_method(cSegment, "no_speech_prob", ruby_whisper_segment_get_no_speech_prob, 0);
-
-  cModel = rb_define_class_under(mWhisper, "Model", rb_cObject);
-  rb_define_alloc_func(cModel, ruby_whisper_model_allocate);
-  rb_define_method(cContext, "model", ruby_whisper_get_model, 0);
-  rb_define_method(cModel, "n_vocab", ruby_whisper_c_model_n_vocab, 0);
-  rb_define_method(cModel, "n_audio_ctx", ruby_whisper_c_model_n_audio_ctx, 0);
-  rb_define_method(cModel, "n_audio_state", ruby_whisper_c_model_n_audio_state, 0);
-  rb_define_method(cModel, "n_audio_head", ruby_whisper_c_model_n_audio_head, 0);
-  rb_define_method(cModel, "n_audio_layer", ruby_whisper_c_model_n_audio_layer, 0);
-  rb_define_method(cModel, "n_text_ctx", ruby_whisper_c_model_n_text_ctx, 0);
-  rb_define_method(cModel, "n_text_state", ruby_whisper_c_model_n_text_state, 0);
-  rb_define_method(cModel, "n_text_head", ruby_whisper_c_model_n_text_head, 0);
-  rb_define_method(cModel, "n_text_layer", ruby_whisper_c_model_n_text_layer, 0);
-  rb_define_method(cModel, "n_mels", ruby_whisper_c_model_n_mels, 0);
-  rb_define_method(cModel, "ftype", ruby_whisper_c_model_ftype, 0);
-  rb_define_method(cModel, "type", ruby_whisper_c_model_type, 0);
-
-  rb_require("whisper/model/uri");
-}
-#ifdef __cplusplus
-}
-#endif
index 21e36c491cf4128c2fa54fe062ecff8b4a8115ce..bbf3435e52c4629ca8eecb471699210ccd3d9ea2 100644 (file)
@@ -22,4 +22,13 @@ typedef struct {
   ruby_whisper_callback_container *abort_callback_container;
 } ruby_whisper_params;
 
+typedef struct {
+  VALUE context;
+  int index;
+} ruby_whisper_segment;
+
+typedef struct {
+  VALUE context;
+} ruby_whisper_model;
+
 #endif
diff --git a/bindings/ruby/ext/ruby_whisper_context.c b/bindings/ruby/ext/ruby_whisper_context.c
new file mode 100644 (file)
index 0000000..df37521
--- /dev/null
@@ -0,0 +1,613 @@
+#include <ruby.h>
+#include <ruby/memory_view.h>
+#include "ruby_whisper.h"
+
+extern ID id_to_s;
+extern ID id___method__;
+extern ID id_to_enum;
+extern ID id_length;
+extern ID id_next;
+extern ID id_new;
+extern ID id_to_path;
+extern ID id_URI;
+extern ID id_pre_converted_models;
+
+extern VALUE cContext;
+extern VALUE eError;
+extern VALUE cModel;
+
+extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self);
+extern VALUE rb_whisper_model_initialize(VALUE context);
+extern VALUE rb_whisper_segment_initialize(VALUE context, int index);
+extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context);
+
+static void
+ruby_whisper_free(ruby_whisper *rw)
+{
+  if (rw->context) {
+    whisper_free(rw->context);
+    rw->context = NULL;
+  }
+}
+
+void
+rb_whisper_mark(ruby_whisper *rw)
+{
+  // call rb_gc_mark on any ruby references in rw
+}
+
+void
+rb_whisper_free(ruby_whisper *rw)
+{
+  ruby_whisper_free(rw);
+  free(rw);
+}
+
+static VALUE
+ruby_whisper_allocate(VALUE klass)
+{
+  ruby_whisper *rw;
+  rw = ALLOC(ruby_whisper);
+  rw->context = NULL;
+  return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw);
+}
+
+/*
+ * call-seq:
+ *   new("base.en") -> Whisper::Context
+ *   new("path/to/model.bin") -> Whisper::Context
+ *   new(Whisper::Model::URI.new("https://example.net/uri/of/model.bin")) -> Whisper::Context
+ */
+static VALUE
+ruby_whisper_initialize(int argc, VALUE *argv, VALUE self)
+{
+  ruby_whisper *rw;
+  VALUE whisper_model_file_path;
+
+  // TODO: we can support init from buffer here too maybe another ruby object to expose
+  rb_scan_args(argc, argv, "01", &whisper_model_file_path);
+  Data_Get_Struct(self, ruby_whisper, rw);
+
+  VALUE pre_converted_models = rb_funcall(cModel, id_pre_converted_models, 0);
+  VALUE pre_converted_model = rb_hash_aref(pre_converted_models, whisper_model_file_path);
+  if (!NIL_P(pre_converted_model)) {
+    whisper_model_file_path = pre_converted_model;
+  }
+  if (TYPE(whisper_model_file_path) == T_STRING) {
+    const char * whisper_model_file_path_str = StringValueCStr(whisper_model_file_path);
+    if (strncmp("http://", whisper_model_file_path_str, 7) == 0 || strncmp("https://", whisper_model_file_path_str, 8) == 0) {
+      VALUE uri_class = rb_const_get(cModel, id_URI);
+      whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class);
+    }
+  }
+  if (rb_obj_is_kind_of(whisper_model_file_path, rb_path2class("URI::HTTP"))) {
+    VALUE uri_class = rb_const_get(cModel, id_URI);
+    whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class);
+  }
+  if (rb_respond_to(whisper_model_file_path, id_to_path)) {
+    whisper_model_file_path = rb_funcall(whisper_model_file_path, id_to_path, 0);
+  }
+  if (!rb_respond_to(whisper_model_file_path, id_to_s)) {
+    rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
+  }
+  rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params());
+  if (rw->context == NULL) {
+    rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context");
+  }
+  return self;
+}
+
+/*
+ * call-seq:
+ *   model_n_vocab -> Integer
+ */
+VALUE ruby_whisper_model_n_vocab(VALUE self)
+{
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_vocab(rw->context));
+}
+
+/*
+ * call-seq:
+ *   model_n_audio_ctx -> Integer
+ */
+VALUE ruby_whisper_model_n_audio_ctx(VALUE self)
+{
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_audio_ctx(rw->context));
+}
+
+/*
+ * call-seq:
+ *   model_n_audio_state -> Integer
+ */
+VALUE ruby_whisper_model_n_audio_state(VALUE self)
+{
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_audio_state(rw->context));
+}
+
+/*
+ * call-seq:
+ *   model_n_audio_head -> Integer
+ */
+VALUE ruby_whisper_model_n_audio_head(VALUE self)
+{
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_audio_head(rw->context));
+}
+
+/*
+ * call-seq:
+ *   model_n_audio_layer -> Integer
+ */
+VALUE ruby_whisper_model_n_audio_layer(VALUE self)
+{
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_audio_layer(rw->context));
+}
+
+/*
+ * call-seq:
+ *   model_n_text_ctx -> Integer
+ */
+VALUE ruby_whisper_model_n_text_ctx(VALUE self)
+{
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_text_ctx(rw->context));
+}
+
+/*
+ * call-seq:
+ *   model_n_text_state -> Integer
+ */
+VALUE ruby_whisper_model_n_text_state(VALUE self)
+{
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_text_state(rw->context));
+}
+
+/*
+ * call-seq:
+ *   model_n_text_head -> Integer
+ */
+VALUE ruby_whisper_model_n_text_head(VALUE self)
+{
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_text_head(rw->context));
+}
+
+/*
+ * call-seq:
+ *   model_n_text_layer -> Integer
+ */
+VALUE ruby_whisper_model_n_text_layer(VALUE self)
+{
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_text_layer(rw->context));
+}
+
+/*
+ * call-seq:
+ *   model_n_mels -> Integer
+ */
+VALUE ruby_whisper_model_n_mels(VALUE self)
+{
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_mels(rw->context));
+}
+
+/*
+ * call-seq:
+ *   model_ftype -> Integer
+ */
+VALUE ruby_whisper_model_ftype(VALUE self)
+{
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  return INT2NUM(whisper_model_ftype(rw->context));
+}
+
+/*
+ * call-seq:
+ *   model_type -> String
+ */
+VALUE ruby_whisper_model_type(VALUE self)
+{
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  return rb_str_new2(whisper_model_type_readable(rw->context));
+}
+
+/*
+ * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
+ * Not thread safe for same context
+ * Uses the specified decoding strategy to obtain the text.
+ *
+ * call-seq:
+ *   full(params, samples, n_samples) -> nil
+ *   full(params, samples) -> nil
+ *
+ * The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data.
+ */
+VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self)
+{
+  if (argc < 2 || argc > 3) {
+    rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
+  }
+
+  ruby_whisper *rw;
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  VALUE params = argv[0];
+  Data_Get_Struct(params, ruby_whisper_params, rwp);
+  VALUE samples = argv[1];
+  int n_samples;
+  rb_memory_view_t view;
+  const bool memory_view_available_p = rb_memory_view_available_p(samples);
+  if (argc == 3) {
+    n_samples = NUM2INT(argv[2]);
+    if (TYPE(samples) == T_ARRAY) {
+      if (RARRAY_LEN(samples) < n_samples) {
+        rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
+      }
+    }
+    // Should check when samples.respond_to?(:length)?
+  } else {
+    if (TYPE(samples) == T_ARRAY) {
+      n_samples = RARRAY_LEN(samples);
+    } else if (memory_view_available_p) {
+      if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
+        view.obj = Qnil;
+        rb_raise(rb_eArgError, "unable to get a memory view");
+      }
+      n_samples = view.byte_size / view.item_size;
+    } else if (rb_respond_to(samples, id_length)) {
+      n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
+    } else {
+      rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
+    }
+  }
+  float * c_samples = (float *)malloc(n_samples * sizeof(float));
+  if (memory_view_available_p)  {
+    c_samples = (float *)view.data;
+  } else {
+    if (TYPE(samples) == T_ARRAY) {
+      for (int i = 0; i < n_samples; i++) {
+        c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
+      }
+    } else {
+      // TODO: use rb_block_call
+      VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
+      for (int i = 0; i < n_samples; i++) {
+        // TODO: check if iter is exhausted and raise ArgumentError appropriately
+        VALUE sample = rb_funcall(iter, id_next, 0);
+        c_samples[i] = RFLOAT_VALUE(sample);
+      }
+    }
+  }
+  register_callbacks(rwp, &self);
+  const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples);
+  if (0 == result) {
+    return self;
+  } else {
+    rb_exc_raise(rb_funcall(eError, id_new, 1, result));
+  }
+}
+
+/*
+ * Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
+ * Result is stored in the default state of the context
+ * Not thread safe if executed in parallel on the same context.
+ * It seems this approach can offer some speedup in some cases.
+ * However, the transcription accuracy can be worse at the beginning and end of each chunk.
+ *
+ * call-seq:
+ *   full_parallel(params, samples) -> nil
+ *   full_parallel(params, samples, n_samples) -> nil
+ *   full_parallel(params, samples, n_samples, n_processors) -> nil
+ *   full_parallel(params, samples, nil, n_processors) -> nil
+ */
+static VALUE
+ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self)
+{
+  if (argc < 2 || argc > 4) {
+    rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
+  }
+
+  ruby_whisper *rw;
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  VALUE params = argv[0];
+  Data_Get_Struct(params, ruby_whisper_params, rwp);
+  VALUE samples = argv[1];
+  int n_samples;
+  int n_processors;
+  rb_memory_view_t view;
+  const bool memory_view_available_p = rb_memory_view_available_p(samples);
+  switch (argc) {
+  case 2:
+    n_processors = 1;
+    break;
+  case 3:
+    n_processors = 1;
+    break;
+  case 4:
+    n_processors = NUM2INT(argv[3]);
+    break;
+  }
+  if (argc >= 3 && !NIL_P(argv[2])) {
+    n_samples = NUM2INT(argv[2]);
+    if (TYPE(samples) == T_ARRAY) {
+      if (RARRAY_LEN(samples) < n_samples) {
+        rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
+      }
+    }
+    // Should check when samples.respond_to?(:length)?
+  } else if (memory_view_available_p) {
+    if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
+      view.obj = Qnil;
+      rb_raise(rb_eArgError, "unable to get a memory view");
+    }
+    n_samples = view.byte_size / view.item_size;
+  } else {
+    if (TYPE(samples) == T_ARRAY) {
+      n_samples = RARRAY_LEN(samples);
+    } else if (rb_respond_to(samples, id_length)) {
+      n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
+    } else {
+      rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
+    }
+  }
+  float * c_samples = (float *)malloc(n_samples * sizeof(float));
+  if (memory_view_available_p) {
+    c_samples = (float *)view.data;
+  } else {
+    if (TYPE(samples) == T_ARRAY) {
+      for (int i = 0; i < n_samples; i++) {
+        c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
+      }
+    } else {
+      // FIXME: use rb_block_call
+      VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
+      for (int i = 0; i < n_samples; i++) {
+        // TODO: check if iter is exhausted and raise ArgumentError
+        VALUE sample = rb_funcall(iter, id_next, 0);
+        c_samples[i] = RFLOAT_VALUE(sample);
+      }
+    }
+  }
+  register_callbacks(rwp, &self);
+  const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors);
+  if (0 == result) {
+    return self;
+  } else {
+    rb_exc_raise(rb_funcall(eError, id_new, 1, result));
+  }
+}
+
+/*
+ * Number of segments.
+ *
+ * call-seq:
+ *   full_n_segments -> Integer
+ */
+static VALUE
+ruby_whisper_full_n_segments(VALUE self)
+{
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  return INT2NUM(whisper_full_n_segments(rw->context));
+}
+
+/*
+ * Language ID, which can be converted to string by Whisper.lang_str and Whisper.lang_str_full.
+ *
+ * call-seq:
+ *   full_lang_id -> Integer
+ */
+static VALUE
+ruby_whisper_full_lang_id(VALUE self)
+{
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  return INT2NUM(whisper_full_lang_id(rw->context));
+}
+
+static int ruby_whisper_full_check_segment_index(const ruby_whisper * rw, const VALUE i_segment)
+{
+  const int c_i_segment = NUM2INT(i_segment);
+  if (c_i_segment < 0 || c_i_segment >= whisper_full_n_segments(rw->context)) {
+    rb_raise(rb_eIndexError, "segment index %d out of range", c_i_segment);
+  }
+  return c_i_segment;
+}
+
+/*
+ * Start time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
+ *
+ *   full_get_segment_t0(3) # => 1668 (16680 ms)
+ *
+ * call-seq:
+ *   full_get_segment_t0(segment_index) -> Integer
+ */
+static VALUE
+ruby_whisper_full_get_segment_t0(VALUE self, VALUE i_segment)
+{
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
+  const int64_t t0 = whisper_full_get_segment_t0(rw->context, c_i_segment);
+  return INT2NUM(t0);
+}
+
+/*
+ * End time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
+ *
+ *   full_get_segment_t1(3) # => 1668 (16680 ms)
+ *
+ * call-seq:
+ *   full_get_segment_t1(segment_index) -> Integer
+ */
+static VALUE
+ruby_whisper_full_get_segment_t1(VALUE self, VALUE i_segment)
+{
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
+  const int64_t t1 = whisper_full_get_segment_t1(rw->context, c_i_segment);
+  return INT2NUM(t1);
+}
+
+/*
+ * Whether the next segment indexed by +segment_index+ is predicated as a speaker turn.
+ *
+ *   full_get_segment_speacker_turn_next(3) # => true
+ *
+ * call-seq:
+ *   full_get_segment_speacker_turn_next(segment_index) -> bool
+ */
+static VALUE
+ruby_whisper_full_get_segment_speaker_turn_next(VALUE self, VALUE i_segment)
+{
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
+  const bool speaker_turn_next = whisper_full_get_segment_speaker_turn_next(rw->context, c_i_segment);
+  return speaker_turn_next ? Qtrue : Qfalse;
+}
+
+/*
+ * Text of a segment indexed by +segment_index+.
+ *
+ *   full_get_segment_text(3) # => "ask not what your country can do for you, ..."
+ *
+ * call-seq:
+ *   full_get_segment_text(segment_index) -> String
+ */
+static VALUE
+ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment)
+{
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
+  const char * text = whisper_full_get_segment_text(rw->context, c_i_segment);
+  return rb_str_new2(text);
+}
+
+/*
+ * call-seq:
+ *   full_get_segment_no_speech_prob(segment_index) -> Float
+ */
+static VALUE
+ruby_whisper_full_get_segment_no_speech_prob(VALUE self, VALUE i_segment)
+{
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
+  const float no_speech_prob = whisper_full_get_segment_no_speech_prob(rw->context, c_i_segment);
+  return DBL2NUM(no_speech_prob);
+}
+
+// High level API
+
+static VALUE
+ruby_whisper_full_get_segment(VALUE self, VALUE i_segment)
+{
+  return rb_whisper_segment_initialize(self, NUM2INT(i_segment));
+}
+
+/*
+ * Yields each Whisper::Segment:
+ *
+ *   whisper.transcribe("path/to/audio.wav", params)
+ *   whisper.each_segment do |segment|
+ *     puts segment.text
+ *   end
+ *
+ * Returns an Enumerator if no block given:
+ *
+ *   whisper.transcribe("path/to/audio.wav", params)
+ *   enum = whisper.each_segment
+ *   enum.to_a # => [#<Whisper::Segment>, ...]
+ *
+ * call-seq:
+ *   each_segment {|segment| ... }
+ *   each_segment -> Enumerator
+ */
+static VALUE
+ruby_whisper_each_segment(VALUE self)
+{
+  if (!rb_block_given_p()) {
+    const VALUE method_name = rb_funcall(self, id___method__, 0);
+    return rb_funcall(self, id_to_enum, 1, method_name);
+  }
+
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+
+  const int n_segments = whisper_full_n_segments(rw->context);
+  for (int i = 0; i < n_segments; ++i) {
+    rb_yield(rb_whisper_segment_initialize(self, i));
+  }
+
+  return self;
+}
+
+/*
+ * call-seq:
+ *   model -> Whisper::Model
+ */
+static VALUE
+ruby_whisper_get_model(VALUE self)
+{
+  return rb_whisper_model_initialize(self);
+}
+
+void
+init_ruby_whisper_context(VALUE *mWhisper)
+{
+  cContext = rb_define_class_under(*mWhisper, "Context", rb_cObject);
+
+  rb_define_alloc_func(cContext, ruby_whisper_allocate);
+  rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1);
+
+  rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1);
+  rb_define_method(cContext, "model_n_vocab", ruby_whisper_model_n_vocab, 0);
+  rb_define_method(cContext, "model_n_audio_ctx", ruby_whisper_model_n_audio_ctx, 0);
+  rb_define_method(cContext, "model_n_audio_state", ruby_whisper_model_n_audio_state, 0);
+  rb_define_method(cContext, "model_n_audio_head", ruby_whisper_model_n_audio_head, 0);
+  rb_define_method(cContext, "model_n_audio_layer", ruby_whisper_model_n_audio_layer, 0);
+  rb_define_method(cContext, "model_n_text_ctx", ruby_whisper_model_n_text_ctx, 0);
+  rb_define_method(cContext, "model_n_text_state", ruby_whisper_model_n_text_state, 0);
+  rb_define_method(cContext, "model_n_text_head", ruby_whisper_model_n_text_head, 0);
+  rb_define_method(cContext, "model_n_text_layer", ruby_whisper_model_n_text_layer, 0);
+  rb_define_method(cContext, "model_n_mels", ruby_whisper_model_n_mels, 0);
+  rb_define_method(cContext, "model_ftype", ruby_whisper_model_ftype, 0);
+  rb_define_method(cContext, "model_type", ruby_whisper_model_type, 0);
+  rb_define_method(cContext, "full_n_segments", ruby_whisper_full_n_segments, 0);
+  rb_define_method(cContext, "full_lang_id", ruby_whisper_full_lang_id, 0);
+  rb_define_method(cContext, "full_get_segment_t0", ruby_whisper_full_get_segment_t0, 1);
+  rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1);
+  rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1);
+  rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1);
+  rb_define_method(cContext, "full_get_segment_no_speech_prob", ruby_whisper_full_get_segment_no_speech_prob, 1);
+  rb_define_method(cContext, "full", ruby_whisper_full, -1);
+  rb_define_method(cContext, "full_parallel", ruby_whisper_full_parallel, -1);
+
+  // High leve
+  rb_define_method(cContext, "full_get_segment", ruby_whisper_full_get_segment, 1);
+  rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0);
+
+  rb_define_method(cContext, "model", ruby_whisper_get_model, 0);
+}
diff --git a/bindings/ruby/ext/ruby_whisper_error.c b/bindings/ruby/ext/ruby_whisper_error.c
new file mode 100644 (file)
index 0000000..b4dbec0
--- /dev/null
@@ -0,0 +1,52 @@
+#include <ruby.h>
+
+extern VALUE eError;
+
+VALUE ruby_whisper_error_initialize(VALUE self, VALUE code)
+{
+  const int c_code = NUM2INT(code);
+  const char *raw_message;
+  switch (c_code) {
+  case -2:
+    raw_message = "failed to compute log mel spectrogram";
+    break;
+  case -3:
+    raw_message = "failed to auto-detect language";
+    break;
+  case -4:
+    raw_message = "too many decoders requested";
+    break;
+  case -5:
+    raw_message = "audio_ctx is larger than the maximum allowed";
+    break;
+  case -6:
+    raw_message = "failed to encode";
+    break;
+  case -7:
+    raw_message = "whisper_kv_cache_init() failed for self-attention cache";
+    break;
+  case -8:
+    raw_message = "failed to decode";
+    break;
+  case -9:
+    raw_message = "failed to decode";
+    break;
+  default:
+    raw_message = "unknown error";
+    break;
+  }
+  const VALUE message = rb_str_new2(raw_message);
+  rb_call_super(1, &message);
+  rb_iv_set(self, "@code", code);
+
+  return self;
+}
+
+void
+init_ruby_whisper_error(VALUE *mWhisper)
+{
+  eError = rb_define_class_under(*mWhisper, "Error", rb_eStandardError);
+
+  rb_define_attr(eError, "code", true, false);
+  rb_define_method(eError, "initialize", ruby_whisper_error_initialize, 1);
+}
diff --git a/bindings/ruby/ext/ruby_whisper_model.c b/bindings/ruby/ext/ruby_whisper_model.c
new file mode 100644 (file)
index 0000000..1e0648f
--- /dev/null
@@ -0,0 +1,210 @@
+#include <ruby.h>
+#include "ruby_whisper.h"
+
+extern VALUE cModel;
+
+static void rb_whisper_model_mark(ruby_whisper_model *rwm) {
+  rb_gc_mark(rwm->context);
+}
+
+static VALUE ruby_whisper_model_allocate(VALUE klass) {
+  ruby_whisper_model *rwm;
+  rwm = ALLOC(ruby_whisper_model);
+  return Data_Wrap_Struct(klass, rb_whisper_model_mark, RUBY_DEFAULT_FREE, rwm);
+}
+
+VALUE rb_whisper_model_initialize(VALUE context) {
+  ruby_whisper_model *rwm;
+  const VALUE model = ruby_whisper_model_allocate(cModel);
+  Data_Get_Struct(model, ruby_whisper_model, rwm);
+  rwm->context = context;
+  return model;
+};
+
+/*
+ * call-seq:
+ *   n_vocab -> Integer
+ */
+static VALUE
+ruby_whisper_model_n_vocab(VALUE self)
+{
+  ruby_whisper_model *rwm;
+  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  ruby_whisper *rw;
+  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_vocab(rw->context));
+}
+
+/*
+ * call-seq:
+ *   n_audio_ctx -> Integer
+ */
+static VALUE
+ruby_whisper_model_n_audio_ctx(VALUE self)
+{
+  ruby_whisper_model *rwm;
+  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  ruby_whisper *rw;
+  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_audio_ctx(rw->context));
+}
+
+/*
+ * call-seq:
+ *   n_audio_state -> Integer
+ */
+static VALUE
+ruby_whisper_model_n_audio_state(VALUE self)
+{
+  ruby_whisper_model *rwm;
+  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  ruby_whisper *rw;
+  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_audio_state(rw->context));
+}
+
+/*
+ * call-seq:
+ *   n_audio_head -> Integer
+ */
+static VALUE
+ruby_whisper_model_n_audio_head(VALUE self)
+{
+  ruby_whisper_model *rwm;
+  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  ruby_whisper *rw;
+  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_audio_head(rw->context));
+}
+
+/*
+ * call-seq:
+ *   n_audio_layer -> Integer
+ */
+static VALUE
+ruby_whisper_model_n_audio_layer(VALUE self)
+{
+  ruby_whisper_model *rwm;
+  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  ruby_whisper *rw;
+  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_audio_layer(rw->context));
+}
+
+/*
+ * call-seq:
+ *   n_text_ctx -> Integer
+ */
+static VALUE
+ruby_whisper_model_n_text_ctx(VALUE self)
+{
+  ruby_whisper_model *rwm;
+  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  ruby_whisper *rw;
+  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_text_ctx(rw->context));
+}
+
+/*
+ * call-seq:
+ *   n_text_state -> Integer
+ */
+static VALUE
+ruby_whisper_model_n_text_state(VALUE self)
+{
+  ruby_whisper_model *rwm;
+  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  ruby_whisper *rw;
+  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_text_state(rw->context));
+}
+
+/*
+ * call-seq:
+ *   n_text_head -> Integer
+ */
+static VALUE
+ruby_whisper_model_n_text_head(VALUE self)
+{
+  ruby_whisper_model *rwm;
+  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  ruby_whisper *rw;
+  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_text_head(rw->context));
+}
+
+/*
+ * call-seq:
+ *   n_text_layer -> Integer
+ */
+static VALUE
+ruby_whisper_model_n_text_layer(VALUE self)
+{
+  ruby_whisper_model *rwm;
+  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  ruby_whisper *rw;
+  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_text_layer(rw->context));
+}
+
+/*
+ * call-seq:
+ *   n_mels -> Integer
+ */
+static VALUE
+ruby_whisper_model_n_mels(VALUE self)
+{
+  ruby_whisper_model *rwm;
+  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  ruby_whisper *rw;
+  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_mels(rw->context));
+}
+
+/*
+ * call-seq:
+ *   ftype -> Integer
+ */
+static VALUE
+ruby_whisper_model_ftype(VALUE self)
+{
+  ruby_whisper_model *rwm;
+  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  ruby_whisper *rw;
+  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  return INT2NUM(whisper_model_ftype(rw->context));
+}
+
+/*
+ * call-seq:
+ *   type -> String
+ */
+static VALUE
+ruby_whisper_model_type(VALUE self)
+{
+  ruby_whisper_model *rwm;
+  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  ruby_whisper *rw;
+  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  return rb_str_new2(whisper_model_type_readable(rw->context));
+}
+
+void
+init_ruby_whisper_model(VALUE *mWhisper)
+{
+  cModel = rb_define_class_under(*mWhisper, "Model", rb_cObject);
+
+  rb_define_alloc_func(cModel, ruby_whisper_model_allocate);
+  rb_define_method(cModel, "n_vocab", ruby_whisper_model_n_vocab, 0);
+  rb_define_method(cModel, "n_audio_ctx", ruby_whisper_model_n_audio_ctx, 0);
+  rb_define_method(cModel, "n_audio_state", ruby_whisper_model_n_audio_state, 0);
+  rb_define_method(cModel, "n_audio_head", ruby_whisper_model_n_audio_head, 0);
+  rb_define_method(cModel, "n_audio_layer", ruby_whisper_model_n_audio_layer, 0);
+  rb_define_method(cModel, "n_text_ctx", ruby_whisper_model_n_text_ctx, 0);
+  rb_define_method(cModel, "n_text_state", ruby_whisper_model_n_text_state, 0);
+  rb_define_method(cModel, "n_text_head", ruby_whisper_model_n_text_head, 0);
+  rb_define_method(cModel, "n_text_layer", ruby_whisper_model_n_text_layer, 0);
+  rb_define_method(cModel, "n_mels", ruby_whisper_model_n_mels, 0);
+  rb_define_method(cModel, "ftype", ruby_whisper_model_ftype, 0);
+  rb_define_method(cModel, "type", ruby_whisper_model_type, 0);
+}
diff --git a/bindings/ruby/ext/ruby_whisper_params.c b/bindings/ruby/ext/ruby_whisper_params.c
new file mode 100644 (file)
index 0000000..0446db3
--- /dev/null
@@ -0,0 +1,1077 @@
+#include <ruby.h>
+#include "ruby_whisper.h"
+
+#define BOOL_PARAMS_SETTER(self, prop, value) \
+  ruby_whisper_params *rwp; \
+  Data_Get_Struct(self, ruby_whisper_params, rwp); \
+  if (value == Qfalse || value == Qnil) { \
+    rwp->params.prop = false; \
+  } else { \
+    rwp->params.prop = true; \
+  } \
+  return value; \
+
+#define BOOL_PARAMS_GETTER(self,  prop) \
+  ruby_whisper_params *rwp; \
+  Data_Get_Struct(self, ruby_whisper_params, rwp); \
+  if (rwp->params.prop) { \
+    return Qtrue; \
+  } else { \
+    return Qfalse; \
+  }
+
+#define DEFINE_PARAM(param_name, nth) \
+  id_ ## param_name = rb_intern(#param_name); \
+  param_names[nth] = id_ ## param_name; \
+  rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \
+  rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1);
+
+#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 30
+
+extern VALUE cParams;
+
+extern ID id_call;
+
+extern VALUE rb_whisper_segment_initialize(VALUE context, int index);
+
+static ID param_names[RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT];
+static ID id_language;
+static ID id_translate;
+static ID id_no_context;
+static ID id_single_segment;
+static ID id_print_special;
+static ID id_print_progress;
+static ID id_print_realtime;
+static ID id_print_timestamps;
+static ID id_suppress_blank;
+static ID id_suppress_nst;
+static ID id_token_timestamps;
+static ID id_split_on_word;
+static ID id_initial_prompt;
+static ID id_diarize;
+static ID id_offset;
+static ID id_duration;
+static ID id_max_text_tokens;
+static ID id_temperature;
+static ID id_max_initial_ts;
+static ID id_length_penalty;
+static ID id_temperature_inc;
+static ID id_entropy_thold;
+static ID id_logprob_thold;
+static ID id_no_speech_thold;
+static ID id_new_segment_callback;
+static ID id_new_segment_callback_user_data;
+static ID id_progress_callback;
+static ID id_progress_callback_user_data;
+static ID id_abort_callback;
+static ID id_abort_callback_user_data;
+
+static void
+rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc)
+{
+  rb_gc_mark(rwc->user_data);
+  rb_gc_mark(rwc->callback);
+  rb_gc_mark(rwc->callbacks);
+}
+
+static ruby_whisper_callback_container*
+rb_whisper_callback_container_allocate() {
+  ruby_whisper_callback_container *container;
+  container = ALLOC(ruby_whisper_callback_container);
+  container->context = NULL;
+  container->user_data = Qnil;
+  container->callback = Qnil;
+  container->callbacks = rb_ary_new();
+  return container;
+}
+
+static void new_segment_callback(struct whisper_context *ctx, struct whisper_state *state, int n_new, void *user_data) {
+  const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
+
+  // Currently, doesn't support state because
+  // those require to resolve GC-related problems.
+  if (!NIL_P(container->callback)) {
+    rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data);
+  }
+  const long callbacks_len = RARRAY_LEN(container->callbacks);
+  if (0 == callbacks_len) {
+    return;
+  }
+  const int n_segments = whisper_full_n_segments_from_state(state);
+  for (int i = n_new; i > 0; i--) {
+    int i_segment = n_segments - i;
+    VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment);
+    for (int j = 0; j < callbacks_len; j++) {
+      VALUE cb = rb_ary_entry(container->callbacks, j);
+      rb_funcall(cb, id_call, 1, segment);
+    }
+  }
+}
+
+static void progress_callback(struct whisper_context *ctx, struct whisper_state *state, int progress_cur, void *user_data) {
+  const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
+  const VALUE progress = INT2NUM(progress_cur);
+  // Currently, doesn't support state because
+  // those require to resolve GC-related problems.
+  if (!NIL_P(container->callback)) {
+    rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data);
+  }
+  const long callbacks_len = RARRAY_LEN(container->callbacks);
+  if (0 == callbacks_len) {
+    return;
+  }
+  for (int j = 0; j < callbacks_len; j++) {
+    VALUE cb = rb_ary_entry(container->callbacks, j);
+    rb_funcall(cb, id_call, 1, progress);
+  }
+}
+
+static bool abort_callback(void * user_data) {
+  const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
+  if (!NIL_P(container->callback)) {
+    VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data);
+    if (!NIL_P(result) && Qfalse != result) {
+      return true;
+    }
+  }
+  const long callbacks_len = RARRAY_LEN(container->callbacks);
+  if (0 == callbacks_len) {
+    return false;
+  }
+  for (int j = 0; j < callbacks_len; j++) {
+    VALUE cb = rb_ary_entry(container->callbacks, j);
+    VALUE result = rb_funcall(cb, id_call, 1, container->user_data);
+    if (!NIL_P(result) && Qfalse != result) {
+      return true;
+    }
+  }
+  return false;
+}
+
+void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
+  if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) {
+    rwp->new_segment_callback_container->context = context;
+    rwp->params.new_segment_callback = new_segment_callback;
+    rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container;
+  }
+
+  if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) {
+    rwp->progress_callback_container->context = context;
+    rwp->params.progress_callback = progress_callback;
+    rwp->params.progress_callback_user_data = rwp->progress_callback_container;
+  }
+
+  if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
+    rwp->abort_callback_container->context = context;
+    rwp->params.abort_callback = abort_callback;
+    rwp->params.abort_callback_user_data = rwp->abort_callback_container;
+  }
+}
+
+void
+rb_whisper_params_mark(ruby_whisper_params *rwp)
+{
+  rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container);
+  rb_whisper_callbcack_container_mark(rwp->progress_callback_container);
+  rb_whisper_callbcack_container_mark(rwp->abort_callback_container);
+}
+
+void
+ruby_whisper_params_free(ruby_whisper_params *rwp)
+{
+}
+
+void
+rb_whisper_params_free(ruby_whisper_params *rwp)
+{
+  // How to free user_data and callback only when not referred to by others?
+  ruby_whisper_params_free(rwp);
+  free(rwp);
+}
+
+static VALUE
+ruby_whisper_params_allocate(VALUE klass)
+{
+  ruby_whisper_params *rwp;
+  rwp = ALLOC(ruby_whisper_params);
+  rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
+  rwp->diarize = false;
+  rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
+  rwp->progress_callback_container = rb_whisper_callback_container_allocate();
+  rwp->abort_callback_container = rb_whisper_callback_container_allocate();
+  return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
+}
+
+/*
+ * params.language = "auto" | "en", etc...
+ *
+ * call-seq:
+ *   language = lang_name -> lang_name
+ */
+static VALUE
+ruby_whisper_params_set_language(VALUE self, VALUE value)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  if (value == Qfalse || value == Qnil) {
+    rwp->params.language = "auto";
+  } else {
+    rwp->params.language = StringValueCStr(value);
+  }
+  return value;
+}
+/*
+ * call-seq:
+ *   language -> String
+ */
+static VALUE
+ruby_whisper_params_get_language(VALUE self)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  if (rwp->params.language) {
+    return rb_str_new2(rwp->params.language);
+  } else {
+    return rb_str_new2("auto");
+  }
+}
+/*
+ * call-seq:
+ *   translate = do_translate -> do_translate
+ */
+static VALUE
+ruby_whisper_params_set_translate(VALUE self, VALUE value)
+{
+  BOOL_PARAMS_SETTER(self, translate, value)
+}
+/*
+ * call-seq:
+ *   translate -> bool
+ */
+static VALUE
+ruby_whisper_params_get_translate(VALUE self)
+{
+  BOOL_PARAMS_GETTER(self, translate)
+}
+/*
+ * call-seq:
+ *   no_context = dont_use_context -> dont_use_context
+ */
+static VALUE
+ruby_whisper_params_set_no_context(VALUE self, VALUE value)
+{
+  BOOL_PARAMS_SETTER(self, no_context, value)
+}
+/*
+ * If true, does not use past transcription (if any) as initial prompt for the decoder.
+ *
+ * call-seq:
+ *   no_context -> bool
+ */
+static VALUE
+ruby_whisper_params_get_no_context(VALUE self)
+{
+  BOOL_PARAMS_GETTER(self, no_context)
+}
+/*
+ * call-seq:
+ *   single_segment = force_single -> force_single
+ */
+static VALUE
+ruby_whisper_params_set_single_segment(VALUE self, VALUE value)
+{
+  BOOL_PARAMS_SETTER(self, single_segment, value)
+}
+/*
+ * If true, forces single segment output (useful for streaming).
+ *
+ * call-seq:
+ *   single_segment -> bool
+ */
+static VALUE
+ruby_whisper_params_get_single_segment(VALUE self)
+{
+  BOOL_PARAMS_GETTER(self, single_segment)
+}
+/*
+ * call-seq:
+ *   print_special = force_print -> force_print
+ */
+static VALUE
+ruby_whisper_params_set_print_special(VALUE self, VALUE value)
+{
+  BOOL_PARAMS_SETTER(self, print_special, value)
+}
+/*
+ * If true, prints special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.).
+ *
+ * call-seq:
+ *   print_special -> bool
+ */
+static VALUE
+ruby_whisper_params_get_print_special(VALUE self)
+{
+  BOOL_PARAMS_GETTER(self, print_special)
+}
+/*
+ * call-seq:
+ *   print_progress = force_print -> force_print
+ */
+static VALUE
+ruby_whisper_params_set_print_progress(VALUE self, VALUE value)
+{
+  BOOL_PARAMS_SETTER(self, print_progress, value)
+}
+/*
+ * If true, prints progress information.
+ *
+ * call-seq:
+ *   print_progress -> bool
+ */
+static VALUE
+ruby_whisper_params_get_print_progress(VALUE self)
+{
+  BOOL_PARAMS_GETTER(self, print_progress)
+}
+/*
+ * call-seq:
+ *   print_realtime = force_print -> force_print
+ */
+static VALUE
+ruby_whisper_params_set_print_realtime(VALUE self, VALUE value)
+{
+  BOOL_PARAMS_SETTER(self, print_realtime, value)
+}
+/*
+ * If true, prints results from within whisper.cpp. (avoid it, use callback instead)
+ * call-seq:
+ *   print_realtime -> bool
+ */
+static VALUE
+ruby_whisper_params_get_print_realtime(VALUE self)
+{
+  BOOL_PARAMS_GETTER(self, print_realtime)
+}
+/*
+ * call-seq:
+ *   print_timestamps = force_print -> force_print
+ */
+static VALUE
+ruby_whisper_params_set_print_timestamps(VALUE self, VALUE value)
+{
+  BOOL_PARAMS_SETTER(self, print_timestamps, value)
+}
+/*
+ * If true, prints timestamps for each text segment when printing realtime.
+ *
+ * call-seq:
+ *   print_timestamps -> bool
+ */
+static VALUE
+ruby_whisper_params_get_print_timestamps(VALUE self)
+{
+  BOOL_PARAMS_GETTER(self, print_timestamps)
+}
+/*
+ * call-seq:
+ *   suppress_blank = force_suppress -> force_suppress
+ */
+static VALUE
+ruby_whisper_params_set_suppress_blank(VALUE self, VALUE value)
+{
+  BOOL_PARAMS_SETTER(self, suppress_blank, value)
+}
+/*
+ * If true, suppresses blank outputs.
+ *
+ * call-seq:
+ *   suppress_blank -> bool
+ */
+static VALUE
+ruby_whisper_params_get_suppress_blank(VALUE self)
+{
+  BOOL_PARAMS_GETTER(self, suppress_blank)
+}
+/*
+ * call-seq:
+ *   suppress_nst = force_suppress -> force_suppress
+ */
+static VALUE
+ruby_whisper_params_set_suppress_nst(VALUE self, VALUE value)
+{
+  BOOL_PARAMS_SETTER(self, suppress_nst, value)
+}
+/*
+ * If true, suppresses non-speech-tokens.
+ *
+ * call-seq:
+ *   suppress_nst -> bool
+ */
+static VALUE
+ruby_whisper_params_get_suppress_nst(VALUE self)
+{
+  BOOL_PARAMS_GETTER(self, suppress_nst)
+}
+/*
+ * If true, enables token-level timestamps.
+ *
+ * call-seq:
+ *   token_timestamps -> bool
+ */
+static VALUE
+ruby_whisper_params_get_token_timestamps(VALUE self)
+{
+  BOOL_PARAMS_GETTER(self, token_timestamps)
+}
+/*
+ * call-seq:
+ *   token_timestamps = force_timestamps -> force_timestamps
+ */
+static VALUE
+ruby_whisper_params_set_token_timestamps(VALUE self, VALUE value)
+{
+  BOOL_PARAMS_SETTER(self, token_timestamps, value)
+}
+/*
+ * If true, split on word rather than on token (when used with max_len).
+ *
+ * call-seq:
+ *   translate -> bool
+ */
+static VALUE
+ruby_whisper_params_get_split_on_word(VALUE self)
+{
+  BOOL_PARAMS_GETTER(self, split_on_word)
+}
+/*
+ * call-seq:
+ *   split_on_word = force_split -> force_split
+ */
+static VALUE
+ruby_whisper_params_set_split_on_word(VALUE self, VALUE value)
+{
+  BOOL_PARAMS_SETTER(self, split_on_word, value)
+}
+/*
+ * Tokens to provide to the whisper decoder as initial prompt
+ * these are prepended to any existing text context from a previous call
+ * use whisper_tokenize() to convert text to tokens.
+ * Maximum of whisper_n_text_ctx()/2 tokens are used (typically 224).
+ *
+ * call-seq:
+ *   initial_prompt -> String
+ */
+static VALUE
+ruby_whisper_params_get_initial_prompt(VALUE self)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  return rwp->params.initial_prompt == NULL ? Qnil : rb_str_new2(rwp->params.initial_prompt);
+}
+/*
+ * call-seq:
+ *   initial_prompt = prompt -> prompt
+ */
+static VALUE
+ruby_whisper_params_set_initial_prompt(VALUE self, VALUE value)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  rwp->params.initial_prompt = StringValueCStr(value);
+  return value;
+}
+/*
+ * If true, enables diarization.
+ *
+ * call-seq:
+ *   diarize -> bool
+ */
+static VALUE
+ruby_whisper_params_get_diarize(VALUE self)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  if (rwp->diarize) {
+    return Qtrue;
+  } else {
+    return Qfalse;
+  }
+}
+/*
+ * call-seq:
+ *   diarize = force_diarize -> force_diarize
+ */
+static VALUE
+ruby_whisper_params_set_diarize(VALUE self, VALUE value)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  if (value == Qfalse || value == Qnil) {
+    rwp->diarize = false;
+  } else {
+    rwp->diarize = true;
+  } \
+  return value;
+}
+
+/*
+ * Start offset in ms.
+ *
+ * call-seq:
+ *   offset -> Integer
+ */
+static VALUE
+ruby_whisper_params_get_offset(VALUE self)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  return INT2NUM(rwp->params.offset_ms);
+}
+/*
+ * call-seq:
+ *   offset = offset_ms -> offset_ms
+ */
+static VALUE
+ruby_whisper_params_set_offset(VALUE self, VALUE value)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  rwp->params.offset_ms = NUM2INT(value);
+  return value;
+}
+/*
+ * Audio duration to process in ms.
+ *
+ * call-seq:
+ *   duration -> Integer
+ */
+static VALUE
+ruby_whisper_params_get_duration(VALUE self)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  return INT2NUM(rwp->params.duration_ms);
+}
+/*
+ * call-seq:
+ *   duration = duration_ms -> duration_ms
+ */
+static VALUE
+ruby_whisper_params_set_duration(VALUE self, VALUE value)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  rwp->params.duration_ms = NUM2INT(value);
+  return value;
+}
+
+/*
+ * Max tokens to use from past text as prompt for the decoder.
+ *
+ * call-seq:
+ *   max_text_tokens -> Integer
+ */
+static VALUE
+ruby_whisper_params_get_max_text_tokens(VALUE self)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  return INT2NUM(rwp->params.n_max_text_ctx);
+}
+/*
+ * call-seq:
+ *   max_text_tokens = n_tokens -> n_tokens
+ */
+static VALUE
+ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  rwp->params.n_max_text_ctx = NUM2INT(value);
+  return value;
+}
+/*
+ * call-seq:
+ *   temperature -> Float
+ */
+static VALUE
+ruby_whisper_params_get_temperature(VALUE self)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  return DBL2NUM(rwp->params.temperature);
+}
+/*
+ * call-seq:
+ *   temperature = temp -> temp
+ */
+static VALUE
+ruby_whisper_params_set_temperature(VALUE self, VALUE value)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  rwp->params.temperature = RFLOAT_VALUE(value);
+  return value;
+}
+/*
+ * See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
+ *
+ * call-seq:
+ *   max_initial_ts -> Flaot
+ */
+static VALUE
+ruby_whisper_params_get_max_initial_ts(VALUE self)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  return DBL2NUM(rwp->params.max_initial_ts);
+}
+/*
+ * call-seq:
+ *   max_initial_ts = timestamp -> timestamp
+ */
+static VALUE
+ruby_whisper_params_set_max_initial_ts(VALUE self, VALUE value)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  rwp->params.max_initial_ts = RFLOAT_VALUE(value);
+  return value;
+}
+/*
+ * call-seq:
+ *   length_penalty -> Float
+ */
+static VALUE
+ruby_whisper_params_get_length_penalty(VALUE self)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  return DBL2NUM(rwp->params.length_penalty);
+}
+/*
+ * call-seq:
+ *   length_penalty = penalty -> penalty
+ */
+static VALUE
+ruby_whisper_params_set_length_penalty(VALUE self, VALUE value)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  rwp->params.length_penalty = RFLOAT_VALUE(value);
+  return value;
+}
+/*
+ * call-seq:
+ *   temperature_inc -> Float
+ */
+static VALUE
+ruby_whisper_params_get_temperature_inc(VALUE self)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  return DBL2NUM(rwp->params.temperature_inc);
+}
+/*
+ * call-seq:
+ *   temperature_inc = inc -> inc
+ */
+static VALUE
+ruby_whisper_params_set_temperature_inc(VALUE self, VALUE value)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  rwp->params.temperature_inc = RFLOAT_VALUE(value);
+  return value;
+}
+/*
+ * Similar to OpenAI's "compression_ratio_threshold"
+ *
+ * call-seq:
+ *   entropy_thold -> Float
+ */
+static VALUE
+ruby_whisper_params_get_entropy_thold(VALUE self)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  return DBL2NUM(rwp->params.entropy_thold);
+}
+/*
+ * call-seq:
+ *   entropy_thold = threshold -> threshold
+ */
+static VALUE
+ruby_whisper_params_set_entropy_thold(VALUE self, VALUE value)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  rwp->params.entropy_thold = RFLOAT_VALUE(value);
+  return value;
+}
+/*
+ * call-seq:
+ *   logprob_thold -> Float
+ */
+static VALUE
+ruby_whisper_params_get_logprob_thold(VALUE self)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  return DBL2NUM(rwp->params.logprob_thold);
+}
+/*
+ * call-seq:
+ *   logprob_thold = threshold -> threshold
+ */
+static VALUE
+ruby_whisper_params_set_logprob_thold(VALUE self, VALUE value)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  rwp->params.logprob_thold = RFLOAT_VALUE(value);
+  return value;
+}
+/*
+ * call-seq:
+ *   no_speech_thold -> Float
+ */
+static VALUE
+ruby_whisper_params_get_no_speech_thold(VALUE self)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  return DBL2NUM(rwp->params.no_speech_thold);
+}
+/*
+ * call-seq:
+ *   no_speech_thold = threshold -> threshold
+ */
+static VALUE
+ruby_whisper_params_set_no_speech_thold(VALUE self, VALUE value)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  rwp->params.no_speech_thold = RFLOAT_VALUE(value);
+  return value;
+}
+static VALUE
+ruby_whisper_params_get_new_segment_callback(VALUE self)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  return rwp->new_segment_callback_container->callback;
+}
+/*
+ * Sets new segment callback, called for every newly generated text segment.
+ *
+ *   params.new_segment_callback = ->(context, _, n_new, user_data) {
+ *     # ...
+ *   }
+ *
+ * call-seq:
+ *   new_segment_callback = callback -> callback
+ */
+static VALUE
+ruby_whisper_params_set_new_segment_callback(VALUE self, VALUE value)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  rwp->new_segment_callback_container->callback = value;
+  return value;
+}
+static VALUE
+ruby_whisper_params_get_new_segment_callback_user_data(VALUE self)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  return rwp->new_segment_callback_container->user_data;
+}
+/*
+ * Sets user data passed to the last argument of new segment callback.
+ *
+ * call-seq:
+ *   new_segment_callback_user_data = user_data -> use_data
+ */
+static VALUE
+ruby_whisper_params_set_new_segment_callback_user_data(VALUE self, VALUE value)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  rwp->new_segment_callback_container->user_data = value;
+  return value;
+}
+static VALUE
+ruby_whisper_params_get_progress_callback(VALUE self)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  return rwp->progress_callback_container->callback;
+}
+/*
+ * Sets progress callback, called on each progress update.
+ *
+ *   params.new_segment_callback = ->(context, _, progress, user_data) {
+ *     # ...
+ *   }
+ *
+ * +progress+ is an Integer between 0 and 100.
+ *
+ * call-seq:
+ *   progress_callback = callback -> callback
+ */
+static VALUE
+ruby_whisper_params_set_progress_callback(VALUE self, VALUE value)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  rwp->progress_callback_container->callback = value;
+  return value;
+}
+static VALUE
+ruby_whisper_params_get_progress_callback_user_data(VALUE self)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  return rwp->progress_callback_container->user_data;
+}
+/*
+ * Sets user data passed to the last argument of progress callback.
+ *
+ * call-seq:
+ *   progress_callback_user_data = user_data -> use_data
+ */
+static VALUE
+ruby_whisper_params_set_progress_callback_user_data(VALUE self, VALUE value)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  rwp->progress_callback_container->user_data = value;
+  return value;
+}
+static VALUE
+ruby_whisper_params_get_abort_callback(VALUE self)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  return rwp->abort_callback_container->callback;
+}
+/*
+ * Sets abort callback, called to check if the process should be aborted.
+ *
+ *   params.abort_callback = ->(user_data) {
+ *     # ...
+ *   }
+ *
+ * call-seq:
+ *   abort_callback = callback -> callback
+ */
+static VALUE
+ruby_whisper_params_set_abort_callback(VALUE self, VALUE value)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  rwp->abort_callback_container->callback = value;
+  return value;
+}
+static VALUE
+ruby_whisper_params_get_abort_callback_user_data(VALUE self)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  return rwp->abort_callback_container->user_data;
+}
+/*
+ * Sets user data passed to the last argument of abort callback.
+ *
+ * call-seq:
+ *   abort_callback_user_data = user_data -> use_data
+ */
+static VALUE
+ruby_whisper_params_set_abort_callback_user_data(VALUE self, VALUE value)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  rwp->abort_callback_container->user_data = value;
+  return value;
+}
+
+#define SET_PARAM_IF_SAME(param_name) \
+  if (id == id_ ## param_name) { \
+    ruby_whisper_params_set_ ## param_name(self, value); \
+    continue; \
+  }
+
+static VALUE
+ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
+{
+
+  VALUE kw_hash;
+  VALUE values[RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT] = {Qundef};
+  VALUE value;
+  ruby_whisper_params *rwp;
+  ID id;
+  int i;
+
+  rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash);
+  if (NIL_P(kw_hash)) {
+    return self;
+  }
+
+  rb_get_kwargs(kw_hash, &param_names, 0, RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT, &values);
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+
+  for (i = 0; i < RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT; i++) {
+    id = param_names[i];
+    value = values[i];
+    if (value == Qundef) {
+      continue;
+    }
+    if (id == id_diarize) {
+      rwp->diarize = value;
+      continue;
+    } else {
+      SET_PARAM_IF_SAME(language)
+      SET_PARAM_IF_SAME(translate)
+      SET_PARAM_IF_SAME(no_context)
+      SET_PARAM_IF_SAME(single_segment)
+      SET_PARAM_IF_SAME(print_special)
+      SET_PARAM_IF_SAME(print_progress)
+      SET_PARAM_IF_SAME(print_realtime)
+      SET_PARAM_IF_SAME(print_timestamps)
+      SET_PARAM_IF_SAME(suppress_blank)
+      SET_PARAM_IF_SAME(suppress_nst)
+      SET_PARAM_IF_SAME(token_timestamps)
+      SET_PARAM_IF_SAME(split_on_word)
+      SET_PARAM_IF_SAME(initial_prompt)
+      SET_PARAM_IF_SAME(offset)
+      SET_PARAM_IF_SAME(duration)
+      SET_PARAM_IF_SAME(max_text_tokens)
+      SET_PARAM_IF_SAME(temperature)
+      SET_PARAM_IF_SAME(max_initial_ts)
+      SET_PARAM_IF_SAME(length_penalty)
+      SET_PARAM_IF_SAME(temperature_inc)
+      SET_PARAM_IF_SAME(entropy_thold)
+      SET_PARAM_IF_SAME(logprob_thold)
+      SET_PARAM_IF_SAME(no_speech_thold)
+      SET_PARAM_IF_SAME(new_segment_callback)
+      SET_PARAM_IF_SAME(new_segment_callback_user_data)
+      SET_PARAM_IF_SAME(progress_callback)
+      SET_PARAM_IF_SAME(progress_callback_user_data)
+      SET_PARAM_IF_SAME(abort_callback)
+      SET_PARAM_IF_SAME(abort_callback_user_data)
+    }
+  }
+
+  return self;
+}
+
+#undef SET_PARAM_IF_SAME
+
+/*
+ * Hook called on new segment. Yields each Whisper::Segment.
+ *
+ *   whisper.on_new_segment do |segment|
+ *     # ...
+ *   end
+ *
+ * call-seq:
+ *   on_new_segment {|segment| ... }
+ */
+static VALUE
+ruby_whisper_params_on_new_segment(VALUE self)
+{
+  ruby_whisper_params *rws;
+  Data_Get_Struct(self, ruby_whisper_params, rws);
+  const VALUE blk = rb_block_proc();
+  rb_ary_push(rws->new_segment_callback_container->callbacks, blk);
+  return Qnil;
+}
+
+/*
+ * Hook called on progress update. Yields each progress Integer between 0 and 100.
+ *
+ *   whisper.on_progress do |progress|
+ *     # ...
+ *   end
+ *
+ * call-seq:
+ *   on_progress {|progress| ... }
+ */
+static VALUE
+ruby_whisper_params_on_progress(VALUE self)
+{
+  ruby_whisper_params *rws;
+  Data_Get_Struct(self, ruby_whisper_params, rws);
+  const VALUE blk = rb_block_proc();
+  rb_ary_push(rws->progress_callback_container->callbacks, blk);
+  return Qnil;
+}
+
+/*
+ * Call block to determine whether abort or not. Return +true+ when you want to abort.
+ *
+ *   params.abort_on do
+ *     if some_condition
+ *       true # abort
+ *     else
+ *       false # continue
+ *     end
+ *   end
+ *
+ * call-seq:
+ *   abort_on { ... }
+ */
+static VALUE
+ruby_whisper_params_abort_on(VALUE self)
+{
+  ruby_whisper_params *rws;
+  Data_Get_Struct(self, ruby_whisper_params, rws);
+  const VALUE blk = rb_block_proc();
+  rb_ary_push(rws->abort_callback_container->callbacks, blk);
+  return Qnil;
+}
+
+void
+init_ruby_whisper_params(VALUE *mWhisper)
+{
+  cParams  = rb_define_class_under(*mWhisper, "Params", rb_cObject);
+
+  rb_define_alloc_func(cParams, ruby_whisper_params_allocate);
+  rb_define_method(cParams, "initialize", ruby_whisper_params_initialize, -1);
+
+  DEFINE_PARAM(language, 0)
+  DEFINE_PARAM(translate, 1)
+  DEFINE_PARAM(no_context, 2)
+  DEFINE_PARAM(single_segment, 3)
+  DEFINE_PARAM(print_special, 4)
+  DEFINE_PARAM(print_progress, 5)
+  DEFINE_PARAM(print_realtime, 6)
+  DEFINE_PARAM(print_timestamps, 7)
+  DEFINE_PARAM(suppress_blank, 8)
+  DEFINE_PARAM(suppress_nst, 9)
+  DEFINE_PARAM(token_timestamps, 10)
+  DEFINE_PARAM(split_on_word, 11)
+  DEFINE_PARAM(initial_prompt, 12)
+  DEFINE_PARAM(diarize, 13)
+  DEFINE_PARAM(offset, 14)
+  DEFINE_PARAM(duration, 15)
+  DEFINE_PARAM(max_text_tokens, 16)
+  DEFINE_PARAM(temperature, 17)
+  DEFINE_PARAM(max_initial_ts, 18)
+  DEFINE_PARAM(length_penalty, 19)
+  DEFINE_PARAM(temperature_inc, 20)
+  DEFINE_PARAM(entropy_thold, 21)
+  DEFINE_PARAM(logprob_thold, 22)
+  DEFINE_PARAM(no_speech_thold, 23)
+  DEFINE_PARAM(new_segment_callback, 24)
+  DEFINE_PARAM(new_segment_callback_user_data, 25)
+  DEFINE_PARAM(progress_callback, 26)
+  DEFINE_PARAM(progress_callback_user_data, 27)
+  DEFINE_PARAM(abort_callback, 28)
+  DEFINE_PARAM(abort_callback_user_data, 29)
+
+  rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
+  rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
+  rb_define_method(cParams, "abort_on", ruby_whisper_params_abort_on, 0);
+}
diff --git a/bindings/ruby/ext/ruby_whisper_segment.c b/bindings/ruby/ext/ruby_whisper_segment.c
new file mode 100644 (file)
index 0000000..3440ff9
--- /dev/null
@@ -0,0 +1,123 @@
+#include <ruby.h>
+#include "ruby_whisper.h"
+
+extern VALUE cSegment;
+
+static void
+rb_whisper_segment_mark(ruby_whisper_segment *rws)
+{
+  rb_gc_mark(rws->context);
+}
+
+VALUE
+ruby_whisper_segment_allocate(VALUE klass)
+{
+  ruby_whisper_segment *rws;
+  rws = ALLOC(ruby_whisper_segment);
+  return Data_Wrap_Struct(klass, rb_whisper_segment_mark, RUBY_DEFAULT_FREE, rws);
+}
+
+VALUE
+rb_whisper_segment_initialize(VALUE context, int index)
+{
+  ruby_whisper_segment *rws;
+  const VALUE segment = ruby_whisper_segment_allocate(cSegment);
+  Data_Get_Struct(segment, ruby_whisper_segment, rws);
+  rws->context = context;
+  rws->index = index;
+  return segment;
+};
+
+/*
+ * Start time in milliseconds.
+ *
+ * call-seq:
+ *   start_time -> Integer
+ */
+static VALUE
+ruby_whisper_segment_get_start_time(VALUE self)
+{
+  ruby_whisper_segment *rws;
+  Data_Get_Struct(self, ruby_whisper_segment, rws);
+  ruby_whisper *rw;
+  Data_Get_Struct(rws->context, ruby_whisper, rw);
+  const int64_t t0 = whisper_full_get_segment_t0(rw->context, rws->index);
+  // able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it
+  return INT2NUM(t0 * 10);
+}
+
+/*
+ * End time in milliseconds.
+ *
+ * call-seq:
+ *   end_time -> Integer
+ */
+static VALUE
+ruby_whisper_segment_get_end_time(VALUE self)
+{
+  ruby_whisper_segment *rws;
+  Data_Get_Struct(self, ruby_whisper_segment, rws);
+  ruby_whisper *rw;
+  Data_Get_Struct(rws->context, ruby_whisper, rw);
+  const int64_t t1 = whisper_full_get_segment_t1(rw->context, rws->index);
+  // able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it
+  return INT2NUM(t1 * 10);
+}
+
+/*
+ * Whether the next segment is predicted as a speaker turn.
+ *
+ * call-seq:
+ *   speaker_turn_next? -> bool
+ */
+static VALUE
+ruby_whisper_segment_get_speaker_turn_next(VALUE self)
+{
+  ruby_whisper_segment *rws;
+  Data_Get_Struct(self, ruby_whisper_segment, rws);
+  ruby_whisper *rw;
+  Data_Get_Struct(rws->context, ruby_whisper, rw);
+  return whisper_full_get_segment_speaker_turn_next(rw->context, rws->index) ? Qtrue : Qfalse;
+}
+
+/*
+ * call-seq:
+ *   text -> String
+ */
+static VALUE
+ruby_whisper_segment_get_text(VALUE self)
+{
+  ruby_whisper_segment *rws;
+  Data_Get_Struct(self, ruby_whisper_segment, rws);
+  ruby_whisper *rw;
+  Data_Get_Struct(rws->context, ruby_whisper, rw);
+  const char * text = whisper_full_get_segment_text(rw->context, rws->index);
+  return rb_str_new2(text);
+}
+
+/*
+ * call-seq:
+ *   no_speech_prob -> Float
+ */
+static VALUE
+ruby_whisper_segment_get_no_speech_prob(VALUE self)
+{
+  ruby_whisper_segment *rws;
+  Data_Get_Struct(self, ruby_whisper_segment, rws);
+  ruby_whisper *rw;
+  Data_Get_Struct(rws->context, ruby_whisper, rw);
+  return DBL2NUM(whisper_full_get_segment_no_speech_prob(rw->context, rws->index));
+}
+
+void
+init_ruby_whisper_segment(VALUE *mWhisper, VALUE *cContext)
+{
+  cSegment  = rb_define_class_under(*mWhisper, "Segment", rb_cObject);
+
+  rb_define_alloc_func(cSegment, ruby_whisper_segment_allocate);
+  rb_define_method(cSegment, "start_time", ruby_whisper_segment_get_start_time, 0);
+  rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0);
+  rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0);
+  rb_define_method(cSegment, "text", ruby_whisper_segment_get_text, 0);
+  rb_define_method(cSegment, "no_speech_prob", ruby_whisper_segment_get_no_speech_prob, 0);
+}
diff --git a/bindings/ruby/ext/ruby_whisper_transcribe.cpp b/bindings/ruby/ext/ruby_whisper_transcribe.cpp
new file mode 100644 (file)
index 0000000..d50ed06
--- /dev/null
@@ -0,0 +1,159 @@
+#include <ruby.h>
+#include "ruby_whisper.h"
+#define DR_WAV_IMPLEMENTATION
+#include "dr_wav.h"
+#include <string>
+#include <vector>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+extern ID id_to_s;
+extern ID id_call;
+
+extern void
+register_callbacks(ruby_whisper_params * rwp, VALUE * self);
+
+/*
+ * transcribe a single file
+ * can emit to a block results
+ *
+ *   params = Whisper::Params.new
+ *   params.duration = 60_000
+ *   whisper.transcribe "path/to/audio.wav", params do |text|
+ *     puts text
+ *   end
+ *
+ * call-seq:
+ *   transcribe(path_to_audio, params) {|text| ...}
+ **/
+VALUE
+ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
+  ruby_whisper *rw;
+  ruby_whisper_params *rwp;
+  VALUE wave_file_path, blk, params;
+
+  rb_scan_args(argc, argv, "02&", &wave_file_path, &params, &blk);
+  Data_Get_Struct(self, ruby_whisper, rw);
+  Data_Get_Struct(params, ruby_whisper_params, rwp);
+
+  if (!rb_respond_to(wave_file_path, id_to_s)) {
+    rb_raise(rb_eRuntimeError, "Expected file path to wave file");
+  }
+
+  std::string fname_inp = StringValueCStr(wave_file_path);
+
+  std::vector<float> pcmf32; // mono-channel F32 PCM
+  std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
+
+  // WAV input - this is directly from main.cpp example
+  {
+    drwav wav;
+    std::vector<uint8_t> wav_data; // used for pipe input from stdin
+
+    if (fname_inp == "-") {
+      {
+        uint8_t buf[1024];
+        while (true) {
+          const size_t n = fread(buf, 1, sizeof(buf), stdin);
+          if (n == 0) {
+            break;
+          }
+          wav_data.insert(wav_data.end(), buf, buf + n);
+        }
+      }
+
+      if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
+        fprintf(stderr, "error: failed to open WAV file from stdin\n");
+        return self;
+      }
+
+      fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
+    } else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
+      fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
+      return self;
+    }
+
+    if (wav.channels != 1 && wav.channels != 2) {
+      fprintf(stderr, "WAV file '%s' must be mono or stereo\n", fname_inp.c_str());
+      return self;
+    }
+
+    if (rwp->diarize && wav.channels != 2 && rwp->params.print_timestamps == false) {
+      fprintf(stderr, "WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", fname_inp.c_str());
+      return self;
+    }
+
+    if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
+      fprintf(stderr, "WAV file '%s' must be %i kHz\n", fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
+      return self;
+    }
+
+    if (wav.bitsPerSample != 16) {
+      fprintf(stderr, "WAV file '%s' must be 16-bit\n", fname_inp.c_str());
+      return self;
+    }
+
+    const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
+
+    std::vector<int16_t> pcm16;
+    pcm16.resize(n*wav.channels);
+    drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
+    drwav_uninit(&wav);
+
+    // convert to mono, float
+    pcmf32.resize(n);
+    if (wav.channels == 1) {
+      for (uint64_t i = 0; i < n; i++) {
+        pcmf32[i] = float(pcm16[i])/32768.0f;
+      }
+    } else {
+      for (uint64_t i = 0; i < n; i++) {
+        pcmf32[i] = float((int32_t)pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
+      }
+    }
+
+    if (rwp->diarize) {
+      // convert to stereo, float
+      pcmf32s.resize(2);
+
+      pcmf32s[0].resize(n);
+      pcmf32s[1].resize(n);
+      for (uint64_t i = 0; i < n; i++) {
+        pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
+        pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
+      }
+    }
+  }
+  {
+    static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
+
+    rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
+      bool is_aborted = *(bool*)user_data;
+      return !is_aborted;
+    };
+    rwp->params.encoder_begin_callback_user_data = &is_aborted;
+  }
+
+  register_callbacks(rwp, &self);
+
+  if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
+    fprintf(stderr, "failed to process audio\n");
+    return self;
+  }
+  const int n_segments = whisper_full_n_segments(rw->context);
+  VALUE output = rb_str_new2("");
+  for (int i = 0; i < n_segments; ++i) {
+    const char * text = whisper_full_get_segment_text(rw->context, i);
+    output = rb_str_concat(output, rb_str_new2(text));
+  }
+  VALUE idCall = id_call;
+  if (blk != Qnil) {
+    rb_funcall(blk, idCall, 1, output);
+  }
+  return self;
+}
+#ifdef __cplusplus
+}
+#endif
index b43d90dd48621c1bce170c34e195300f1c0470fa..ce19f715a88cd1bf4d70251a941db105ef3e5910 100644 (file)
@@ -65,6 +65,13 @@ module Whisper
             end
           end
         end
+      rescue => err
+        if cache_path.exist?
+          warn err
+        # Use cache file
+        else
+          raise
+        end
       end
 
       def download(response)
index aff2ae73ee8a0b6d94554a8e22fb60249bd1353d..85d941cba966d7539a18b70f7490fdb7cada11f9 100644 (file)
@@ -20,13 +20,12 @@ module Whisper
   def self.lang_id: (string name) -> Integer
   def self.lang_str: (Integer id) -> String
   def self.lang_str_full: (Integer id) -> String
-  def self.log_set=: (log_callback) -> log_callback
-  def self.finalize_log_callback: (void) -> void # Second argument of ObjectSpace.define_finalizer
+  def self.log_set: (log_callback, Object? user_data) -> log_callback
 
   class Context
-    def initialize: (string | _ToPath | ::URI::HTTP ) -> void
-    def transcribe: (string, Params) -> void
-                  | (string, Params) { (String) -> void } -> void
+    def self.new: (string | _ToPath | ::URI::HTTP) -> instance
+    def transcribe: (string, Params) -> self
+                  | (string, Params) { (String) -> void } -> self
     def model_n_vocab: () -> Integer
     def model_n_audio_ctx: () -> Integer
     def model_n_audio_state: () -> Integer
@@ -35,6 +34,10 @@ module Whisper
     def model_n_mels: () -> Integer
     def model_ftype: () -> Integer
     def model_type: () -> String
+    def each_segment: { (Segment) -> void } -> void
+                    | () -> Enumerator[Segment]
+    def model: () -> Model
+    def full_get_segment: (Integer nth) -> Segment
     def full_n_segments: () -> Integer
     def full_lang_id: () -> Integer
     def full_get_segment_t0: (Integer) -> Integer
@@ -42,18 +45,46 @@ module Whisper
     def full_get_segment_speaker_turn_next: (Integer) -> (true | false)
     def full_get_segment_text: (Integer) -> String
     def full_get_segment_no_speech_prob: (Integer) -> Float
-    def full: (Params, Array[Float], ?Integer) -> void
-            | (Params, _Samples, ?Integer) -> void
-    def full_parallel: (Params, Array[Float], ?Integer) -> void
-                     | (Params, _Samples, ?Integer) -> void
-                     | (Params, _Samples, ?Integer?, Integer) -> void
-    def each_segment: { (Segment) -> void } -> void
-                    | () -> Enumerator[Segment]
-    def model: () -> Model
+    def full: (Params, Array[Float] samples, ?Integer n_samples) -> self
+            | (Params, _Samples, ?Integer n_samples) -> self
+    def full_parallel: (Params, Array[Float], ?Integer n_samples) -> self
+                     | (Params, _Samples, ?Integer n_samples) -> self
+                     | (Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self
   end
 
   class Params
-    def initialize: () -> void
+    def self.new: (
+      ?language: string,
+      ?translate: boolish,
+      ?no_context: boolish,
+      ?single_segment: boolish,
+      ?print_special: boolish,
+      ?print_progress: boolish,
+      ?print_realtime: boolish,
+      ?print_timestamps: boolish,
+      ?suppress_blank: boolish,
+      ?suppress_nst: boolish,
+      ?token_timestamps: boolish,
+      ?split_on_word: boolish,
+      ?initial_prompt: string | nil,
+      ?diarize: boolish,
+      ?offset: Integer,
+      ?duration: Integer,
+      ?max_text_tokens: Integer,
+      ?temperature: Float,
+      ?max_initial_ts: Float,
+      ?length_penalty: Float,
+      ?temperature_inc: Float,
+      ?entropy_thold: Float,
+      ?logprob_thold: Float,
+      ?no_speech_thold: Float,
+      ?new_segment_callback: new_segment_callback,
+      ?new_segment_callback_user_data: Object,
+      ?progress_callback: progress_callback,
+      ?progress_callback_user_data: Object,
+      ?abort_callback: abort_callback,
+      ?abort_callback_user_data: Object
+    ) -> instance
     def language=: (String) -> String # TODO: Enumerate lang names
     def language: () -> String
     def translate=: (boolish) -> boolish
@@ -79,7 +110,7 @@ module Whisper
     def split_on_word=: (boolish) -> boolish
     def split_on_word: () -> (true | false)
     def initial_prompt=: (_ToS) -> _ToS
-    def initial_prompt: () -> String
+    def initial_prompt: () -> (String | nil)
     def diarize=: (boolish) -> boolish
     def diarize: () -> (true | false)
     def offset=: (Integer) -> Integer
@@ -103,19 +134,25 @@ module Whisper
     def no_speech_thold=: (Float) -> Float
     def no_speech_thold: () -> Float
     def new_segment_callback=: (new_segment_callback) -> new_segment_callback
+    def new_segment_callback: () -> (new_segment_callback | nil)
     def new_segment_callback_user_data=: (Object) -> Object
+    def new_segment_callback_user_data: () -> Object
     def progress_callback=: (progress_callback) -> progress_callback
+    def progress_callback: () -> (progress_callback | nil)
     def progress_callback_user_data=: (Object) -> Object
+    def progress_callback_user_data: () -> Object
     def abort_callback=: (abort_callback) -> abort_callback
+    def abort_callback: () -> (abort_callback | nil)
     def abort_callback_user_data=: (Object) -> Object
+    def abort_callback_user_data: () -> Object
     def on_new_segment: { (Segment) -> void } -> void
-    def on_progress: { (Integer) -> void } -> void
-    def abort_on: { (Object) -> boolish } -> void
+    def on_progress: { (Integer progress) -> void } -> void
+    def abort_on: { (Object user_data) -> boolish } -> void
   end
 
   class Model
     def self.pre_converted_models: () -> Hash[String, Model::URI]
-    def initialize: () -> void
+    def self.new: () -> instance
     def n_vocab: () -> Integer
     def n_audio_ctx: () -> Integer
     def n_audio_state: () -> Integer
@@ -130,14 +167,13 @@ module Whisper
     def type: () -> String
 
     class URI
-      def initialize: (string | ::URI::HTTP) -> void
+      def self.new: (string | ::URI::HTTP) -> self
       def to_path: -> String
       def clear_cache: -> void
     end
   end
 
   class Segment
-    def initialize: () -> void
     def start_time: () -> Integer
     def end_time: () -> Integer
     def speaker_next_turn?: () -> (true | false)
@@ -148,6 +184,6 @@ module Whisper
   class Error < StandardError
     attr_reader code: Integer
 
-    def initialize: (Integer) -> void
+    def self.new: (Integer code) -> instance
   end
 end
index 7981bfaab5076db929337697463ace77475a0471..0cc49433930087d7295191681e41df686a4cd762 100644 (file)
@@ -1,6 +1,39 @@
 require_relative "helper"
 
 class TestParams < TestBase
+  PARAM_NAMES = [
+    :language,
+    :translate,
+    :no_context,
+    :single_segment,
+    :print_special,
+    :print_progress,
+    :print_realtime,
+    :print_timestamps,
+    :suppress_blank,
+    :suppress_nst,
+    :token_timestamps,
+    :split_on_word,
+    :initial_prompt,
+    :diarize,
+    :offset,
+    :duration,
+    :max_text_tokens,
+    :temperature,
+    :max_initial_ts,
+    :length_penalty,
+    :temperature_inc,
+    :entropy_thold,
+    :logprob_thold,
+    :no_speech_thold,
+    :new_segment_callback,
+    :new_segment_callback_user_data,
+    :progress_callback,
+    :progress_callback_user_data,
+    :abort_callback,
+    :abort_callback_user_data,
+  ]
+
   def setup
     @params  = Whisper::Params.new
   end
@@ -157,4 +190,57 @@ class TestParams < TestBase
     @params.no_speech_thold = 0.2
     assert_in_delta 0.2, @params.no_speech_thold
   end
+
+  def test_new_with_kw_args
+    params = Whisper::Params.new(language: "es")
+    assert_equal "es", params.language
+    assert_equal 1.0, params.max_initial_ts
+  end
+
+  def test_new_with_kw_args_non_existent
+    assert_raise ArgumentError do
+      Whisper::Params.new(non_existent: "value")
+    end
+  end
+
+  def test_new_with_kw_args_wrong_type
+    assert_raise TypeError do
+      Whisper::Params.new(language: 3)
+    end
+  end
+
+  data(PARAM_NAMES.collect {|param| [param, param]}.to_h)
+  def test_new_with_kw_args_default_values(param)
+    default_value = @params.send(param)
+    value = case [param, default_value]
+            in [*, true | false]
+              !default_value
+            in [*, Integer | Float]
+              default_value + 1
+            in [:language, *]
+              "es"
+            in [:initial_prompt, *]
+              "Initial prompt"
+            in [/_callback\Z/, *]
+              proc {}
+            in [/_user_data\Z/, *]
+              Object.new
+            end
+    params = Whisper::Params.new(param => value)
+    if Float === value
+      assert_in_delta value, params.send(param)
+    else
+      assert_equal value, params.send(param)
+    end
+
+    PARAM_NAMES.reject {|name| name == param}.each do |name|
+      expected = @params.send(name)
+      actual = params.send(name)
+      if Float === expected
+        assert_in_delta expected, actual
+      else
+        assert_equal expected, actual
+      end
+    end
+  end
 end
index 5b0d189e85f44cc4e78a99b3fdd4d5a4e0fafbf4..76b92c73bf551c8270e82e546e7cafda596a3d48 100644 (file)
@@ -29,6 +29,12 @@ class TestWhisper < TestBase
       assert_equal 0, whisper.full_lang_id
     end
 
+    def test_full_get_segment
+      segment = whisper.full_get_segment(0)
+      assert_equal 0, segment.start_time
+      assert_match /ask not what your country can do for you, ask what you can do for your country/, segment.text
+    end
+
     def test_full_get_segment_t0
       assert_equal 0, whisper.full_get_segment_t0(0)
       assert_raise IndexError do