]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ruby : bug fix on callbacks and no_speech_prob (#2656)
authorKITAITI Makoto <redacted>
Sat, 21 Dec 2024 19:52:06 +0000 (04:52 +0900)
committerGitHub <redacted>
Sat, 21 Dec 2024 19:52:06 +0000 (21:52 +0200)
* Don't generate documentation on test

* Move .startup to TestBase class

* Extract new_segment_callback as a function

* Extract progress_callback as a function

* Extract abort_callback as a function

* Extract register_callbacks as a function

* Call callbacks in Whiser::Context#full and #full_parallel

* Fix README

* Care about the cases content-size is nil and TTY is not available

* Add tests for no_speech_prob

* Add Whisper::Context#full_get_segment_no_speech_prob and Whisper::Segment#no_speech_prob

bindings/ruby/README.md
bindings/ruby/ext/ruby_whisper.cpp
bindings/ruby/lib/whisper/model/uri.rb
bindings/ruby/tests/helper.rb
bindings/ruby/tests/test_package.rb
bindings/ruby/tests/test_segment.rb
bindings/ruby/tests/test_whisper.rb

index 03a8b9e165c3b88d7a3df5a29ea77cdd467fe0dc..8492e4ed91b02ba05a7e8381ec9b7604aa603733 100644 (file)
@@ -63,7 +63,7 @@ whisper = Whisper::Context.new("base.en")
 You can see the list of prepared model names by `Whisper::Model.preconverted_models.keys`:
 
 ```ruby
-puts Whisper::Model.preconverted_model_names
+puts Whisper::Model.preconverted_models.keys
 # tiny
 # tiny.en
 # tiny-q5_1
@@ -220,7 +220,7 @@ whisper.each_segment do |segment|
 end
 ```
 
-The second argument `samples` may be an array, an object with `length` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy.
+The second argument `samples` may be an array, an object with `length` and `each` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy.
 
 License
 -------
index aa526577fbed32add0398ab8dc982e161984b263..88a4fd2c205509fcc685bfafa3e8342f5bc9d346 100644 (file)
@@ -53,6 +53,9 @@ static ID id_pre_converted_models;
 
 static bool is_log_callback_finalized = false;
 
+// High level API
+static VALUE rb_whisper_segment_initialize(VALUE context, int index);
+
 /*
  * call-seq:
  *   lang_max_id -> Integer
@@ -187,6 +190,69 @@ static ruby_whisper_callback_container * rb_whisper_callback_container_allocate(
   return container;
 }
 
+static void new_segment_callback(struct whisper_context *ctx, struct whisper_state *state, int n_new, void *user_data) {
+  const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
+
+  // Currently, doesn't support state because
+  // those require to resolve GC-related problems.
+  if (!NIL_P(container->callback)) {
+    rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data);
+  }
+  const long callbacks_len = RARRAY_LEN(container->callbacks);
+  if (0 == callbacks_len) {
+    return;
+  }
+  const int n_segments = whisper_full_n_segments_from_state(state);
+  for (int i = n_new; i > 0; i--) {
+    int i_segment = n_segments - i;
+    VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment);
+    for (int j = 0; j < callbacks_len; j++) {
+      VALUE cb = rb_ary_entry(container->callbacks, j);
+      rb_funcall(cb, id_call, 1, segment);
+    }
+  }
+}
+
+static void progress_callback(struct whisper_context *ctx, struct whisper_state *state, int progress_cur, void *user_data) {
+  const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
+  const VALUE progress = INT2NUM(progress_cur);
+  // Currently, doesn't support state because
+  // those require to resolve GC-related problems.
+  if (!NIL_P(container->callback)) {
+    rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data);
+  }
+  const long callbacks_len = RARRAY_LEN(container->callbacks);
+  if (0 == callbacks_len) {
+    return;
+  }
+  for (int j = 0; j < callbacks_len; j++) {
+    VALUE cb = rb_ary_entry(container->callbacks, j);
+    rb_funcall(cb, id_call, 1, progress);
+  }
+}
+
+static bool abort_callback(void * user_data) {
+  const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
+  if (!NIL_P(container->callback)) {
+    VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data);
+    if (!NIL_P(result) && Qfalse != result) {
+      return true;
+    }
+  }
+  const long callbacks_len = RARRAY_LEN(container->callbacks);
+  if (0 == callbacks_len) {
+    return false;
+  }
+  for (int j = 0; j < callbacks_len; j++) {
+    VALUE cb = rb_ary_entry(container->callbacks, j);
+    VALUE result = rb_funcall(cb, id_call, 1, container->user_data);
+    if (!NIL_P(result) && Qfalse != result) {
+      return true;
+    }
+  }
+  return false;
+}
+
 static VALUE ruby_whisper_params_allocate(VALUE klass) {
   ruby_whisper_params *rwp;
   rwp = ALLOC(ruby_whisper_params);
@@ -230,8 +296,25 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
   return self;
 }
 
-// High level API
-static VALUE rb_whisper_segment_initialize(VALUE context, int index);
+static void register_callbacks(ruby_whisper_params * rwp, VALUE * self) {
+  if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) {
+    rwp->new_segment_callback_container->context = self;
+    rwp->params.new_segment_callback = new_segment_callback;
+    rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container;
+  }
+
+  if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) {
+    rwp->progress_callback_container->context = self;
+    rwp->params.progress_callback = progress_callback;
+    rwp->params.progress_callback_user_data = rwp->progress_callback_container;
+  }
+
+  if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
+    rwp->abort_callback_container->context = self;
+    rwp->params.abort_callback = abort_callback;
+    rwp->params.abort_callback_user_data = rwp->abort_callback_container;
+  }
+}
 
 /*
  * transcribe a single file
@@ -353,80 +436,7 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
     rwp->params.encoder_begin_callback_user_data = &is_aborted;
   }
 
-  if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) {
-    rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) {
-      const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
-
-      // Currently, doesn't support state because
-      // those require to resolve GC-related problems.
-      if (!NIL_P(container->callback)) {
-        rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data);
-      }
-      const long callbacks_len = RARRAY_LEN(container->callbacks);
-      if (0 == callbacks_len) {
-        return;
-      }
-      const int n_segments = whisper_full_n_segments_from_state(state);
-      for (int i = n_new; i > 0; i--) {
-        int i_segment = n_segments - i;
-        VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment);
-        for (int j = 0; j < callbacks_len; j++) {
-          VALUE cb = rb_ary_entry(container->callbacks, j);
-          rb_funcall(cb, id_call, 1, segment);
-        }
-      }
-    };
-    rwp->new_segment_callback_container->context = &self;
-    rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container;
-  }
-
-  if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) {
-    rwp->params.progress_callback = [](struct whisper_context *ctx, struct whisper_state * /*state*/, int progress_cur, void *user_data) {
-      const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
-      const VALUE progress = INT2NUM(progress_cur);
-      // Currently, doesn't support state because
-      // those require to resolve GC-related problems.
-      if (!NIL_P(container->callback)) {
-        rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data);
-      }
-      const long callbacks_len = RARRAY_LEN(container->callbacks);
-      if (0 == callbacks_len) {
-        return;
-      }
-      for (int j = 0; j < callbacks_len; j++) {
-        VALUE cb = rb_ary_entry(container->callbacks, j);
-        rb_funcall(cb, id_call, 1, progress);
-      }
-    };
-    rwp->progress_callback_container->context = &self;
-    rwp->params.progress_callback_user_data = rwp->progress_callback_container;
-  }
-
-  if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
-    rwp->params.abort_callback = [](void * user_data) {
-      const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
-      if (!NIL_P(container->callback)) {
-        VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data);
-        if (!NIL_P(result) && Qfalse != result) {
-          return true;
-        }
-      }
-      const long callbacks_len = RARRAY_LEN(container->callbacks);
-      if (0 == callbacks_len) {
-        return false;
-      }
-      for (int j = 0; j < callbacks_len; j++) {
-        VALUE cb = rb_ary_entry(container->callbacks, j);
-        VALUE result = rb_funcall(cb, id_call, 1, container->user_data);
-        if (!NIL_P(result) && Qfalse != result) {
-          return true;
-        }
-      }
-      return false;
-    };
-    rwp->abort_callback_container->context = &self;
-    rwp->params.abort_callback_user_data = rwp->abort_callback_container;
-  }
+  register_callbacks(rwp, &self);
 
   if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
     fprintf(stderr, "failed to process audio\n");
