]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ruby : add encoder begin callback related methods (#3076) upstream/latest
authorKITAITI Makoto <redacted>
Fri, 25 Apr 2025 19:33:11 +0000 (04:33 +0900)
committerGitHub <redacted>
Fri, 25 Apr 2025 19:33:11 +0000 (04:33 +0900)
* Lazy run TestBase.whisper

* Fix indentation

* Remove disused GGML_HIP_UMA from Ruby

* Add encoder_begin_callback

* Comment out existing abort mechanism

* Add test for encoder_begin_callback

* Add signatures for encoder_begin_callback related methods

* Update gem date

bindings/ruby/ext/options.rb
bindings/ruby/ext/ruby_whisper.h
bindings/ruby/ext/ruby_whisper_params.c
bindings/ruby/ext/ruby_whisper_transcribe.cpp
bindings/ruby/lib/whisper/model/uri.rb
bindings/ruby/sig/whisper.rbs
bindings/ruby/tests/helper.rb
bindings/ruby/tests/test_callback.rb
bindings/ruby/whispercpp.gemspec

index 679b74d133a5d1b2e099d133fc94c8ff30206324..6fed318405906e05191e541d6fbd4ecacb1f63c9 100644 (file)
@@ -114,7 +114,6 @@ class Options
     bool "GGML_HIP_GRAPHS"
     bool "GGML_HIP_NO_VMM"
     bool "GGML_HIP_ROCWMMA_FATTN"
-    bool "GGML_HIP_UMA"
     ignored "GGML_INCLUDE_INSTALL_DIR"
     bool "GGML_KOMPUTE"
     bool "GGML_LASX"
index bbf3435e52c4629ca8eecb471699210ccd3d9ea2..6111a151784141797307ecad27ceca21a42bbb58 100644 (file)
@@ -19,6 +19,7 @@ typedef struct {
   bool diarize;
   ruby_whisper_callback_container *new_segment_callback_container;
   ruby_whisper_callback_container *progress_callback_container;
+  ruby_whisper_callback_container *encoder_begin_callback_container;
   ruby_whisper_callback_container *abort_callback_container;
 } ruby_whisper_params;
 
index caeb34f2274c8e10a74e99bee902206b5e8366a5..c07f2372f16722b0694a89388605d5889d842ac8 100644 (file)
@@ -26,7 +26,7 @@
   rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \
   rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1);
 
-#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 30
+#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 32
 
 extern VALUE cParams;
 
@@ -63,6 +63,8 @@ static ID id_new_segment_callback;
 static ID id_new_segment_callback_user_data;
 static ID id_progress_callback;
 static ID id_progress_callback_user_data;
+static ID id_encoder_begin_callback;
+static ID id_encoder_begin_callback_user_data;
 static ID id_abort_callback;
 static ID id_abort_callback_user_data;
 
@@ -126,6 +128,33 @@ static void progress_callback(struct whisper_context *ctx, struct whisper_state
   }
 }
 
