]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ruby : extend API (#2551)
authorKITAITI Makoto <redacted>
Wed, 13 Nov 2024 19:52:56 +0000 (04:52 +0900)
committerGitHub <redacted>
Wed, 13 Nov 2024 19:52:56 +0000 (21:52 +0200)
* Handle objs in Ruby code

* Add task to make Makefile

* Share commont constance in test suites

* Add model-related APIs

* Add Whisper::Model class

* Add tests for Whisper::Model

* Add missing LDFLAG -lstdc++

* Add tests for Whisper.log_set

* Add Whisper.set_log

* Define log level

* Add document on logging

* Add license section to README

* Add document on Whisper::Model

* Fix examples in README

* Add test for Model with GC

* Make dependency on Makefile more accurate

* Fix bug about Whisper::Model and GC

bindings/ruby/README.md
bindings/ruby/Rakefile
bindings/ruby/ext/extconf.rb
bindings/ruby/ext/ruby_whisper.cpp
bindings/ruby/tests/helper.rb [new file with mode: 0644]
bindings/ruby/tests/test_model.rb [new file with mode: 0644]
bindings/ruby/tests/test_package.rb
bindings/ruby/tests/test_params.rb
bindings/ruby/tests/test_segment.rb
bindings/ruby/tests/test_whisper.rb

index a63e833313acfc65276b032bd8171b3c7bb2761a..05e2279f9516c0fceeb6e60b3aa169c783c4381d 100644 (file)
@@ -107,5 +107,63 @@ whisper.transcribe("path/to/audio.wav", params)
 
 ```
 
+You can see model information:
+
+```ruby
+whisper = Whisper::Context.new("path/to/model.bin")
+model = whisper.model
+
+model.n_vocab # => 51864
+model.n_audio_ctx # => 1500
+model.n_audio_state # => 512
+model.n_audio_head # => 8
+model.n_audio_layer # => 6
+model.n_text_ctx # => 448
+model.n_text_state # => 512
+model.n_text_head # => 8
+model.n_text_layer # => 6
+model.n_mels # => 80
+model.ftype # => 1
+model.type # => "base"
+
+```
+
+You can set log callback:
+
+```ruby
+prefix = "[MyApp] "
+log_callback = ->(level, buffer, user_data) {
+  case level
+  when Whisper::LOG_LEVEL_NONE
+    puts "#{user_data}none: #{buffer}"
+  when Whisper::LOG_LEVEL_INFO
+    puts "#{user_data}info: #{buffer}"
+  when Whisper::LOG_LEVEL_WARN
+    puts "#{user_data}warn: #{buffer}"
+  when Whisper::LOG_LEVEL_ERROR
+    puts "#{user_data}error: #{buffer}"
+  when Whisper::LOG_LEVEL_DEBUG
+    puts "#{user_data}debug: #{buffer}"
+  when Whisper::LOG_LEVEL_CONT
+    puts "#{user_data}same to previous: #{buffer}"
+  end
+}
+Whisper.log_set log_callback, prefix
+```
+
+Using this feature, you are also able to suppress log:
+
+```ruby
+Whisper.log_set ->(level, buffer, user_data) {
+  # do nothing
+}, nil
+Whisper::Context.new(MODEL)
+```
+
+License
+-------
+
+The same to [whisper.cpp][].
+
 [whisper.cpp]: https://github.com/ggerganov/whisper.cpp
 [models]: https://github.com/ggerganov/whisper.cpp/tree/master/models
index 5a6a9167a9f3b280460bc60228ec263bee51a5ea..d6fc49c8c09c12418b336beedd025ad589117e99 100644 (file)
@@ -23,30 +23,39 @@ CLEAN.include FileList[
                 "ext/depend"
               ]
 
-task build: SOURCES + FileList[
-                        "ext/extconf.rb",
-                        "ext/ruby_whisper.h",
-                        "ext/ruby_whisper.cpp",
-                        "whispercpp.gemspec",
-                      ]
+task build: FileList[
+       "ext/Makefile",
+       "ext/ruby_whisper.h",
+       "ext/ruby_whisper.cpp",
+       "whispercpp.gemspec",
+     ]
 
 directory "pkg"
 CLOBBER.include "pkg"
 
 TEST_MODEL = "../../models/ggml-base.en.bin"
 LIB_NAME = "whisper".ext(RbConfig::CONFIG["DLEXT"])
+SO_FILE = File.join("ext", LIB_NAME)
 LIB_FILE = File.join("lib", LIB_NAME)
 
-directory "lib"
-task LIB_FILE => SOURCES + ["lib"] do |t|
+file "ext/Makefile" => ["ext/extconf.rb", "ext/ruby_whisper.h", "ext/ruby_whisper.cpp"] + SOURCES do |t|
+  Dir.chdir "ext" do
+    ruby "extconf.rb"
+  end
+end
+
+file SO_FILE => "ext/Makefile" do |t|
   Dir.chdir "ext" do
-    sh "ruby extconf.rb"
     sh "make"
   end
-  mv "ext/#{LIB_NAME}", t.name
 end
 CLEAN.include LIB_FILE
 
+directory "lib"
+file LIB_FILE => [SO_FILE, "lib"] do |t|
+  copy t.source, t.name
+end
+
 Rake::TestTask.new do |t|
   t.test_files = FileList["tests/test_*.rb"]
 end
index 3b54a4a1a52098ce40b6605d1b35dfe167efbe73..5e98b393b027972f5af9a227c568fe8758221211 100644 (file)
@@ -2,6 +2,9 @@ require 'mkmf'
 
 # need to use c++ compiler flags
 $CXXFLAGS << ' -std=c++11'
+
+$LDFLAGS << ' -lstdc++'
+
 # Set to true when building binary gems
 if enable_config('static-stdlib', false)
   $LDFLAGS << ' -static-libgcc -static-libstdc++'
@@ -12,34 +15,6 @@ if enable_config('march-tune-native', false)
   $CXXFLAGS << ' -march=native -mtune=native'
 end
 
-def with_disabling_unsupported_files
-  disabled_files = []
-
-  unless $GGML_METAL
-    disabled_files << 'ggml-metal.h' << 'ggml-metal.m'
-  end
-
-  unless $GGML_METAL_EMBED_LIBRARY
-    disabled_files << 'ggml-metal.metal'
-  end
-
-  unless $OBJ_ALL&.include? 'ggml-blas.o'
-    disabled_files << 'ggml-blas.h' << 'ggml-blas.cpp'
-  end
-
-  disabled_files.filter! {|file| File.exist? file}
-
-  disabled_files.each do |file|
-    File.rename file, "#{file}.disabled"
-  end
-
-  yield
-
-  disabled_files.each do |file|
-    File.rename "#{file}.disabled", file
-  end
-end
-
 if ENV['WHISPER_METAL']
   $GGML_METAL ||= true
   $DEPRECATE_WARNING ||= true
@@ -66,10 +41,10 @@ $MK_CXXFLAGS = '-std=c++11 -fPIC'
 $MK_NVCCFLAGS = '-std=c++11'
 $MK_LDFLAGS = ''
 
-$OBJ_GGML = ''
-$OBJ_WHISPER = ''
-$OBJ_COMMON = ''
-$OBJ_SDL = ''
+$OBJ_GGML = []
+$OBJ_WHISPER = []
+$OBJ_COMMON = []
+$OBJ_SDL = []
 
 $MK_CPPFLAGS << ' -D_XOPEN_SOURCE=600'
 
@@ -152,7 +127,7 @@ unless ENV['GGML_NO_ACCELERATE']
     $MK_CPPFLAGS << ' -DACCELERATE_NEW_LAPACK'
     $MK_CPPFLAGS << ' -DACCELERATE_LAPACK_ILP64'
     $MK_LDFLAGS  << ' -framework Accelerate'
-    $OBJ_GGML    << ' ggml-blas.o'
+    $OBJ_GGML    << 'ggml-blas.o'
   end
 end
 
@@ -160,20 +135,20 @@ if ENV['GGML_OPENBLAS']
   $MK_CPPFLAGS << " -DGGML_USE_BLAS #{`pkg-config --cflags-only-I openblas`.chomp}"
   $MK_CFLAGS   << " #{`pkg-config --cflags-only-other openblas)`.chomp}"
   $MK_LDFLAGS  << " #{`pkg-config --libs openblas`}"
-  $OBJ_GGML    << ' ggml-blas.o'
+  $OBJ_GGML    << 'ggml-blas.o'
 end
 
 if ENV['GGML_OPENBLAS64']
   $MK_CPPFLAGS << " -DGGML_USE_BLAS #{`pkg-config --cflags-only-I openblas64`.chomp}"
   $MK_CFLAGS   << " #{`pkg-config --cflags-only-other openblas64)`.chomp}"
   $MK_LDFLAGS  << " #{`pkg-config --libs openblas64`}"
-  $OBJ_GGML    << ' ggml-blas.o'
+  $OBJ_GGML    << 'ggml-blas.o'
 end
 
 if $GGML_METAL
   $MK_CPPFLAGS << ' -DGGML_USE_METAL'
   $MK_LDFLAGS  << ' -framework Foundation -framework Metal -framework MetalKit'
-  $OBJ_GGML    << ' ggml-metal.o'
+  $OBJ_GGML    << 'ggml-metal.o'
 
   if ENV['GGML_METAL_NDEBUG']
     $MK_CPPFLAGS << ' -DGGML_METAL_NDEBUG'
@@ -181,21 +156,22 @@ if $GGML_METAL
 
   if $GGML_METAL_EMBED_LIBRARY
     $MK_CPPFLAGS << ' -DGGML_METAL_EMBED_LIBRARY'
-    $OBJ_GGML    << ' ggml-metal-embed.o'
+    $OBJ_GGML    << 'ggml-metal-embed.o'
   end
 end
 
 $OBJ_GGML <<
-  ' ggml.o' <<
-  ' ggml-alloc.o' <<
-  ' ggml-backend.o' <<
-  ' ggml-quants.o' <<
-  ' ggml-aarch64.o'
+  'ggml.o' <<
+  'ggml-alloc.o' <<
+  'ggml-backend.o' <<
+  'ggml-quants.o' <<
+  'ggml-aarch64.o'
 
 $OBJ_WHISPER <<
-  ' whisper.o'
+  'whisper.o'
 
-$OBJ_ALL = "#{$OBJ_GGML} #{$OBJ_WHISPER} #{$OBJ_COMMON} #{$OBJ_SDL}"
+$objs = $OBJ_GGML + $OBJ_WHISPER + $OBJ_COMMON + $OBJ_SDL
+$objs << "ruby_whisper.o"
 
 $CPPFLAGS  = "#{$MK_CPPFLAGS} #{$CPPFLAGS}"
 $CFLAGS    = "#{$CPPFLAGS} #{$MK_CFLAGS} #{$GF_CFLAGS} #{$CFLAGS}"
@@ -204,26 +180,13 @@ $CXXFLAGS  = "#{$BASE_CXXFLAGS} #{$HOST_CXXFLAGS} #{$GF_CXXFLAGS} #{$CPPFLAGS}"
 $NVCCFLAGS = "#{$MK_NVCCFLAGS} #{$NVCCFLAGS}"
 $LDFLAGS   = "#{$MK_LDFLAGS} #{$LDFLAGS}"
 
-if $GGML_METAL_EMBED_LIBRARY
-  File.write 'depend', "$(OBJS): $(OBJS) ggml-metal-embed.o\n"
-end
-
-with_disabling_unsupported_files do
-
-  create_makefile('whisper')
-
-end
+create_makefile('whisper')
 
 File.open 'Makefile', 'a' do |file|
   file.puts 'include get-flags.mk'
 
   if $GGML_METAL
     if $GGML_METAL_EMBED_LIBRARY
-      # mkmf determines object files to compile dependent on existing *.{c,cpp,m} files
-      # but ggml-metal-embed.c doesn't exist on creating Makefile.
-      file.puts "objs := $(OBJS)"
-      file.puts "OBJS = $(objs) 'ggml-metal-embed.o'"
-
       file.puts 'include metal-embed.mk'
     end
   end
index 2c720e9814928d564cc693735bb7adad33223c01..3f528ee47901b0214a0bbfe78cfe6654efe13679 100644 (file)
@@ -41,6 +41,8 @@ static ID id_call;
 static ID id___method__;
 static ID id_to_enum;
 
+static bool is_log_callback_finalized = false;
+
 /*
  * call-seq:
  *   lang_max_id -> Integer
@@ -88,6 +90,39 @@ static VALUE ruby_whisper_s_lang_str_full(VALUE self, VALUE id) {
   return rb_str_new2(str_full);
 }
 
+static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) {
+  is_log_callback_finalized = true;
+  return Qnil;
+}
+
+/*
+ * call-seq:
+ *   log_set ->(level, buffer, user_data) { ... }, user_data -> nil
+ */
+static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) {
+  VALUE old_callback = rb_iv_get(self, "@log_callback");
+  if (!NIL_P(old_callback)) {
+    rb_undefine_finalizer(old_callback);
+  }
+
+  rb_iv_set(self, "@log_callback", log_callback);
+  rb_iv_set(self, "@user_data", user_data);
+
+  VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback"));
+  rb_define_finalizer(log_callback, finalize_log_callback);
+
+  whisper_log_set([](ggml_log_level level, const char * buffer, void * user_data) {
+    if (is_log_callback_finalized) {
+      return;
+    }
+    VALUE log_callback = rb_iv_get(mWhisper, "@log_callback");
+    VALUE udata = rb_iv_get(mWhisper, "@user_data");
+    rb_funcall(log_callback, id_call, 3, INT2NUM(level), rb_str_new2(buffer), udata);
+  }, nullptr);
+
+  return Qnil;
+}
+
 static void ruby_whisper_free(ruby_whisper *rw) {
   if (rw->context) {
     whisper_free(rw->context);
@@ -389,6 +424,126 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
   return self;
 }
 
+/*
+ * call-seq:
+ *   model_n_vocab -> Integer
+ */
+VALUE ruby_whisper_model_n_vocab(VALUE self) {
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_vocab(rw->context));
+}
+
+/*
+ * call-seq:
+ *   model_n_audio_ctx -> Integer
+ */
+VALUE ruby_whisper_model_n_audio_ctx(VALUE self) {
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_audio_ctx(rw->context));
+}
+
+/*
+ * call-seq:
+ *   model_n_audio_state -> Integer
+ */
+VALUE ruby_whisper_model_n_audio_state(VALUE self) {
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_audio_state(rw->context));
+}
+
+/*
+ * call-seq:
+ *   model_n_audio_head -> Integer
+ */
+VALUE ruby_whisper_model_n_audio_head(VALUE self) {
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_audio_head(rw->context));
+}
+
+/*
+ * call-seq:
+ *   model_n_audio_layer -> Integer
+ */
+VALUE ruby_whisper_model_n_audio_layer(VALUE self) {
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_audio_layer(rw->context));
+}
+
+/*
+ * call-seq:
+ *   model_n_text_ctx -> Integer
+ */
+VALUE ruby_whisper_model_n_text_ctx(VALUE self) {
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_text_ctx(rw->context));
+}
+
+/*
+ * call-seq:
+ *   model_n_text_state -> Integer
+ */
+VALUE ruby_whisper_model_n_text_state(VALUE self) {
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_text_state(rw->context));
+}
+
+/*
+ * call-seq:
+ *   model_n_text_head -> Integer
+ */
+VALUE ruby_whisper_model_n_text_head(VALUE self) {
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_text_head(rw->context));
+}
+
+/*
+ * call-seq:
+ *   model_n_text_layer -> Integer
+ */
+VALUE ruby_whisper_model_n_text_layer(VALUE self) {
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_text_layer(rw->context));
+}
+
+/*
+ * call-seq:
+ *   model_n_mels -> Integer
+ */
+VALUE ruby_whisper_model_n_mels(VALUE self) {
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_mels(rw->context));
+}
+
+/*
+ * call-seq:
+ *   model_ftype -> Integer
+ */
+VALUE ruby_whisper_model_ftype(VALUE self) {
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  return INT2NUM(whisper_model_ftype(rw->context));
+}
+
+/*
+ * call-seq:
+ *   model_type -> String
+ */
+VALUE ruby_whisper_model_type(VALUE self) {
+  ruby_whisper *rw;
+  Data_Get_Struct(self, ruby_whisper, rw);
+  return rb_str_new2(whisper_model_type_readable(rw->context));
+}
+
 /*
  * Number of segments.
  *
@@ -1015,7 +1170,12 @@ typedef struct {
   int index;
 } ruby_whisper_segment;
 
+typedef struct {
+  VALUE context;
+} ruby_whisper_model;
+
 VALUE cSegment;
+VALUE cModel;
 
 static void rb_whisper_segment_mark(ruby_whisper_segment *rws) {
   rb_gc_mark(rws->context);
@@ -1188,6 +1348,176 @@ static VALUE ruby_whisper_segment_get_text(VALUE self) {
   return rb_str_new2(text);
 }
 
+static void rb_whisper_model_mark(ruby_whisper_model *rwm) {
+  rb_gc_mark(rwm->context);
+}
+
+static VALUE ruby_whisper_model_allocate(VALUE klass) {
+  ruby_whisper_model *rwm;
+  rwm = ALLOC(ruby_whisper_model);
+  return Data_Wrap_Struct(klass, rb_whisper_model_mark, RUBY_DEFAULT_FREE, rwm);
+}
+
+static VALUE rb_whisper_model_initialize(VALUE context) {
+  ruby_whisper_model *rwm;
+  const VALUE model = ruby_whisper_model_allocate(cModel);
+  Data_Get_Struct(model, ruby_whisper_model, rwm);
+  rwm->context = context;
+  return model;
+};
+
+/*
+ * call-seq:
+ *   model -> Whisper::Model
+ */
+static VALUE ruby_whisper_get_model(VALUE self) {
+  return rb_whisper_model_initialize(self);
+}
+
+/*
+ * call-seq:
+ *   n_vocab -> Integer
+ */
+static VALUE ruby_whisper_c_model_n_vocab(VALUE self) {
+  ruby_whisper_model *rwm;
+  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  ruby_whisper *rw;
+  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_vocab(rw->context));
+}
+
+/*
+ * call-seq:
+ *   n_audio_ctx -> Integer
+ */
+static VALUE ruby_whisper_c_model_n_audio_ctx(VALUE self) {
+  ruby_whisper_model *rwm;
+  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  ruby_whisper *rw;
+  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_audio_ctx(rw->context));
+}
+
+/*
+ * call-seq:
+ *   n_audio_state -> Integer
+ */
+static VALUE ruby_whisper_c_model_n_audio_state(VALUE self) {
+  ruby_whisper_model *rwm;
+  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  ruby_whisper *rw;
+  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_audio_state(rw->context));
+}
+
+/*
+ * call-seq:
+ *   n_audio_head -> Integer
+ */
+static VALUE ruby_whisper_c_model_n_audio_head(VALUE self) {
+  ruby_whisper_model *rwm;
+  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  ruby_whisper *rw;
+  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_audio_head(rw->context));
+}
+
+/*
+ * call-seq:
+ *   n_audio_layer -> Integer
+ */
+static VALUE ruby_whisper_c_model_n_audio_layer(VALUE self) {
+  ruby_whisper_model *rwm;
+  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  ruby_whisper *rw;
+  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_audio_layer(rw->context));
+}
+
+/*
+ * call-seq:
+ *   n_text_ctx -> Integer
+ */
+static VALUE ruby_whisper_c_model_n_text_ctx(VALUE self) {
+  ruby_whisper_model *rwm;
+  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  ruby_whisper *rw;
+  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_text_ctx(rw->context));
+}
+
+/*
+ * call-seq:
+ *   n_text_state -> Integer
+ */
+static VALUE ruby_whisper_c_model_n_text_state(VALUE self) {
+  ruby_whisper_model *rwm;
+  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  ruby_whisper *rw;
+  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_text_state(rw->context));
+}
+
+/*
+ * call-seq:
+ *   n_text_head -> Integer
+ */
+static VALUE ruby_whisper_c_model_n_text_head(VALUE self) {
+  ruby_whisper_model *rwm;
+  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  ruby_whisper *rw;
+  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_text_head(rw->context));
+}
+
+/*
+ * call-seq:
+ *   n_text_layer -> Integer
+ */
+static VALUE ruby_whisper_c_model_n_text_layer(VALUE self) {
+  ruby_whisper_model *rwm;
+  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  ruby_whisper *rw;
+  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_text_layer(rw->context));
+}
+
+/*
+ * call-seq:
+ *   n_mels -> Integer
+ */
+static VALUE ruby_whisper_c_model_n_mels(VALUE self) {
+  ruby_whisper_model *rwm;
+  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  ruby_whisper *rw;
+  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  return INT2NUM(whisper_model_n_mels(rw->context));
+}
+
+/*
+ * call-seq:
+ *   ftype -> Integer
+ */
+static VALUE ruby_whisper_c_model_ftype(VALUE self) {
+  ruby_whisper_model *rwm;
+  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  ruby_whisper *rw;
+  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  return INT2NUM(whisper_model_ftype(rw->context));
+}
+
+/*
+ * call-seq:
+ *   type -> String
+ */
+static VALUE ruby_whisper_c_model_type(VALUE self) {
+  ruby_whisper_model *rwm;
+  Data_Get_Struct(self, ruby_whisper_model, rwm);
+  ruby_whisper *rw;
+  Data_Get_Struct(rwm->context, ruby_whisper, rw);
+  return rb_str_new2(whisper_model_type_readable(rw->context));
+}
+
 void Init_whisper() {
   id_to_s = rb_intern("to_s");
   id_call = rb_intern("call");
@@ -1198,15 +1528,36 @@ void Init_whisper() {
   cContext = rb_define_class_under(mWhisper, "Context", rb_cObject);
   cParams  = rb_define_class_under(mWhisper, "Params", rb_cObject);
 
+  rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE));
+  rb_define_const(mWhisper, "LOG_LEVEL_INFO", INT2NUM(GGML_LOG_LEVEL_INFO));
+  rb_define_const(mWhisper, "LOG_LEVEL_WARN", INT2NUM(GGML_LOG_LEVEL_WARN));
+  rb_define_const(mWhisper, "LOG_LEVEL_ERROR", INT2NUM(GGML_LOG_LEVEL_ERROR));
+  rb_define_const(mWhisper, "LOG_LEVEL_DEBUG", INT2NUM(GGML_LOG_LEVEL_DEBUG));
+  rb_define_const(mWhisper, "LOG_LEVEL_CONT", INT2NUM(GGML_LOG_LEVEL_CONT));
+
   rb_define_singleton_method(mWhisper, "lang_max_id", ruby_whisper_s_lang_max_id, 0);
   rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1);
   rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1);
   rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1);
+  rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2);
+  rb_define_singleton_method(mWhisper, "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1);
 
   rb_define_alloc_func(cContext, ruby_whisper_allocate);
   rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1);
 
   rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1);
+  rb_define_method(cContext, "model_n_vocab", ruby_whisper_model_n_vocab, 0);
+  rb_define_method(cContext, "model_n_audio_ctx", ruby_whisper_model_n_audio_ctx, 0);
+  rb_define_method(cContext, "model_n_audio_state", ruby_whisper_model_n_audio_state, 0);
+  rb_define_method(cContext, "model_n_audio_head", ruby_whisper_model_n_audio_head, 0);
+  rb_define_method(cContext, "model_n_audio_layer", ruby_whisper_model_n_audio_layer, 0);
+  rb_define_method(cContext, "model_n_text_ctx", ruby_whisper_model_n_text_ctx, 0);
+  rb_define_method(cContext, "model_n_text_state", ruby_whisper_model_n_text_state, 0);
+  rb_define_method(cContext, "model_n_text_head", ruby_whisper_model_n_text_head, 0);
+  rb_define_method(cContext, "model_n_text_layer", ruby_whisper_model_n_text_layer, 0);
+  rb_define_method(cContext, "model_n_mels", ruby_whisper_model_n_mels, 0);
+  rb_define_method(cContext, "model_ftype", ruby_whisper_model_ftype, 0);
+  rb_define_method(cContext, "model_type", ruby_whisper_model_type, 0);
   rb_define_method(cContext, "full_n_segments", ruby_whisper_full_n_segments, 0);
   rb_define_method(cContext, "full_lang_id", ruby_whisper_full_lang_id, 0);
   rb_define_method(cContext, "full_get_segment_t0", ruby_whisper_full_get_segment_t0, 1);
@@ -1284,6 +1635,22 @@ void Init_whisper() {
   rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0);
   rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0);
   rb_define_method(cSegment, "text", ruby_whisper_segment_get_text, 0);
+
+  cModel = rb_define_class_under(mWhisper, "Model", rb_cObject);
+  rb_define_alloc_func(cModel, ruby_whisper_model_allocate);
+  rb_define_method(cContext, "model", ruby_whisper_get_model, 0);
+  rb_define_method(cModel, "n_vocab", ruby_whisper_c_model_n_vocab, 0);
+  rb_define_method(cModel, "n_audio_ctx", ruby_whisper_c_model_n_audio_ctx, 0);
+  rb_define_method(cModel, "n_audio_state", ruby_whisper_c_model_n_audio_state, 0);
+  rb_define_method(cModel, "n_audio_head", ruby_whisper_c_model_n_audio_head, 0);
+  rb_define_method(cModel, "n_audio_layer", ruby_whisper_c_model_n_audio_layer, 0);
+  rb_define_method(cModel, "n_text_ctx", ruby_whisper_c_model_n_text_ctx, 0);
+  rb_define_method(cModel, "n_text_state", ruby_whisper_c_model_n_text_state, 0);
+  rb_define_method(cModel, "n_text_head", ruby_whisper_c_model_n_text_head, 0);
+  rb_define_method(cModel, "n_text_layer", ruby_whisper_c_model_n_text_layer, 0);
+  rb_define_method(cModel, "n_mels", ruby_whisper_c_model_n_mels, 0);
+  rb_define_method(cModel, "ftype", ruby_whisper_c_model_ftype, 0);
+  rb_define_method(cModel, "type", ruby_whisper_c_model_type, 0);
 }
 #ifdef __cplusplus
 }
diff --git a/bindings/ruby/tests/helper.rb b/bindings/ruby/tests/helper.rb
new file mode 100644 (file)
index 0000000..4172ebc
--- /dev/null
@@ -0,0 +1,7 @@
+require "test/unit"
+require "whisper"
+
+class TestBase < Test::Unit::TestCase
+  MODEL = File.join(__dir__, "..", "..", "..", "models", "ggml-base.en.bin")
+  AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")
+end
diff --git a/bindings/ruby/tests/test_model.rb b/bindings/ruby/tests/test_model.rb
new file mode 100644 (file)
index 0000000..2310522
--- /dev/null
@@ -0,0 +1,44 @@
+require_relative "helper"
+
+class TestModel < TestBase
+  def test_model
+    whisper = Whisper::Context.new(MODEL)
+    assert_instance_of Whisper::Model, whisper.model
+  end
+
+  def test_attributes
+    whisper = Whisper::Context.new(MODEL)
+    model = whisper.model
+
+    assert_equal 51864, model.n_vocab
+    assert_equal 1500, model.n_audio_ctx
+    assert_equal 512, model.n_audio_state
+    assert_equal 8, model.n_audio_head
+    assert_equal 6, model.n_audio_layer
+    assert_equal 448, model.n_text_ctx
+    assert_equal 512, model.n_text_state
+    assert_equal 8, model.n_text_head
+    assert_equal 6, model.n_text_layer
+    assert_equal 80, model.n_mels
+    assert_equal 1, model.ftype
+    assert_equal "base", model.type
+  end
+
+  def test_gc
+    model = Whisper::Context.new(MODEL).model
+    GC.start
+
+    assert_equal 51864, model.n_vocab
+    assert_equal 1500, model.n_audio_ctx
+    assert_equal 512, model.n_audio_state
+    assert_equal 8, model.n_audio_head
+    assert_equal 6, model.n_audio_layer
+    assert_equal 448, model.n_text_ctx
+    assert_equal 512, model.n_text_state
+    assert_equal 8, model.n_text_head
+    assert_equal 6, model.n_text_layer
+    assert_equal 80, model.n_mels
+    assert_equal 1, model.ftype
+    assert_equal "base", model.type
+  end
+end
index f51eab575d6e0a5d1a21057ec2cc555f966a4471..9c47870ef28cce56954630e7d5ca70ab7532d873 100644 (file)
@@ -1,9 +1,9 @@
-require 'test/unit'
+require_relative "helper"
 require 'tempfile'
 require 'tmpdir'
 require 'shellwords'
 
-class TestPackage < Test::Unit::TestCase
+class TestPackage < TestBase
   def test_build
     Tempfile.create do |file|
       assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true)
index 6386049620a9447d4a0a1a0c1cf5ca053f089b76..bf73fd6b29b150d704935484e2f25b668064e3c9 100644 (file)
@@ -1,7 +1,6 @@
-require 'test/unit'
-require 'whisper'
+require_relative "helper"
 
-class TestParams < Test::Unit::TestCase
+class TestParams < TestBase
   def setup
     @params  = Whisper::Params.new
   end
index f3ebc0e9c787b654d32ff70528d9cdb4385c80d9..8129ae5db42faf1a01eaa129b9da896890a51287 100644 (file)
@@ -1,18 +1,14 @@
-require "test/unit"
-require "whisper"
-
-class TestSegment < Test::Unit::TestCase
-  TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
+require_relative "helper"
 
+class TestSegment < TestBase
   class << self
     attr_reader :whisper
 
     def startup
-      @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
+      @whisper = Whisper::Context.new(TestBase::MODEL)
       params = Whisper::Params.new
       params.print_timestamps = false
-      jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
-      @whisper.transcribe(jfk, params)
+      @whisper.transcribe(TestBase::AUDIO, params)
     end
   end
 
@@ -60,7 +56,7 @@ class TestSegment < Test::Unit::TestCase
       end
       index += 1
     end
-    whisper.transcribe(File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav'), params)
+    whisper.transcribe(AUDIO, params)
     assert_equal 0, seg.start_time
     assert_match /ask not what your country can do for you, ask what you can do for your country/, seg.text
   end
@@ -76,7 +72,7 @@ class TestSegment < Test::Unit::TestCase
       assert_same seg, segment
       return
     end
-    whisper.transcribe(File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav'), params)
+    whisper.transcribe(AUDIO, params)
   end
 
   private
index 5ebb8151c659062c96cdf5c27e2375a92f775dac..e37e24c64c4433168b19a82188e0193e1bfdc595 100644 (file)
@@ -1,20 +1,20 @@
-require 'whisper'
-require 'test/unit'
+require_relative "helper"
+require "stringio"
 
-class TestWhisper < Test::Unit::TestCase
-  TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
+# Exists to detect memory-related bug
+Whisper.log_set ->(level, buffer, user_data) {}, nil
 
+class TestWhisper < TestBase
   def setup
     @params  = Whisper::Params.new
   end
 
   def test_whisper
-    @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
+    @whisper = Whisper::Context.new(MODEL)
     params  = Whisper::Params.new
     params.print_timestamps = false
 
-    jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
-    @whisper.transcribe(jfk, params) {|text|
+    @whisper.transcribe(AUDIO, params) {|text|
       assert_match /ask not what your country can do for you, ask what you can do for your country/, text
     }
   end
@@ -24,11 +24,10 @@ class TestWhisper < Test::Unit::TestCase
       attr_reader :whisper
 
       def startup
-        @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
+        @whisper = Whisper::Context.new(TestBase::MODEL)
         params = Whisper::Params.new
         params.print_timestamps = false
-        jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
-        @whisper.transcribe(jfk, params)
+        @whisper.transcribe(TestBase::AUDIO, params)
       end
     end
 
@@ -96,4 +95,33 @@ class TestWhisper < Test::Unit::TestCase
       Whisper.lang_str_full(Whisper.lang_max_id + 1)
     end
   end
+
+  def test_log_set
+    user_data = Object.new
+    logs = []
+    log_callback = ->(level, buffer, udata) {
+      logs << [level, buffer, udata]
+    }
+    Whisper.log_set log_callback, user_data
+    Whisper::Context.new(MODEL)
+
+    assert logs.length > 30
+    logs.each do |log|
+      assert_equal Whisper::LOG_LEVEL_INFO, log[0]
+      assert_same user_data, log[2]
+    end
+  end
+
+  def test_log_suppress
+    stderr = $stderr
+    Whisper.log_set ->(level, buffer, user_data) {
+      # do nothing
+    }, nil
+    dev = StringIO.new("")
+    $stderr = dev
+    Whisper::Context.new(MODEL)
+    assert_empty dev.string
+  ensure
+    $stderr = stderr
+  end
 end