@@ -631,6 +641,7 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) {
       }
     }
   }
+  register_callbacks(rwp, &self);
   const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples);
   if (0 == result) {
     return Qnil;
@@ -719,6 +730,7 @@ static VALUE ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) {
       }
     }
   }
+  register_callbacks(rwp, &self);
   const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors);
   if (0 == result) {
     return Qnil;
@@ -823,6 +835,18 @@ static VALUE ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment) {
   return rb_str_new2(text);
 }
 
+/*
+ * call-seq:
+ *   full_get_segment_no_speech_prob -> Float
+ */
+static VALUE ruby_whisper_full_get_segment_no_speech_prob(VALUE self, VALUE i_segment) {
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
+  const float no_speech_prob = whisper_full_get_segment_no_speech_prob(rw->context, c_i_segment);
+  return DBL2NUM(no_speech_prob);
+}
+
 /*
  * params.language = "auto" | "en", etc...
  *
@@ -1547,6 +1571,18 @@ static VALUE ruby_whisper_segment_get_text(VALUE self) {
   return rb_str_new2(text);
 }
 
+/*
+ * call-seq:
+ *   no_speech_prob -> Float
+ */
+static VALUE ruby_whisper_segment_get_no_speech_prob(VALUE self) {
+  ruby_whisper_segment *rws;
+  Data_Get_Struct(self, ruby_whisper_segment, rws);
+  ruby_whisper *rw;
+  Data_Get_Struct(rws->context, ruby_whisper, rw);
+  return DBL2NUM(whisper_full_get_segment_no_speech_prob(rw->context, rws->index));
+}
+
 static void rb_whisper_model_mark(ruby_whisper_model *rwm) {
   rb_gc_mark(rwm->context);
 }