+static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_state *state, void *user_data) {
+  const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
+  bool is_aborted = false;
+  VALUE result;
+
+  // Currently, doesn't support state because
+  // those require to resolve GC-related problems.
+  if (!NIL_P(container->callback)) {
+    result = rb_funcall(container->callback, id_call, 3, *container->context, Qnil, container->user_data);
+    if (result == Qfalse) {
+      is_aborted = true;
+    }
+  }
+  const long callbacks_len = RARRAY_LEN(container->callbacks);
+  if (0 == callbacks_len) {
+    return !is_aborted;
+  }
+  for (int j = 0; j < callbacks_len; j++) {
+    VALUE cb = rb_ary_entry(container->callbacks, j);
+    result = rb_funcall(cb, id_call, 0);
+    if (result == Qfalse) {
+      is_aborted = true;
+    }
+  }
+  return !is_aborted;
+}
+
 static bool abort_callback(void * user_data) {
   const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
   if (!NIL_P(container->callback)) {
@@ -161,6 +190,12 @@ void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
     rwp->params.progress_callback_user_data = rwp->progress_callback_container;
   }
 
+  if (!NIL_P(rwp->encoder_begin_callback_container->callback) || 0 != RARRAY_LEN(rwp->encoder_begin_callback_container->callbacks)) {
+    rwp->encoder_begin_callback_container->context = context;
+    rwp->params.encoder_begin_callback = encoder_begin_callback;
+    rwp->params.encoder_begin_callback_user_data = rwp->encoder_begin_callback_container;
+  }
+
   if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
     rwp->abort_callback_container->context = context;
     rwp->params.abort_callback = abort_callback;
@@ -173,6 +208,7 @@ rb_whisper_params_mark(ruby_whisper_params *rwp)
 {
   rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container);
   rb_whisper_callbcack_container_mark(rwp->progress_callback_container);
+  rb_whisper_callbcack_container_mark(rwp->encoder_begin_callback_container);
   rb_whisper_callbcack_container_mark(rwp->abort_callback_container);
 }
 
@@ -198,6 +234,7 @@ ruby_whisper_params_allocate(VALUE klass)
   rwp->diarize = false;
   rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
   rwp->progress_callback_container = rb_whisper_callback_container_allocate();
+  rwp->encoder_begin_callback_container = rb_whisper_callback_container_allocate();
   rwp->abort_callback_container = rb_whisper_callback_container_allocate();
   return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
 }
@@ -849,6 +886,57 @@ ruby_whisper_params_set_progress_callback_user_data(VALUE self, VALUE value)
   rwp->progress_callback_container->user_data = value;
   return value;
 }
+
+static VALUE
+ruby_whisper_params_get_encoder_begin_callback(VALUE self)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  return rwp->encoder_begin_callback_container->callback;
+}
+
+/*
+ * Sets encoder begin callback, called when the encoder starts.
+ *
+ *   params.encoder_begin_callback = ->(context, _, user_data) {
+ *     # ...
+ *   }
+ *
+ * call-seq:
+ *   encoder_begin_callback = callback -> callback
+ */
+static VALUE
+ruby_whisper_params_set_encoder_begin_callback(VALUE self, VALUE value)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  rwp->encoder_begin_callback_container->callback = value;
+  return value;
+}
+
+static VALUE
+ruby_whisper_params_get_encoder_begin_callback_user_data(VALUE self)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  return rwp->encoder_begin_callback_container->user_data;
+}
+
+/*
+ * Sets user data passed to the last argument of encoder begin callback.
+ *
+ * call-seq:
+ *   encoder_begin_callback_user_data = user_data -> use_data
+ */
+static VALUE
+ruby_whisper_params_set_encoder_begin_callback_user_data(VALUE self, VALUE value)
+{
+  ruby_whisper_params *rwp;
+  Data_Get_Struct(self, ruby_whisper_params, rwp);
+  rwp->encoder_begin_callback_container->user_data = value;
+  return value;
+}
+
 static VALUE
 ruby_whisper_params_get_abort_callback(VALUE self)
 {
@@ -958,6 +1046,8 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
       SET_PARAM_IF_SAME(new_segment_callback_user_data)
       SET_PARAM_IF_SAME(progress_callback)
       SET_PARAM_IF_SAME(progress_callback_user_data)
+      SET_PARAM_IF_SAME(encoder_begin_callback)
+      SET_PARAM_IF_SAME(encoder_begin_callback_user_data)
       SET_PARAM_IF_SAME(abort_callback)
       SET_PARAM_IF_SAME(abort_callback_user_data)
     }
@@ -1008,6 +1098,26 @@ ruby_whisper_params_on_progress(VALUE self)
   return Qnil;
 }
 
