]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ruby : Sync whisper.cpp and model download feature (#2617)
authorKITAITI Makoto <redacted>
Mon, 9 Dec 2024 11:17:50 +0000 (20:17 +0900)
committerGitHub <redacted>
Mon, 9 Dec 2024 11:17:50 +0000 (13:17 +0200)
* Use C++17

* Add test for Pathname of model

* Make Whisper::Context#initialize accept Pathname

* Add shorthand for pre-converted models

* Update documents

* Add headings to API section in README [skip ci]

* Remove unused function

* Don't care about no longer included file

* Cosmetic fix

* Use conditional get when get model files

bindings/ruby/.gitignore
bindings/ruby/README.md
bindings/ruby/Rakefile
bindings/ruby/ext/.gitignore
bindings/ruby/ext/extconf.rb
bindings/ruby/ext/ruby_whisper.cpp
bindings/ruby/lib/whisper.rb [new file with mode: 0644]
bindings/ruby/lib/whisper/model.rb [new file with mode: 0644]
bindings/ruby/tests/jfk_reader/jfk_reader.c
bindings/ruby/tests/test_model.rb

index e04a90a9c69b97a7b52f065ef1d563722591ae2a..6e3b3be0e24f0beff3c51d25d21233a1cfaeef60 100644 (file)
@@ -1,3 +1,5 @@
 LICENSE
 pkg/
-lib/whisper.*
+lib/whisper.so
+lib/whisper.bundle
+lib/whisper.dll
index 05e19eb6e0938210cc7bc346e45f5568f1b36157..e7065bf9d70f62a6a67a49658b702b0bccdbf6f7 100644 (file)
@@ -22,7 +22,7 @@ Usage
 ```ruby
 require "whisper"
 
-whisper = Whisper::Context.new("path/to/model.bin")
+whisper = Whisper::Context.new(Whisper::Model["base"])
 
 params = Whisper::Params.new
 params.language = "en"
@@ -41,21 +41,60 @@ end
 
 ### Preparing model ###
 
-Use script to download model file(s):
+Some models are prepared up-front:
 
-```bash
-git clone https://github.com/ggerganov/whisper.cpp.git
-cd whisper.cpp
-sh ./models/download-ggml-model.sh base.en
+```ruby
+base_en = Whisper::Model["base.en"]
+whisper = Whisper::Context.new(base_en)
+```
+
+At first time you use a model, it is downloaded automatically. After that, downloaded cached file is used. To clear cache, call `#clear_cache`:
+
+```ruby
+Whisper::Model["base"].clear_cache
 ```
 
-There are some types of models. See [models][] page for details.
+You can see the list of prepared model names by `Whisper::Model.preconverted_model_names`:
+
+```ruby
+puts Whisper::Model.preconverted_model_names
+# tiny
+# tiny.en
+# tiny-q5_1
+# tiny.en-q5_1
+# tiny-q8_0
+# base
+# base.en
+# base-q5_1
+# base.en-q5_1
+# base-q8_0
+#   :
+#   :
+```
+
+You can also use local model files you prepared:
+
+```ruby
+whisper = Whisper::Context.new("path/to/your/model.bin")
+```
+
+Or, you can download model files:
+
+```ruby
+model_uri = Whisper::Model::URI.new("http://example.net/uri/of/your/model.bin")
+whisper = Whisper::Context.new(model_uri)
+```
+
+See [models][] page for details.
 
 ### Preparing audio file ###
 
 Currently, whisper.cpp accepts only 16-bit WAV files.
 
-### API ###
+API
+---
+
+### Segments ###
 
 Once `Whisper::Context#transcribe` called, you can retrieve segments by `#each_segment`:
 
@@ -107,10 +146,12 @@ whisper.transcribe("path/to/audio.wav", params)
 
 ```
 
+### Models ###
+
 You can see model information:
 
 ```ruby
-whisper = Whisper::Context.new("path/to/model.bin")
+whisper = Whisper::Context.new(Whisper::Model["base"])
 model = whisper.model
 
 model.n_vocab # => 51864
@@ -128,6 +169,8 @@ model.type # => "base"
 
 ```
 
+### Logging ###
+
 You can set log callback:
 
 ```ruby
@@ -160,6 +203,8 @@ Whisper.log_set ->(level, buffer, user_data) {
 Whisper::Context.new(MODEL)
 ```
 
+### Low-level API to transcribe ###
+
 You can also call `Whisper::Context#full` and `#full_parallel` with a Ruby array as samples. Although `#transcribe` with audio file path is recommended because it extracts PCM samples in C++ and is fast, `#full` and `#full_parallel` give you flexibility.
 
 ```ruby
@@ -169,7 +214,7 @@ require "wavefile"
 reader = WaveFile::Reader.new("path/to/audio.wav", WaveFile::Format.new(:mono, :float, 16000))
 samples = reader.enum_for(:each_buffer).map(&:samples).flatten
 
-whisper = Whisper::Context.new("path/to/model.bin")
+whisper = Whisper::Context.new(Whisper::Model["base"])
 whisper.full(Whisper::Params.new, samples)
 whisper.each_segment do |segment|
   puts segment.text
index 5f6303ba055df8830e5d0a65910dc5045d547341..f640dce94f2400b242b6b2fd26c17ec0627b3a7c 100644 (file)
@@ -18,19 +18,9 @@ EXTSOURCES.each do |src|
 end
 
 CLEAN.include SOURCES
-CLEAN.include FileList[
-                "ext/*.o",
-                "ext/*.metal",
-                "ext/whisper.{so,bundle,dll}",
-                "ext/depend"
-              ]
+CLEAN.include FileList["ext/*.o", "ext/*.metal", "ext/whisper.{so,bundle,dll}"]
 
-task build: FileList[
-       "ext/Makefile",
-       "ext/ruby_whisper.h",
-       "ext/ruby_whisper.cpp",
-       "whispercpp.gemspec",
-     ]
+task build: ["ext/Makefile", "ext/ruby_whisper.h", "ext/ruby_whisper.cpp", "whispercpp.gemspec"]
 
 directory "pkg"
 CLOBBER.include "pkg"
index 3804ab7e3e4aafa4756e0c574bd0c7e0577564ca..e96a8584c94322569a9f37e30d1d42be4f20d9e7 100644 (file)
@@ -2,7 +2,6 @@ Makefile
 whisper.so
 whisper.bundle
 whisper.dll
-depend
 scripts/get-flags.mk
 *.o
 *.c
index 6d76a7cd9ac09eb06ed7d0c82bfc7f05b1060347..59388ffe0bccc2be860725330b2ea816be43e7d4 100644 (file)
@@ -1,7 +1,7 @@
 require 'mkmf'
 
 # need to use c++ compiler flags
-$CXXFLAGS << ' -std=c++11'
+$CXXFLAGS << ' -std=c++17'
 
 $LDFLAGS << ' -lstdc++'
 
@@ -35,10 +35,10 @@ if $GGML_METAL
   $GGML_METAL_EMBED_LIBRARY = true
 end
 
-$MK_CPPFLAGS = '-Iggml/include -Iggml/src -Iinclude -Isrc -Iexamples'
+$MK_CPPFLAGS = '-Iggml/include -Iggml/src -Iggml/src/ggml-cpu -Iinclude -Isrc -Iexamples'
 $MK_CFLAGS   = '-std=c11   -fPIC'
-$MK_CXXFLAGS = '-std=c++11 -fPIC'
-$MK_NVCCFLAGS = '-std=c++11'
+$MK_CXXFLAGS = '-std=c++17 -fPIC'
+$MK_NVCCFLAGS = '-std=c++17'
 $MK_LDFLAGS = ''
 
 $OBJ_GGML = []
index bb6bae8a859e22b034e50509c8881a30533494c4..83fc53fc058611677dd980a7a22d6ce92c284a9d 100644 (file)
@@ -45,6 +45,7 @@ static ID id_to_enum;
 static ID id_length;
 static ID id_next;
 static ID id_new;
+static ID id_to_path;
 
 static bool is_log_callback_finalized = false;
 
@@ -194,7 +195,9 @@ static VALUE ruby_whisper_params_allocate(VALUE klass) {
 
 /*
  * call-seq:
+ *   new(Whisper::Model["base.en"]) -> Whisper::Context
  *   new("path/to/model.bin") -> Whisper::Context
+ *   new(Whisper::Model::URI.new("https://example.net/uri/of/model.bin")) -> Whisper::Context
  */
 static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
   ruby_whisper *rw;
@@ -204,6 +207,9 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
   rb_scan_args(argc, argv, "01", &whisper_model_file_path);
   Data_Get_Struct(self, ruby_whisper, rw);
 
+  if (rb_respond_to(whisper_model_file_path, id_to_path)) {
+    whisper_model_file_path = rb_funcall(whisper_model_file_path, id_to_path, 0);
+  }
   if (!rb_respond_to(whisper_model_file_path, id_to_s)) {
     rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
   }
@@ -1733,6 +1739,7 @@ void Init_whisper() {
   id_length = rb_intern("length");
   id_next = rb_intern("next");
   id_new = rb_intern("new");
+  id_to_path = rb_intern("to_path");
 
   mWhisper = rb_define_module("Whisper");
   cContext = rb_define_class_under(mWhisper, "Context", rb_cObject);
diff --git a/bindings/ruby/lib/whisper.rb b/bindings/ruby/lib/whisper.rb
new file mode 100644 (file)
index 0000000..4c8e01e
--- /dev/null
@@ -0,0 +1,2 @@
+require "whisper.so"
+require "whisper/model"
diff --git a/bindings/ruby/lib/whisper/model.rb b/bindings/ruby/lib/whisper/model.rb
new file mode 100644 (file)
index 0000000..be67dff
--- /dev/null
@@ -0,0 +1,159 @@
+require "whisper.so"
+require "uri"
+require "net/http"
+require "pathname"
+require "io/console/size"
+
+class Whisper::Model
+  class URI
+    def initialize(uri)
+      @uri = URI(uri)
+    end
+
+    def to_path
+      cache
+      cache_path.to_path
+    end
+
+    def clear_cache
+      path = cache_path
+      path.delete if path.exist?
+    end
+
+    private
+
+    def cache_path
+      base_cache_dir/@uri.host/@uri.path[1..]
+    end
+
+    def base_cache_dir
+      base = case RUBY_PLATFORM
+             when /mswin|mingw/
+               ENV.key?("LOCALAPPDATA") ? Pathname(ENV["LOCALAPPDATA"]) : Pathname(Dir.home)/"AppData/Local"
+             when /darwin/
+               Pathname(Dir.home)/"Library/Caches"
+             else
+               ENV.key?("XDG_CACHE_HOME") ? ENV["XDG_CACHE_HOME"] : Pathname(Dir.home)/".cache"
+             end
+      base/"whisper.cpp"
+    end
+
+    def cache
+      path = cache_path
+      headers = {}
+      headers["if-modified-since"] = path.mtime.httpdate if path.exist?
+      request @uri, headers
+      path
+    end
+
+    def request(uri, headers)
+      Net::HTTP.start uri.host, uri.port, use_ssl: uri.scheme == "https" do |http|
+        request = Net::HTTP::Get.new(uri, headers)
+        http.request request do |response|
+          case response
+          when Net::HTTPNotModified
+            # noop
+          when Net::HTTPOK
+            download response
+          when Net::HTTPRedirection
+            request URI(response["location"])
+          else
+            raise response
+          end
+        end
+      end
+    end
+
+    def download(response)
+      path = cache_path
+      path.dirname.mkpath unless path.dirname.exist?
+      downloading_path = Pathname("#{path}.downloading")
+      size = response.content_length
+      downloading_path.open "wb" do |file|
+        downloaded = 0
+        response.read_body do |chunk|
+          file << chunk
+          downloaded += chunk.bytesize
+          show_progress downloaded, size
+        end
+      end
+      downloading_path.rename path
+    end
+
+    def show_progress(current, size)
+      return unless size
+
+      unless @prev
+        @prev = Time.now
+        $stderr.puts "Downloading #{@uri}"
+      end
+
+      now = Time.now
+      return if now - @prev < 1 && current < size
+
+      progress_width = 20
+      progress = current.to_f / size
+      arrow_length = progress * progress_width
+      arrow = "=" * (arrow_length - 1) + ">" + " " * (progress_width - arrow_length)
+      line = "[#{arrow}] (#{format_bytesize(current)} / #{format_bytesize(size)})"
+      padding = ' ' * ($stderr.winsize[1] - line.size)
+      $stderr.print "\r#{line}#{padding}"
+      $stderr.puts if current >= size
+      @prev = now
+    end
+
+    def format_bytesize(bytesize)
+      return "0.0 B" if bytesize.zero?
+
+      units = %w[B KiB MiB GiB TiB]
+      exp = (Math.log(bytesize) / Math.log(1024)).to_i
+      format("%.1f %s", bytesize.to_f / 1024 ** exp, units[exp])
+    end
+  end
+
+  @names = {}
+  %w[
+    tiny
+    tiny.en
+    tiny-q5_1
+    tiny.en-q5_1
+    tiny-q8_0
+    base
+    base.en
+    base-q5_1
+    base.en-q5_1
+    base-q8_0
+    small
+    small.en
+    small.en-tdrz
+    small-q5_1
+    small.en-q5_1
+    small-q8_0
+    medium
+    medium.en
+    medium-q5_0
+    medium.en-q5_0
+    medium-q8_0
+    large-v1
+    large-v2
+    large-v2-q5_0
+    large-v2-8_0
+    large-v3
+    large-v3-q5_0
+    large-v3-turbo
+    large-v3-turbo-q5_0
+    large-v3-turbo-q8_0
+  ].each do |name|
+    @names[name] = URI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}.bin")
+  end
+
+  class << self
+    def [](name)
+      @names[name]
+    end
+
+    def preconverted_model_names
+      @names.keys
+    end
+  end
+end
index a0688374d0603e8e3dd5fb289d4b077910a1569a..6657176e767578de72963b917342998156a8e438 100644 (file)
@@ -60,49 +60,9 @@ static const rb_memory_view_entry_t jfk_reader_view_entry = {
   jfk_reader_memory_view_available_p
 };
 
-static VALUE
-read_jfk(int argc, VALUE *argv, VALUE obj)
-{
-  const char *audio_path_str = StringValueCStr(argv[0]);
-  const int n_samples = 176000;
-
-  short samples[n_samples];
-  FILE *file = fopen(audio_path_str, "rb");
-
-  fseek(file, 78, SEEK_SET);
-  fread(samples, sizeof(short), n_samples, file);
-  fclose(file);
-
-  VALUE rb_samples = rb_ary_new2(n_samples);
-  for (int i = 0; i < n_samples; i++) {
-    rb_ary_push(rb_samples, INT2FIX(samples[i]));
-  }
-
-  VALUE rb_data = rb_ary_new2(n_samples);
-  for (int i = 0; i < n_samples; i++) {
-    rb_ary_push(rb_data, DBL2NUM(samples[i]/32768.0));
-  }
-
-  float data[n_samples];
-  for (int i = 0; i < n_samples; i++) {
-    data[i] = samples[i]/32768.0;
-  }
-  void *c_data = (void *)data;
-  VALUE rb_void = rb_enc_str_new((const char *)c_data, sizeof(data), rb_ascii8bit_encoding());
-
-  VALUE rb_result = rb_ary_new3(3, rb_samples, rb_data, rb_void);
-  return rb_result;
-}
-
 void Init_jfk_reader(void)
 {
   VALUE cJFKReader = rb_define_class("JFKReader", rb_cObject);
   rb_memory_view_register(cJFKReader, &jfk_reader_view_entry);
   rb_define_method(cJFKReader, "initialize", jfk_reader_initialize, 1);
-
-
-  rb_define_global_function("read_jfk", read_jfk, -1);
-
-
-
 }
index 2310522a644032a3bd2f52290751c9e5a4ccee59..598dbde9f1334353d9719128ec377aea6b4377cf 100644 (file)
@@ -1,4 +1,5 @@
 require_relative "helper"
+require "pathname"
 
 class TestModel < TestBase
   def test_model
@@ -41,4 +42,23 @@ class TestModel < TestBase
     assert_equal 1, model.ftype
     assert_equal "base", model.type
   end
+
+  def test_pathname
+    path = Pathname(MODEL)
+    whisper = Whisper::Context.new(path)
+    model = whisper.model
+
+    assert_equal 51864, model.n_vocab
+    assert_equal 1500, model.n_audio_ctx
+    assert_equal 512, model.n_audio_state
+    assert_equal 8, model.n_audio_head
+    assert_equal 6, model.n_audio_layer
+    assert_equal 448, model.n_text_ctx
+    assert_equal 512, model.n_text_state
+    assert_equal 8, model.n_text_head
+    assert_equal 6, model.n_text_layer
+    assert_equal 80, model.n_mels
+    assert_equal 1, model.ftype
+    assert_equal "base", model.type
+  end
 end