@@ -1809,6 +1845,7 @@ void Init_whisper() {
   rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1);
   rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1);
   rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1);
+  rb_define_method(cContext, "full_get_segment_no_speech_prob", ruby_whisper_full_get_segment_no_speech_prob, 1);
   rb_define_method(cContext, "full", ruby_whisper_full, -1);
   rb_define_method(cContext, "full_parallel", ruby_whisper_full_parallel, -1);
 
@@ -1887,6 +1924,7 @@ void Init_whisper() {
   rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0);
   rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0);
   rb_define_method(cSegment, "text", ruby_whisper_segment_get_text, 0);
+  rb_define_method(cSegment, "no_speech_prob", ruby_whisper_segment_get_no_speech_prob, 0);
 
   cModel = rb_define_class_under(mWhisper, "Model", rb_cObject);
   rb_define_alloc_func(cModel, ruby_whisper_model_allocate);
index 5ca77ed4bb27c47047d884738752d930277e198e..fe5ed56b3fbde9f8d8273caee5745081e7808c21 100644 (file)
@@ -79,30 +79,36 @@ class Whisper::Model
           downloaded += chunk.bytesize
           show_progress downloaded, size
         end
+        $stderr.puts
       end
       downloading_path.rename path
     end
 
     def show_progress(current, size)
-      return unless $stderr.tty?
-      return unless size
+      progress_rate_available = size && $stderr.tty?
 
       unless @prev
         @prev = Time.now
-        $stderr.puts "Downloading #{@uri}"
+        $stderr.puts "Downloading #{@uri} to #{cache_path}"
       end
 
       now = Time.now
-      return if now - @prev < 1 && current < size
-
-      progress_width = 20
-      progress = current.to_f / size
-      arrow_length = progress * progress_width
-      arrow = "=" * (arrow_length - 1) + ">" + " " * (progress_width - arrow_length)
-      line = "[#{arrow}] (#{format_bytesize(current)} / #{format_bytesize(size)})"
-      padding = ' ' * ($stderr.winsize[1] - line.size)
-      $stderr.print "\r#{line}#{padding}"
-      $stderr.puts if current >= size
+
+      if progress_rate_available
+        return if now - @prev < 1 && current < size
+
+        progress_width = 20
+        progress = current.to_f / size
+        arrow_length = progress * progress_width
+        arrow = "=" * (arrow_length - 1) + ">" + " " * (progress_width - arrow_length)
+        line = "[#{arrow}] (#{format_bytesize(current)} / #{format_bytesize(size)})"
+        padding = ' ' * ($stderr.winsize[1] - line.size)
+        $stderr.print "\r#{line}#{padding}"
+      else
+        return if now - @prev < 1
+
+        $stderr.print "."
+      end
       @prev = now
     end
 