+/*
+ * Hook called when the encoder starts.
+ *
+ *   whisper.on_encoder_begin do
+ *     # ...
+ *   end
+ *
+ * call-seq:
+ *   on_encoder_begin { ... }
+ */
+static VALUE
+ruby_whisper_params_on_encoder_begin(VALUE self)
+{
+  ruby_whisper_params *rws;
+  Data_Get_Struct(self, ruby_whisper_params, rws);
+  const VALUE blk = rb_block_proc();
+  rb_ary_push(rws->encoder_begin_callback_container->callbacks, blk);
+  return Qnil;
+}
+
 /*
  * Call block to determine whether abort or not. Return +true+ when you want to abort.
  *
@@ -1068,10 +1178,13 @@ init_ruby_whisper_params(VALUE *mWhisper)
   DEFINE_PARAM(new_segment_callback_user_data, 25)
   DEFINE_PARAM(progress_callback, 26)
   DEFINE_PARAM(progress_callback_user_data, 27)
-  DEFINE_PARAM(abort_callback, 28)
-  DEFINE_PARAM(abort_callback_user_data, 29)
+  DEFINE_PARAM(encoder_begin_callback, 28)
+  DEFINE_PARAM(encoder_begin_callback_user_data, 29)
+  DEFINE_PARAM(abort_callback, 30)
+  DEFINE_PARAM(abort_callback_user_data, 31)
 
   rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
   rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
+  rb_define_method(cParams, "on_encoder_begin", ruby_whisper_params_on_encoder_begin, 0);
   rb_define_method(cParams, "abort_on", ruby_whisper_params_abort_on, 0);
 }
index 00b9d2e1a42e184db35301eb2bb7359248a20adf..ef3c0780f45c8a5dbb4d3297ef63bf7a9a4c1068 100644 (file)
@@ -50,15 +50,16 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
     fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
     return self;
   }
-  {
-    static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
+  // Commented out because it is work in progress
+  // {
+  //   static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
 
-    rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
-      bool is_aborted = *(bool*)user_data;
-      return !is_aborted;
-    };
-    rwp->params.encoder_begin_callback_user_data = &is_aborted;
-  }
+  //   rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
+  //     bool is_aborted = *(bool*)user_data;
+  //     return !is_aborted;
+  //   };
+  //   rwp->params.encoder_begin_callback_user_data = &is_aborted;
+  // }
 
   register_callbacks(rwp, &self);
 
index b2bc9c4b38b4f2d7fcd7c3a3854e3cbc43fe7f91..47c23c52721aea93185f91a48a6eef21571241d4 100644 (file)
@@ -53,7 +53,7 @@ module Whisper
           http.request request do |response|
             case response
             when Net::HTTPNotModified
-            # noop
+              # noop
             when Net::HTTPOK
               download response
             when Net::HTTPRedirection
@@ -68,7 +68,7 @@ module Whisper
       rescue => err
         if cache_path.exist?
           warn err
-        # Use cache file
+          # Use cache file
         else
           raise
         end
index 0f3d74e0a9418972dd60589078d98a00af40a587..a3ce94b8fde0d9acfdf900aff74e8b8584d538cb 100644 (file)
@@ -7,6 +7,7 @@ module Whisper
   type log_callback = ^(Integer level, String message, Object user_data) -> void
   type new_segment_callback = ^(Whisper::Context, void, Integer n_new, Object user_data) -> void
   type progress_callback = ^(Whisper::Context, void, Integer progress, Object user_data) -> void
+  type encoder_begin_callback = ^(Whisper::Context, void, Object user_data) -> void
   type abort_callback = ^(Whisper::Context, void, Object user_data) -> boolish
 
   LOG_LEVEL_NONE: Integer
@@ -146,6 +147,8 @@ module Whisper
       ?new_segment_callback_user_data: Object,
       ?progress_callback: progress_callback,
       ?progress_callback_user_data: Object,
+      ?encoder_begin_callback: encoder_begin_callback,
+      ?encoder_begin_callback_user_data: Object,
       ?abort_callback: abort_callback,
       ?abort_callback_user_data: Object
     ) -> instance
@@ -306,6 +309,18 @@ module Whisper
 
     def progress_callback_user_data: () -> Object
 
+    # Sets encoder begin callback, called when the encoder starts.
+    #
+    def encoder_begin_callback=: (encoder_begin_callback) -> encoder_begin_callback
+
+    def encoder_begin_callback: () -> (encoder_begin_callback | nil)
+
+    # Sets user data passed to the last argument of encoder begin callback.
+    #
+    def encoder_begin_callback_user_data=: (Object) -> Object
+
+    def encoder_begin_callback_user_data: () -> Object
+
     # Sets abort callback, called to check if the process should be aborted.
     #
     #   params.abort_callback = ->(user_data) {
@@ -335,6 +350,10 @@ module Whisper
     #
     def on_progress: { (Integer progress) -> void } -> void
 
+    # Hook called on encoder starts.
+    #
+    def on_encoder_begin: { () -> void } -> void
+
     # Call block to determine whether abort or not. Return +true+ when you want to abort.
     #
     #   params.abort_on do
index a69a2b7e2c248d413b098c3314586bfa9b683d90..bc5e472456500456e1bcb010f2a299ba8aa59461 100644 (file)
@@ -6,9 +6,9 @@ class TestBase < Test::Unit::TestCase
   AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")
 
   class << self
-    attr_reader :whisper
+    def whisper
+      return @whisper if @whisper
 
-    def startup
       @whisper = Whisper::Context.new("base.en")
       params = Whisper::Params.new
       params.print_timestamps = false
index 61ef366c36e6df423ec883c8eaa6390436fff7e7..a7f49245ade573d41033406bce49fa127cef4e0f 100644 (file)
@@ -111,6 +111,48 @@ class TestCallback < TestBase
     assert_equal 100, last
   end
 
+  def test_encoder_begin_callback
+    i = 0
+    @params.encoder_begin_callback = ->(context, state, user_data) {
+      i += 1
+    }
+    @whisper.transcribe(@audio, @params)
+    assert i > 0
+  end
+
+  def test_encoder_begin_callback_abort
+    logs = []
+    Whisper.log_set -> (level, buffer, user_data) {
+      logs << buffer if level == Whisper::LOG_LEVEL_ERROR
+    }, logs
+    @params.encoder_begin_callback = ->(context, state, user_data) {
+      return false
+    }
+    @whisper.transcribe(@audio, @params)
+    assert_match(/encoder_begin_callback returned false - aborting/, logs.join)
+    Whisper.log_set ->(level, buffer, user_data) {}, nil
+  end
+
+  def test_encoder_begin_callback_user_data
+    udata = Object.new
+    @params.encoder_begin_callback_user_data = udata
+    yielded = nil
+    @params.encoder_begin_callback = ->(context, state, user_data) {
+      yielded = user_data
+    }
+    @whisper.transcribe(@audio, @params)
+    assert_same udata, yielded
+  end
+
+  def test_on_encoder_begin
+    i = 0
+    @params.on_encoder_begin do
+      i += 1
+    end
+    @whisper.transcribe(@audio, @params)
+    assert i > 0
+  end
+
   def test_abort_callback
     i = 0
     @params.abort_callback = ->(user_data) {
index 329e670bfdf4e5e5c05db4006a878deee116a824..97cf4e27a12178317b105db8d765b0690dcff135 100644 (file)
@@ -4,7 +4,7 @@ Gem::Specification.new do |s|
   s.name    = "whispercpp"
   s.authors = ["Georgi Gerganov", "Todd A. Fisher"]
   s.version = '1.3.2'
-  s.date    = '2025-04-17'
+  s.date    = '2025-04-25'
   s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby}
   s.email   = 'todd.fisher@gmail.com'
   s.extra_rdoc_files = ['LICENSE', 'README.md']