]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ruby : Fix of C++ header guard name, model URI support, type signature and more ...
authorKITAITI Makoto <redacted>
Mon, 30 Dec 2024 12:26:35 +0000 (21:26 +0900)
committerGitHub <redacted>
Mon, 30 Dec 2024 12:26:35 +0000 (14:26 +0200)
* Add test to make Whisper::Context.new accept URI string

* Add test to make Whisper::Context.new accept URI

* Make Whisper::Context.new accept URI string and URI

* Update README

Revert "Fix argument of rb_undefine_finalizer"

* Fix typos

* Add type signature file

* Assign literarl to const variable

* Load Whisper::Model::URI from Init_whisper

* Simplify .gitignore

* Don't load whisper.so from whisper/model/uri.rb

* Use each_with_object instead of each

* Add Development section to README

* Rename header guard to conform to C++ naming convention

bindings/ruby/.gitignore
bindings/ruby/README.md
bindings/ruby/ext/ruby_whisper.cpp
bindings/ruby/ext/ruby_whisper.h
bindings/ruby/lib/whisper.rb [deleted file]
bindings/ruby/lib/whisper/model/uri.rb
bindings/ruby/sig/whisper.rbs [new file with mode: 0644]
bindings/ruby/tests/test_model.rb

index 6e3b3be0e24f0beff3c51d25d21233a1cfaeef60..e04a90a9c69b97a7b52f065ef1d563722591ae2a 100644 (file)
@@ -1,5 +1,3 @@
 LICENSE
 pkg/
-lib/whisper.so
-lib/whisper.bundle
-lib/whisper.dll
+lib/whisper.*
index 8492e4ed91b02ba05a7e8381ec9b7604aa603733..13ff1f00ad16e7d13277242857c849d3785424b3 100644 (file)
@@ -60,10 +60,10 @@ You also can use shorthand for pre-converted models:
 whisper = Whisper::Context.new("base.en")
 ```
 
-You can see the list of prepared model names by `Whisper::Model.preconverted_models.keys`:
+You can see the list of prepared model names by `Whisper::Model.pre_converted_models.keys`:
 
 ```ruby
-puts Whisper::Model.preconverted_models.keys
+puts Whisper::Model.pre_converted_models.keys
 # tiny
 # tiny.en
 # tiny-q5_1
@@ -87,8 +87,9 @@ whisper = Whisper::Context.new("path/to/your/model.bin")
 Or, you can download model files:
 
 ```ruby
-model_uri = Whisper::Model::URI.new("http://example.net/uri/of/your/model.bin")
-whisper = Whisper::Context.new(model_uri)
+whisper = Whisper::Context.new("https://example.net/uri/of/your/model.bin")
+# Or
+whisper = Whisper::Context.new(URI("https://example.net/uri/of/your/model.bin"))
 ```
 
 See [models][] page for details.
@@ -222,6 +223,17 @@ end
 
 The second argument `samples` may be an array, an object with `length` and `each` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy.
 
+Development
+-----------
+
+    % git clone https://github.com/ggerganov/whisper.cpp.git
+    % cd whisper.cpp/bindings/ruby
+    % rake test
+
+First call of `rake test` builds an extension and downloads a model for testing. After that, you add tests in `tests` directory and modify `ext/ruby_whisper.cpp`.
+
+If something seems wrong on build, running `rake clean` solves some cases.
+
 License
 -------
 
index 88a4fd2c205509fcc685bfafa3e8342f5bc9d346..5979f208ec95cfdef2da3bd61cbc3e76307e25cc 100644 (file)
@@ -49,6 +49,7 @@ static ID id_length;
 static ID id_next;
 static ID id_new;
 static ID id_to_path;
+static ID id_URI;
 static ID id_pre_converted_models;
 
 static bool is_log_callback_finalized = false;
@@ -283,6 +284,17 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
   if (!NIL_P(pre_converted_model)) {
     whisper_model_file_path = pre_converted_model;
   }