index da52f2687312b9d175878bc21e2c4a6b9cb98e04..a182319d95f21aa1c975adc28976d2b72ee7364d 100644 (file)
@@ -4,4 +4,21 @@ require_relative "jfk_reader/jfk_reader"
 
 class TestBase < Test::Unit::TestCase
   AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")
+
+  class << self
+    attr_reader :whisper
+
+    def startup
+      @whisper = Whisper::Context.new("base.en")
+      params = Whisper::Params.new
+      params.print_timestamps = false
+      @whisper.transcribe(TestBase::AUDIO, params)
+    end
+  end
+
+  private
+
+  def whisper
+    self.class.whisper
+  end
 end
index 9c47870ef28cce56954630e7d5ca70ab7532d873..33c2b37e5321b5c59125c4e665a6a5d7c99f20c8 100644 (file)
@@ -23,7 +23,7 @@ class TestPackage < TestBase
       version = match_data[2]
       basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}"
       Dir.mktmpdir do |dir|
-        system "gem", "install", "--install-dir", dir.shellescape, "pkg/#{filename.shellescape}", exception: true
+        system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{filename.shellescape}", exception: true
         assert_path_exist File.join(dir, "gems/whispercpp-#{version}/lib", basename)
       end
     end
index 559bcea71e6f654ab7371ef7bf5b67a1698812bb..44ab0a6b74de934f5902d2cc6e0c5b7783966c88 100644 (file)
@@ -1,17 +1,6 @@
 require_relative "helper"
 
 class TestSegment < TestBase
-  class << self
-    attr_reader :whisper
-
-    def startup
-      @whisper = Whisper::Context.new("base.en")
-      params = Whisper::Params.new
-      params.print_timestamps = false
-      @whisper.transcribe(TestBase::AUDIO, params)
-    end
-  end
-
   def test_iteration
     whisper.each_segment do |segment|
       assert_instance_of Whisper::Segment, segment
@@ -43,6 +32,14 @@ class TestSegment < TestBase
     end
   end
 
+  def test_no_speech_prob
+    no_speech_prob = nil
+    whisper.each_segment do |segment|
+      no_speech_prob = segment.no_speech_prob
+    end
+    assert no_speech_prob > 0.0
+  end
+
   def test_on_new_segment
     params = Whisper::Params.new
     seg = nil
@@ -74,10 +71,4 @@ class TestSegment < TestBase
     end
     whisper.transcribe(AUDIO, params)
   end
-
-  private
-
-  def whisper
-    self.class.whisper
-  end
 end
index 115569edb3d4df94e04de474aa79949400e27c55..5b0d189e85f44cc4e78a99b3fdd4d5a4e0fafbf4 100644 (file)
@@ -21,21 +21,6 @@ class TestWhisper < TestBase
   end
 
   sub_test_case "After transcription" do
-    class << self
-      attr_reader :whisper
-
-      def startup
-        @whisper = Whisper::Context.new("base.en")
-        params = Whisper::Params.new
-        params.print_timestamps = false
-        @whisper.transcribe(TestBase::AUDIO, params)
-      end
-    end
-
-    def whisper
-      self.class.whisper
-    end
-
     def test_full_n_segments
       assert_equal 1, whisper.full_n_segments
     end
@@ -70,6 +55,12 @@ class TestWhisper < TestBase
     def test_full_get_segment_text
       assert_match /ask not what your country can do for you, ask what you can do for your country/, whisper.full_get_segment_text(0)
     end
+
+    def test_full_get_segment_no_speech_prob
+      prob = whisper.full_get_segment_no_speech_prob(0)
+      assert prob > 0.0
+      assert prob < 1.0
+    end
   end
 
   def test_lang_max_id