+  if (TYPE(whisper_model_file_path) == T_STRING) {
+    const char * whisper_model_file_path_str = StringValueCStr(whisper_model_file_path);
+    if (strncmp("http://", whisper_model_file_path_str, 7) == 0 || strncmp("https://", whisper_model_file_path_str, 8) == 0) {
+      VALUE uri_class = rb_const_get(cModel, id_URI);
+      whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class);
+    }
+  }
+  if (rb_obj_is_kind_of(whisper_model_file_path, rb_path2class("URI::HTTP"))) {
+    VALUE uri_class = rb_const_get(cModel, id_URI);
+    whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class);
+  }
   if (rb_respond_to(whisper_model_file_path, id_to_path)) {
     whisper_model_file_path = rb_funcall(whisper_model_file_path, id_to_path, 0);
   }
@@ -837,7 +849,7 @@ static VALUE ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment) {
 
 /*
  * call-seq:
- *   full_get_segment_no_speech_prob -> Float
+ *   full_get_segment_no_speech_prob(segment_index) -> Float
  */
 static VALUE ruby_whisper_full_get_segment_no_speech_prob(VALUE self, VALUE i_segment) {
   ruby_whisper *rw;
@@ -1755,7 +1767,7 @@ static VALUE ruby_whisper_c_model_type(VALUE self) {
 
 static VALUE ruby_whisper_error_initialize(VALUE self, VALUE code) {
   const int c_code = NUM2INT(code);
-  char *raw_message;
+  const char *raw_message;
   switch (c_code) {
   case -2:
     raw_message = "failed to compute log mel spectrogram";
@@ -1802,6 +1814,7 @@ void Init_whisper() {
   id_next = rb_intern("next");
   id_new = rb_intern("new");
   id_to_path = rb_intern("to_path");
+  id_URI = rb_intern("URI");
   id_pre_converted_models = rb_intern("pre_converted_models");
 
   mWhisper = rb_define_module("Whisper");
@@ -1941,6 +1954,8 @@ void Init_whisper() {
   rb_define_method(cModel, "n_mels", ruby_whisper_c_model_n_mels, 0);
   rb_define_method(cModel, "ftype", ruby_whisper_c_model_ftype, 0);
   rb_define_method(cModel, "type", ruby_whisper_c_model_type, 0);
+
+  rb_require("whisper/model/uri");
 }
 #ifdef __cplusplus
 }
index a6771038e6f59b6feee692de83411ada4b2ee28e..21e36c491cf4128c2fa54fe062ecff8b4a8115ce 100644 (file)
@@ -1,5 +1,5 @@
-#ifndef __RUBY_WHISPER_H
-#define __RUBY_WHISPER_H
+#ifndef RUBY_WHISPER_H
+#define RUBY_WHISPER_H
 
 #include "whisper.h"
 
diff --git a/bindings/ruby/lib/whisper.rb b/bindings/ruby/lib/whisper.rb
deleted file mode 100644 (file)
index 3a0b844..0000000
+++ /dev/null
@@ -1,2 +0,0 @@
-require "whisper.so"
-require "whisper/model/uri"
index fe5ed56b3fbde9f8d8273caee5745081e7808c21..b43d90dd48621c1bce170c34e195300f1c0470fa 100644 (file)
-require "whisper.so"
 require "uri"
 require "net/http"
 require "time"
 require "pathname"
 require "io/console/size"
 
-class Whisper::Model
-  class URI
-    def initialize(uri)
-      @uri = URI(uri)
-    end
+module Whisper
+  class Model
+    class URI
+      def initialize(uri)
+        @uri = URI(uri)
+      end
 
-    def to_path
-      cache
-      cache_path.to_path
-    end
+      def to_path
+        cache
+        cache_path.to_path
+      end
 
-    def clear_cache
-      path = cache_path
-      path.delete if path.exist?
-    end
+      def clear_cache
+        path = cache_path
+        path.delete if path.exist?
+      end
 
-    private
+      private
 
-    def cache_path
-      base_cache_dir/@uri.host/@uri.path[1..]
-    end
+      def cache_path
+        base_cache_dir/@uri.host/@uri.path[1..]
+      end
 
-    def base_cache_dir
-      base = case RUBY_PLATFORM
-             when /mswin|mingw/
-               ENV.key?("LOCALAPPDATA") ? Pathname(ENV["LOCALAPPDATA"]) : Pathname(Dir.home)/"AppData/Local"
-             when /darwin/
-               Pathname(Dir.home)/"Library/Caches"
-             else
-               ENV.key?("XDG_CACHE_HOME") ? ENV["XDG_CACHE_HOME"] : Pathname(Dir.home)/".cache"
-             end
-      base/"whisper.cpp"
-    end
+      def base_cache_dir
+        base = case RUBY_PLATFORM
+               when /mswin|mingw/
+                 ENV.key?("LOCALAPPDATA") ? Pathname(ENV["LOCALAPPDATA"]) : Pathname(Dir.home)/"AppData/Local"
+               when /darwin/
+                 Pathname(Dir.home)/"Library/Caches"
+               else
+                 ENV.key?("XDG_CACHE_HOME") ? ENV["XDG_CACHE_HOME"] : Pathname(Dir.home)/".cache"
+               end
+        base/"whisper.cpp"
+      end
 
-    def cache
-      path = cache_path
-      headers = {}
-      headers["if-modified-since"] = path.mtime.httpdate if path.exist?
-      request @uri, headers
-      path
-    end
+      def cache
+        path = cache_path
+        headers = {}
+        headers["if-modified-since"] = path.mtime.httpdate if path.exist?
+        request @uri, headers
+        path
+      end
 
-    def request(uri, headers)
-      Net::HTTP.start uri.host, uri.port, use_ssl: uri.scheme == "https" do |http|
-        request = Net::HTTP::Get.new(uri, headers)
-        http.request request do |response|
-          case response
-          when Net::HTTPNotModified
+      def request(uri, headers)
+        Net::HTTP.start uri.host, uri.port, use_ssl: uri.scheme == "https" do |http|
+          request = Net::HTTP::Get.new(uri, headers)
+          http.request request do |response|
+            case response
+            when Net::HTTPNotModified
             # noop
-          when Net::HTTPOK
-            download response
-          when Net::HTTPRedirection
-            request URI(response["location"]), headers
-          else
-            return if headers.key?("if-modified-since") # Use cache file
-
-            raise "#{response.code} #{response.message}\n#{response.body}"
+            when Net::HTTPOK
+              download response
+            when Net::HTTPRedirection
+              request URI(response["location"]), headers
+            else
+              return if headers.key?("if-modified-since") # Use cache file
+
+              raise "#{response.code} #{response.message}\n#{response.body}"
+            end
           end
         end
       end
-    end
 
-    def download(response)
-      path = cache_path
-      path.dirname.mkpath unless path.dirname.exist?
-      downloading_path = Pathname("#{path}.downloading")
-      size = response.content_length
-      downloading_path.open "wb" do |file|
-        downloaded = 0
-        response.read_body do |chunk|
-          file << chunk
-          downloaded += chunk.bytesize
-          show_progress downloaded, size
+      def download(response)
+        path = cache_path
+        path.dirname.mkpath unless path.dirname.exist?
+        downloading_path = Pathname("#{path}.downloading")
+        size = response.content_length
+        downloading_path.open "wb" do |file|
+          downloaded = 0
+          response.read_body do |chunk|
+            file << chunk
+            downloaded += chunk.bytesize
+            show_progress downloaded, size
+          end
+          $stderr.puts
         end
-        $stderr.puts
+        downloading_path.rename path
       end
-      downloading_path.rename path
-    end
 
-    def show_progress(current, size)
-      progress_rate_available = size && $stderr.tty?
+      def show_progress(current, size)
+        progress_rate_available = size && $stderr.tty?
 
-      unless @prev
-        @prev = Time.now
-        $stderr.puts "Downloading #{@uri} to #{cache_path}"
-      end
+        unless @prev
+          @prev = Time.now
+          $stderr.puts "Downloading #{@uri} to #{cache_path}"
+        end
 
-      now = Time.now
+        now = Time.now
 
-      if progress_rate_available
-        return if now - @prev < 1 && current < size
+        if progress_rate_available
+          return if now - @prev < 1 && current < size
 
-        progress_width = 20
-        progress = current.to_f / size
-        arrow_length = progress * progress_width
-        arrow = "=" * (arrow_length - 1) + ">" + " " * (progress_width - arrow_length)
-        line = "[#{arrow}] (#{format_bytesize(current)} / #{format_bytesize(size)})"
-        padding = ' ' * ($stderr.winsize[1] - line.size)
-        $stderr.print "\r#{line}#{padding}"
-      else
-        return if now - @prev < 1
+          progress_width = 20
+          progress = current.to_f / size
+          arrow_length = progress * progress_width
+          arrow = "=" * (arrow_length - 1) + ">" + " " * (progress_width - arrow_length)
+          line = "[#{arrow}] (#{format_bytesize(current)} / #{format_bytesize(size)})"
+          padding = ' ' * ($stderr.winsize[1] - line.size)
+          $stderr.print "\r#{line}#{padding}"
+        else
+          return if now - @prev < 1
 
-        $stderr.print "."
+          $stderr.print "."
+        end
+        @prev = now
       end
-      @prev = now
-    end
 
-    def format_bytesize(bytesize)
-      return "0.0 B" if bytesize.zero?
+      def format_bytesize(bytesize)
+        return "0.0 B" if bytesize.zero?
 
-      units = %w[B KiB MiB GiB TiB]
-      exp = (Math.log(bytesize) / Math.log(1024)).to_i
-      format("%.1f %s", bytesize.to_f / 1024 ** exp, units[exp])
+        units = %w[B KiB MiB GiB TiB]
+        exp = (Math.log(bytesize) / Math.log(1024)).to_i
+        format("%.1f %s", bytesize.to_f / 1024 ** exp, units[exp])
+      end
     end
-  end
 
-  @pre_converted_models = {}
-  %w[
-    tiny
-    tiny.en
-    tiny-q5_1
-    tiny.en-q5_1
-    tiny-q8_0
-    base
-    base.en
-    base-q5_1
-    base.en-q5_1
-    base-q8_0
-    small
-    small.en
-    small.en-tdrz
-    small-q5_1
-    small.en-q5_1
-    small-q8_0
-    medium
-    medium.en
-    medium-q5_0
-    medium.en-q5_0
-    medium-q8_0
-    large-v1
-    large-v2
-    large-v2-q5_0
-    large-v2-q8_0
-    large-v3
-    large-v3-q5_0
-    large-v3-turbo
-    large-v3-turbo-q5_0
-    large-v3-turbo-q8_0
-  ].each do |name|
-    @pre_converted_models[name] = URI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}.bin")
-  end
-
-  class << self
-    attr_reader :pre_converted_models
+    @pre_converted_models = %w[
+      tiny
+      tiny.en
+      tiny-q5_1
+      tiny.en-q5_1
+      tiny-q8_0
+      base
+      base.en
+      base-q5_1
+      base.en-q5_1
+      base-q8_0
+      small
+      small.en
+      small.en-tdrz
+      small-q5_1
+      small.en-q5_1
+      small-q8_0
+      medium
+      medium.en
+      medium-q5_0
+      medium.en-q5_0
+      medium-q8_0
+      large-v1
+      large-v2
+      large-v2-q5_0
+      large-v2-q8_0
+      large-v3
+      large-v3-q5_0
+      large-v3-turbo
+      large-v3-turbo-q5_0
+      large-v3-turbo-q8_0
+    ].each_with_object({}) {|name, models|
+      models[name] = URI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}.bin")
+    }
+
+    class << self
+      attr_reader :pre_converted_models
+    end
   end
 end
diff --git a/bindings/ruby/sig/whisper.rbs b/bindings/ruby/sig/whisper.rbs
new file mode 100644 (file)
index 0000000..aff2ae7
--- /dev/null
@@ -0,0 +1,153 @@
+module Whisper
+  interface _Samples
+    def length: () -> Integer
+    def each: { (Float) -> void } -> void
+  end
+
+  type log_callback = ^(Integer level, String message, Object user_data) -> void
+  type new_segment_callback = ^(Whisper::Context, void, Integer n_new, Object user_data) -> void
+  type progress_callback = ^(Whisper::Context, void, Integer progress, Object user_data) -> void
+  type abort_callback = ^(Whisper::Context, void, Object user_data) -> boolish
+
+  LOG_LEVEL_NONE: Integer
+  LOG_LEVEL_INFO: Integer
+  LOG_LEVEL_WARN: Integer
+  LOG_LEVEL_ERROR: Integer
+  LOG_LEVEL_DEBUG: Integer
+  LOG_LEVEL_CONT: Integer
+
+  def self.lang_max_id: () -> Integer
+  def self.lang_id: (string name) -> Integer
+  def self.lang_str: (Integer id) -> String
+  def self.lang_str_full: (Integer id) -> String
+  def self.log_set=: (log_callback) -> log_callback
+  def self.finalize_log_callback: (void) -> void # Second argument of ObjectSpace.define_finalizer
+
+  class Context
+    def initialize: (string | _ToPath | ::URI::HTTP ) -> void
+    def transcribe: (string, Params) -> void
+                  | (string, Params) { (String) -> void } -> void
+    def model_n_vocab: () -> Integer
+    def model_n_audio_ctx: () -> Integer
+    def model_n_audio_state: () -> Integer
+    def model_n_text_head: () -> Integer
+    def model_n_text_layer: () -> Integer
+    def model_n_mels: () -> Integer
+    def model_ftype: () -> Integer
+    def model_type: () -> String
+    def full_n_segments: () -> Integer
+    def full_lang_id: () -> Integer
+    def full_get_segment_t0: (Integer) -> Integer
+    def full_get_segment_t1: (Integer) -> Integer
+    def full_get_segment_speaker_turn_next: (Integer) -> (true | false)
+    def full_get_segment_text: (Integer) -> String
+    def full_get_segment_no_speech_prob: (Integer) -> Float
+    def full: (Params, Array[Float], ?Integer) -> void
+            | (Params, _Samples, ?Integer) -> void
+    def full_parallel: (Params, Array[Float], ?Integer) -> void
+                     | (Params, _Samples, ?Integer) -> void
+                     | (Params, _Samples, ?Integer?, Integer) -> void
+    def each_segment: { (Segment) -> void } -> void
+                    | () -> Enumerator[Segment]
+    def model: () -> Model
+  end
+
+  class Params
+    def initialize: () -> void
+    def language=: (String) -> String # TODO: Enumerate lang names
+    def language: () -> String
+    def translate=: (boolish) -> boolish
+    def translate: () -> (true | false)
+    def no_context=: (boolish) -> boolish
+    def no_context: () -> (true | false)
+    def single_segment=: (boolish) -> boolish
+    def single_segment: () -> (true | false)
+    def print_special=: (boolish) -> boolish
+    def print_special: () -> (true | false)
+    def print_progress=: (boolish) -> boolish
+    def print_progress: () -> (true | false)
+    def print_realtime=: (boolish) -> boolish
+    def print_realtime: () -> (true | false)
+    def print_timestamps=: (boolish) -> boolish
+    def print_timestamps: () -> (true | false)
+    def suppress_blank=: (boolish) -> boolish
+    def suppress_blank: () -> (true | false)
+    def suppress_nst=: (boolish) -> boolish
+    def suppress_nst: () -> (true | false)
+    def token_timestamps=: (boolish) -> boolish
+    def token_timestamps: () -> (true | false)
+    def split_on_word=: (boolish) -> boolish
+    def split_on_word: () -> (true | false)
+    def initial_prompt=: (_ToS) -> _ToS
+    def initial_prompt: () -> String
+    def diarize=: (boolish) -> boolish
+    def diarize: () -> (true | false)
+    def offset=: (Integer) -> Integer
+    def offset: () -> Integer
+    def duration=: (Integer) -> Integer
+    def duration: () -> Integer
+    def max_text_tokens=: (Integer) -> Integer
+    def max_text_tokens: () -> Integer
+    def temperature=: (Float) -> Float
+    def temperature: () -> Float
+    def max_initial_ts=: (Float) -> Float
+    def max_initial_ts: () -> Float
+    def length_penalty=: (Float) -> Float
+    def length_penalty: () -> Float
+    def temperature_inc=: (Float) -> Float
+    def temperature_inc: () -> Float
+    def entropy_thold=: (Float) -> Float
+    def entropy_thold: () -> Float
+    def logprob_thold=: (Float) -> Float
+    def logprob_thold: () -> Float
+    def no_speech_thold=: (Float) -> Float
+    def no_speech_thold: () -> Float
+    def new_segment_callback=: (new_segment_callback) -> new_segment_callback
+    def new_segment_callback_user_data=: (Object) -> Object
+    def progress_callback=: (progress_callback) -> progress_callback
+    def progress_callback_user_data=: (Object) -> Object
+    def abort_callback=: (abort_callback) -> abort_callback
+    def abort_callback_user_data=: (Object) -> Object
+    def on_new_segment: { (Segment) -> void } -> void
+    def on_progress: { (Integer) -> void } -> void
+    def abort_on: { (Object) -> boolish } -> void
+  end
+
+  class Model
+    def self.pre_converted_models: () -> Hash[String, Model::URI]
+    def initialize: () -> void
+    def n_vocab: () -> Integer
+    def n_audio_ctx: () -> Integer
+    def n_audio_state: () -> Integer
+    def n_audio_head: () -> Integer
+    def n_audio_layer: () -> Integer
+    def n_text_ctx: () -> Integer
+    def n_text_state: () -> Integer
+    def n_text_head: () -> Integer
+    def n_text_layer: () -> Integer
+    def n_mels: () -> Integer
+    def ftype: () -> Integer
+    def type: () -> String
+
+    class URI
+      def initialize: (string | ::URI::HTTP) -> void
+      def to_path: -> String
+      def clear_cache: -> void
+    end
+  end
+
+  class Segment
+    def initialize: () -> void
+    def start_time: () -> Integer
+    def end_time: () -> Integer
+    def speaker_next_turn?: () -> (true | false)
+    def text: () -> String
+    def no_speech_prob: () -> Float
+  end
+
+  class Error < StandardError
+    attr_reader code: Integer
+
+    def initialize: (Integer) -> void
+  end
+end
index 1362fc469bf3452b6eec63db716a3ccaf0e28827..df871e0e651fe3ffdbeab391b1e595f3c6115505 100644 (file)
@@ -68,4 +68,42 @@ class TestModel < TestBase
     assert_path_exist path
     assert_equal 147964211, File.size(path)
   end
+
+  def test_uri_string
+    path = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin"
+    whisper = Whisper::Context.new(path)
+    model = whisper.model
+
+    assert_equal 51864, model.n_vocab
+    assert_equal 1500, model.n_audio_ctx
+    assert_equal 512, model.n_audio_state
+    assert_equal 8, model.n_audio_head
+    assert_equal 6, model.n_audio_layer
+    assert_equal 448, model.n_text_ctx
+    assert_equal 512, model.n_text_state
+    assert_equal 8, model.n_text_head
+    assert_equal 6, model.n_text_layer
+    assert_equal 80, model.n_mels
+    assert_equal 1, model.ftype
+    assert_equal "base", model.type
+  end
+
+  def test_uri
+    path = URI("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin")
+    whisper = Whisper::Context.new(path)
+    model = whisper.model
+
+    assert_equal 51864, model.n_vocab
+    assert_equal 1500, model.n_audio_ctx
+    assert_equal 512, model.n_audio_state
+    assert_equal 8, model.n_audio_head
+    assert_equal 6, model.n_audio_layer
+    assert_equal 448, model.n_text_ctx
+    assert_equal 512, model.n_text_state
+    assert_equal 8, model.n_text_head
+    assert_equal 6, model.n_text_layer
+    assert_equal 80, model.n_mels
+    assert_equal 1, model.ftype
+    assert_equal "base", model.type
+  end
 end