]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ruby : add `Whisper::Context::Params`, fix token memory management (#3647)
authorKITAITI Makoto <redacted>
Wed, 4 Feb 2026 11:33:09 +0000 (20:33 +0900)
committerGitHub <redacted>
Wed, 4 Feb 2026 11:33:09 +0000 (20:33 +0900)
* Don't convert to temporary VALUE

* Define Whisper::Context::Params

* Add test for Whisper::Context::Params

* Implement Whisper::Context::Params

* Add tests for Context::Params

* Fix Whisper::Token memory management

* Add test for token_timestamps

* Make Context accept Context::Params

* Make Context::Params.new accept keyword args

* Add test for Context::Params.new with keyword args

* Add signature of Context::Params

* Add example for Whisper::Token

* Fix typos

* Revert "Don't convert to temporary VALUE"

This reverts commit dee66e738491ae742fc981dc6e18ad92f1b05316.

* Hold Token#text as Ruby objectd

* Don't use pointer for ruby_whisper_context_params.params

* Use RUBY_DEFAULT_FREE instead of custom function

* Update bindings/ruby/README.md

Co-authored-by: Daniel Bevenius <redacted>
* Add document for Whisper::Context::Params

---------

Co-authored-by: Daniel Bevenius <redacted>
bindings/ruby/README.md
bindings/ruby/ext/ruby_whisper.c
bindings/ruby/ext/ruby_whisper.h
bindings/ruby/ext/ruby_whisper_context.c
bindings/ruby/ext/ruby_whisper_context_params.c [new file with mode: 0644]
bindings/ruby/ext/ruby_whisper_token.c
bindings/ruby/sig/whisper.rbs
bindings/ruby/test/test_context_params.rb [new file with mode: 0644]
bindings/ruby/test/test_token.rb

index 86774158355f4fca04e4d2c24853ac8025e2c6a8..c6280a6926a6f24e9538c39ade413fa6f2da26c2 100644 (file)
@@ -247,6 +247,58 @@ whisper.transcribe("path/to/audio.wav", params)
 
 ```
 
+### Tokens ###
+
+Each segment has tokens.
+
+To enable token timestamps, you need to set `Whisper::Params#token_timestamps = true`. Then, retrieve tokens from segments using `Whisper::Segment#each_token`.
+
+```ruby
+whisper = Whisper::Context.new("base.en")
+params = Whisper::Params.new(token_timestamps: true)
+whisper
+  .transcribe("path/to/audio.wav", params)
+  .each_segment do |segment|
+    segment.each_token do |token|
+      token => {start_time:, end_time:, text:, probability:}
+      st = "%05.2fs" % (start_time / 1000.0)
+      et = "%05.2fs" % (end_time / 1000.0)
+      prob = "%.1f%%" % (probability * 100)
+      puts "[#{st} --> #{et}] #{text} (#{prob})"
+    end
+  end
+```
+
+```
+[00.00s --> 00.00s] [_BEG_] (84.2%)
+[00.32s --> 00.37s]  And (71.2%)
+[00.37s --> 00.53s]  so (98.5%)
+[00.69s --> 00.85s]  my (70.7%)
+[00.85s --> 01.59s]  fellow (99.5%)
+[01.59s --> 02.10s]  Americans (90.1%)
+[02.85s --> 03.30s] , (28.4%)
+[03.30s --> 04.14s]  ask (79.8%)
+[04.14s --> 04.28s]  not (78.9%)
+[05.03s --> 05.35s]  what (93.3%)
+[05.41s --> 05.74s]  your (98.8%)
+[05.74s --> 06.41s]  country (99.6%)
+[06.41s --> 06.74s]  can (97.7%)
+[06.74s --> 06.92s]  do (99.0%)
+[07.00s --> 07.00s]  for (95.8%)
+[07.01s --> 07.52s]  you (98.5%)
+[07.81s --> 08.05s] , (49.3%)
+[08.19s --> 08.37s]  ask (65.6%)
+[08.37s --> 08.75s]  what (98.8%)
+[08.91s --> 09.04s]  you (98.2%)
+[09.04s --> 09.32s]  can (96.9%)
+[09.32s --> 09.38s]  do (90.3%)
+[09.44s --> 09.76s]  for (91.8%)
+[09.76s --> 09.99s]  your (98.2%)
+[10.02s --> 10.36s]  country (99.6%)
+[10.51s --> 10.99s] . (87.0%)
+[11.00s --> 11.00s] [_TT_550] (7.6%)
+```
+
 ### Models ###
 
 You can see model information:
@@ -342,6 +394,20 @@ whisper
   .full(Whisper::Params.new, samples)
 ```
 
+Custom context params
+---------------------
+
+You can use customize `Whisper::Context`'s behavior using `Whisper::Context::Params`.
+
+```ruby
+context_params = Whisper::Context::Params.new(
+  use_gpu: false,
+  flash_attn: false,
+  # etc
+)
+whisper = Whisper::Context.new("base", context_params)
+```
+
 Using VAD separately from ASR
 -----------------------------
 
index eb95829c03228815d26a7da8ac662eaa810fe964..ba71d4ba59483bc55f4db2190f052d1006dd566f 100644 (file)
@@ -33,7 +33,8 @@ static bool is_log_callback_finalized = false;
 // High level API
 extern VALUE ruby_whisper_segment_allocate(VALUE klass);
 
-extern void init_ruby_whisper_context(VALUE *mWhisper);
+extern VALUE init_ruby_whisper_context(VALUE *mWhisper);
+extern void init_ruby_whisper_context_params(VALUE *cContext);
 extern void init_ruby_whisper_params(VALUE *mWhisper);
 extern void init_ruby_whisper_error(VALUE *mWhisper);
 extern void init_ruby_whisper_segment(VALUE *mWhisper);
@@ -162,6 +163,22 @@ void Init_whisper() {
   rb_define_const(mWhisper, "LOG_LEVEL_DEBUG", INT2NUM(GGML_LOG_LEVEL_DEBUG));
   rb_define_const(mWhisper, "LOG_LEVEL_CONT", INT2NUM(GGML_LOG_LEVEL_CONT));
 
+  rb_define_const(mWhisper, "AHEADS_NONE", INT2NUM(WHISPER_AHEADS_NONE));
+  rb_define_const(mWhisper, "AHEADS_N_TOP_MOST", INT2NUM(WHISPER_AHEADS_N_TOP_MOST));
+  rb_define_const(mWhisper, "AHEADS_CUSTOM", INT2NUM(WHISPER_AHEADS_CUSTOM));
+  rb_define_const(mWhisper, "AHEADS_TINY_EN", INT2NUM(WHISPER_AHEADS_TINY_EN));
+  rb_define_const(mWhisper, "AHEADS_TINY", INT2NUM(WHISPER_AHEADS_TINY));
+  rb_define_const(mWhisper, "AHEADS_BASE_EN", INT2NUM(WHISPER_AHEADS_BASE_EN));
+  rb_define_const(mWhisper, "AHEADS_BASE", INT2NUM(WHISPER_AHEADS_BASE));
+  rb_define_const(mWhisper, "AHEADS_SMALL_EN", INT2NUM(WHISPER_AHEADS_SMALL_EN));
+  rb_define_const(mWhisper, "AHEADS_SMALL", INT2NUM(WHISPER_AHEADS_SMALL));
+  rb_define_const(mWhisper, "AHEADS_MEDIUM_EN", INT2NUM(WHISPER_AHEADS_MEDIUM_EN));
+  rb_define_const(mWhisper, "AHEADS_MEDIUM", INT2NUM(WHISPER_AHEADS_MEDIUM));
+  rb_define_const(mWhisper, "AHEADS_LARGE_V1", INT2NUM(WHISPER_AHEADS_LARGE_V1));
+  rb_define_const(mWhisper, "AHEADS_LARGE_V2", INT2NUM(WHISPER_AHEADS_LARGE_V2));
+  rb_define_const(mWhisper, "AHEADS_LARGE_V3", INT2NUM(WHISPER_AHEADS_LARGE_V3));
+  rb_define_const(mWhisper, "AHEADS_LARGE_V3_TURBO", INT2NUM(WHISPER_AHEADS_LARGE_V3_TURBO));
+
   rb_define_singleton_method(mWhisper, "lang_max_id", ruby_whisper_s_lang_max_id, 0);
   rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1);
   rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1);
@@ -170,7 +187,8 @@ void Init_whisper() {
   rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2);
   rb_define_private_method(rb_singleton_class(mWhisper), "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1);
 
-  init_ruby_whisper_context(&mWhisper);
+  cContext = init_ruby_whisper_context(&mWhisper);
+  init_ruby_whisper_context_params(&cContext);
   init_ruby_whisper_params(&mWhisper);
   init_ruby_whisper_error(&mWhisper);
   init_ruby_whisper_segment(&mWhisper);
index c2c9866ae0de868dbc1165beffd3ef92b8fb3c5a..8dfd103c17af81984619fbe08bcda10b3b6031f6 100644 (file)
@@ -16,6 +16,10 @@ typedef struct {
   struct whisper_context *context;
 } ruby_whisper;
 
+typedef struct ruby_whisper_context_params {
+  struct whisper_context_params params;
+} ruby_whisper_context_params;
+
 typedef struct {
   struct whisper_full_params params;
   bool diarize;
@@ -37,7 +41,7 @@ typedef struct {
 
 typedef struct {
   whisper_token_data *token_data;
-  const char *text;
+  VALUE text;
 } ruby_whisper_token;
 
 typedef struct {
@@ -71,7 +75,11 @@ typedef struct parsed_samples_t {
   } \
 } while (0)
 
-#define GetToken(obj, rwt) do {                                             \
+#define GetContextParams(obj, rwcp) do { \
+  TypedData_Get_Struct((obj), ruby_whisper_context_params, &ruby_whisper_context_params_type, (rwcp)); \
+} while (0)
+
+#define GetToken(obj, rwt) do { \
   TypedData_Get_Struct((obj), ruby_whisper_token, &ruby_whisper_token_type, (rwt)); \
   if ((rwt)->token_data == NULL) { \
     rb_raise(rb_eRuntimeError, "Not initialized"); \
index 84790e3dedfa90d64d7511be57bdf15e5fbdaa6a..a8118d12773dff4d89b2d75be1ef83a6ee528a8e 100644 (file)
@@ -18,6 +18,7 @@ extern VALUE eError;
 extern VALUE cModel;
 
 extern const rb_data_type_t ruby_whisper_params_type;
+extern const rb_data_type_t ruby_whisper_context_params_type;
 extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self);
 extern VALUE rb_whisper_model_s_new(VALUE context);
 extern VALUE rb_whisper_segment_s_new(VALUE context, int index);
@@ -143,16 +144,25 @@ ruby_whisper_initialize(int argc, VALUE *argv, VALUE self)
 {
   ruby_whisper *rw;
   VALUE whisper_model_file_path;
+  VALUE context_params;
+  struct whisper_context_params params;
 
   // TODO: we can support init from buffer here too maybe another ruby object to expose
-  rb_scan_args(argc, argv, "01", &whisper_model_file_path);
+  rb_scan_args(argc, argv, "11", &whisper_model_file_path, &context_params);
   TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
 
   whisper_model_file_path = ruby_whisper_normalize_model_path(whisper_model_file_path);
   if (!rb_respond_to(whisper_model_file_path, id_to_s)) {
     rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
   }
-  rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params());
+  if (NIL_P(context_params)) {
+    params = whisper_context_default_params();
+  } else {
+    ruby_whisper_context_params *rwcp;
+    GetContextParams(context_params, rwcp);
+    params = rwcp->params;
+  }
+  rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), params);
   if (rw->context == NULL) {
     rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context");
   }
@@ -711,7 +721,7 @@ ruby_whisper_get_model(VALUE self)
   return rb_whisper_model_s_new(self);
 }
 
-void
+VALUE
 init_ruby_whisper_context(VALUE *mWhisper)
 {
   cContext = rb_define_class_under(*mWhisper, "Context", rb_cObject);
@@ -749,4 +759,6 @@ init_ruby_whisper_context(VALUE *mWhisper)
   rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0);
 
   rb_define_method(cContext, "model", ruby_whisper_get_model, 0);
+
+  return cContext;
 }
diff --git a/bindings/ruby/ext/ruby_whisper_context_params.c b/bindings/ruby/ext/ruby_whisper_context_params.c
new file mode 100644 (file)
index 0000000..87df21d
--- /dev/null
@@ -0,0 +1,163 @@
+#include "ruby_whisper.h"
+
+#define NUM_PARAMS 6
+
+#define DEF_BOOLEAN_ATTR_METHOD(name) \
+static VALUE \
+ruby_whisper_context_params_get_ ## name(VALUE self) { \
+  ruby_whisper_context_params *rwcp; \
+  GetContextParams(self, rwcp); \
+  return rwcp->params.name ? Qtrue : Qfalse; \
+} \
+static VALUE \
+ruby_whisper_context_params_set_ ## name(VALUE self, VALUE value) { \
+  ruby_whisper_context_params *rwcp; \
+  GetContextParams(self, rwcp); \
+  rwcp->params.name = RTEST(value); \
+  return value; \
+}
+
+#define DEF_INT_ATTR_METHOD(name) \
+static VALUE \
+ruby_whisper_context_params_get_ ## name(VALUE self) { \
+  ruby_whisper_context_params *rwcp; \
+  GetContextParams(self, rwcp); \
+  return INT2NUM(rwcp->params.name); \
+} \
+static VALUE \
+ruby_whisper_context_params_set_ ## name(VALUE self, VALUE value) { \
+  ruby_whisper_context_params *rwcp; \
+  GetContextParams(self, rwcp); \
+  rwcp->params.name = NUM2INT(value); \
+  return value; \
+}
+
+#define DEFINE_PARAM(param_name, nth) \
+  id_ ## param_name = rb_intern(#param_name); \
+  param_names[nth] = id_ ## param_name; \
+  rb_define_method(cContextParams, #param_name, ruby_whisper_context_params_get_ ## param_name, 0); \
+  rb_define_method(cContextParams, #param_name "=", ruby_whisper_context_params_set_ ## param_name, 1);
+
+VALUE cContextParams;
+
+static ID param_names[NUM_PARAMS];
+static ID id_use_gpu;
+static ID id_flash_attn;
+static ID id_gpu_device;
+static ID id_dtw_token_timestamps;
+static ID id_dtw_aheads_preset;
+static ID id_dtw_n_top;
+
+static size_t
+ruby_whisper_context_params_memsize(const void *p)
+{
+  const ruby_whisper_context_params *rwcp = (ruby_whisper_context_params *)p;
+  if (!rwcp) {
+    return 0;
+  }
+  return sizeof(ruby_whisper_context_params);
+}
+
+const rb_data_type_t ruby_whisper_context_params_type = {
+  "ruby_whisper_context_params",
+  {0, RUBY_DEFAULT_FREE, ruby_whisper_context_params_memsize,},
+  0, 0,
+  0
+};
+
+static VALUE
+ruby_whisper_context_params_s_allocate(VALUE klass)
+{
+  ruby_whisper_context_params *rwcp;
+  return TypedData_Make_Struct(klass, ruby_whisper_context_params, &ruby_whisper_context_params_type, rwcp);
+}
+
+DEF_BOOLEAN_ATTR_METHOD(use_gpu);
+DEF_BOOLEAN_ATTR_METHOD(flash_attn);
+DEF_INT_ATTR_METHOD(gpu_device);
+DEF_BOOLEAN_ATTR_METHOD(dtw_token_timestamps);
+DEF_INT_ATTR_METHOD(dtw_aheads_preset);
+
+static VALUE
+ruby_whisper_context_params_get_dtw_n_top(VALUE self) {
+  ruby_whisper_context_params *rwcp;
+  GetContextParams(self, rwcp);
+
+  int dtw_n_top = rwcp->params.dtw_n_top;
+
+  return dtw_n_top == -1 ? Qnil : INT2NUM(dtw_n_top);
+}
+
+static VALUE
+ruby_whisper_context_params_set_dtw_n_top(VALUE self, VALUE value) {
+  ruby_whisper_context_params *rwcp;
+  GetContextParams(self, rwcp);
+
+  rwcp->params.dtw_n_top = NIL_P(value) ? -1 : NUM2INT(value);
+
+  return value;
+}
+
+#define SET_PARAM_IF_SAME(param_name) \
+  if (id == id_ ## param_name) { \
+    ruby_whisper_context_params_set_ ## param_name(self, value); \
+    continue; \
+  }
+
+static VALUE
+ruby_whisper_context_params_initialize(int argc, VALUE *argv, VALUE self)
+{
+  ruby_whisper_context_params *rwcp;
+  TypedData_Get_Struct(self, ruby_whisper_context_params, &ruby_whisper_context_params_type, rwcp);
+  rwcp->params = whisper_context_default_params();
+
+  VALUE kw_hash;
+  rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash);
+  if (NIL_P(kw_hash)) {
+    return Qnil;
+  }
+
+  VALUE values[NUM_PARAMS] = {Qundef};
+  rb_get_kwargs(kw_hash, param_names, 0, NUM_PARAMS, values);
+
+  ID id;
+  VALUE value;
+  for (int i = 0; i < NUM_PARAMS; i++) {
+    id = param_names[i];
+    value = values[i];
+    if (value == Qundef) {
+      continue;
+    }
+    SET_PARAM_IF_SAME(use_gpu)
+    SET_PARAM_IF_SAME(flash_attn)
+    SET_PARAM_IF_SAME(gpu_device)
+    SET_PARAM_IF_SAME(dtw_token_timestamps)
+    SET_PARAM_IF_SAME(dtw_aheads_preset)
+    SET_PARAM_IF_SAME(dtw_n_top)
+  }
+
+  return Qnil;
+}
+
+#undef SET_PARAM_IF_SAME
+
+void
+init_ruby_whisper_context_params(VALUE *cContext)
+{
+  cContextParams = rb_define_class_under(*cContext, "Params", rb_cObject);
+
+  rb_define_alloc_func(cContextParams, ruby_whisper_context_params_s_allocate);
+  rb_define_method(cContextParams, "initialize", ruby_whisper_context_params_initialize, -1);
+
+  DEFINE_PARAM(use_gpu, 0)
+  DEFINE_PARAM(flash_attn, 1)
+  DEFINE_PARAM(gpu_device, 2)
+  DEFINE_PARAM(dtw_token_timestamps, 3)
+  DEFINE_PARAM(dtw_aheads_preset, 4)
+  DEFINE_PARAM(dtw_n_top, 5)
+}
+
+#undef DEFINE_PARAM
+#undef DEF_INT_ATTR_METHOD
+#undef DEF_BOOLEAN_ATTR_METHOD
+#undef NUM_PARAMS
index 56a7eab2231c157345d1a6ca3b8469e0aed44105..73f5a547daf1000dbac53c19aa67986ca57cab7b 100644 (file)
@@ -24,12 +24,34 @@ ruby_whisper_token_memsize(const void *p)
   if (!rwt) {
     return 0;
   }
-  return sizeof(rwt);
+  size_t size = sizeof(*rwt);
+  if (rwt->token_data) {
+    size += sizeof(*rwt->token_data);
+  }
+  return size;
+}
+
+static void
+ruby_whisper_token_mark(void *p)
+{
+  ruby_whisper_token *rwt = (ruby_whisper_token *)p;
+  rb_gc_mark(rwt->text);
+}
+
+static void
+ruby_whisper_token_free(void *p)
+{
+  ruby_whisper_token *rwt = (ruby_whisper_token *)p;
+  if (rwt->token_data) {
+    xfree(rwt->token_data);
+    rwt->token_data = NULL;
+  }
+  xfree(rwt);
 }
 
 static const rb_data_type_t ruby_whisper_token_type = {
   "ruby_whisper_token",
-  {0, RUBY_DEFAULT_FREE, ruby_whisper_token_memsize,},
+  {ruby_whisper_token_mark, ruby_whisper_token_free, ruby_whisper_token_memsize,},
   0, 0,
   0
 };
@@ -40,19 +62,19 @@ ruby_whisper_token_allocate(VALUE klass)
   ruby_whisper_token *rwt;
   VALUE token = TypedData_Make_Struct(klass, ruby_whisper_token, &ruby_whisper_token_type, rwt);
   rwt->token_data = NULL;
-  rwt->text = NULL;
+  rwt->text = Qnil;
   return token;
 }
 
 VALUE
 ruby_whisper_token_s_init(struct whisper_context *context, int i_segment, int i_token)
 {
-  whisper_token_data token_data = whisper_full_get_token_data(context, i_segment, i_token);
   const VALUE token = ruby_whisper_token_allocate(cToken);
   ruby_whisper_token *rwt;
   TypedData_Get_Struct(token, ruby_whisper_token, &ruby_whisper_token_type, rwt);
-  rwt->token_data = &token_data;
-  rwt->text = whisper_full_get_token_text(context, i_segment, i_token);
+  rwt->token_data = ALLOC(whisper_token_data);
+  *(rwt->token_data) = whisper_full_get_token_data(context, i_segment, i_token);
+  rwt->text = rb_str_new2(whisper_full_get_token_text(context, i_segment, i_token));
   return token;
 }
 
@@ -182,10 +204,9 @@ ruby_whisper_token_get_text(VALUE self)
 {
   ruby_whisper_token *rwt;
   GetToken(self, rwt);
-  return rb_str_new2(rwt->text);
+  return rwt->text;
 }
 
-
 /*
  * Start time of the token.
  *
index 0e7b2c276e8f24599b167643f1e745460298b203..9ade451c6b26f15e1cf8eea8d1c6f1b61d116576 100644 (file)
@@ -17,6 +17,21 @@ module Whisper
   LOG_LEVEL_ERROR: Integer
   LOG_LEVEL_DEBUG: Integer
   LOG_LEVEL_CONT: Integer
+  AHEADS_NONE: Integer
+  AHEADS_N_TOP_MOST: Integer
+  AHEADS_CUSTOM: Integer
+  AHEADS_TINY_EN: Integer
+  AHEADS_TINY: Integer
+  AHEADS_BASE_EN: Integer
+  AHEADS_BASE: Integer
+  AHEADS_SMALL_EN: Integer
+  AHEADS_SMALL: Integer
+  AHEADS_MEDIUM_EN: Integer
+  AHEADS_MEDIUM: Integer
+  AHEADS_LARGE_V1: Integer
+  AHEADS_LARGE_V2: Integer
+  AHEADS_LARGE_V3: Integer
+  AHEADS_LARGE_V3_TURBO: Integer
 
   def self.lang_max_id: () -> Integer
   def self.lang_id: (string name) -> Integer
@@ -120,6 +135,30 @@ module Whisper
 
     def to_srt: () -> String
     def to_webvtt: () -> String
+
+    class Params
+      def self.new: (
+        use_gpu: boolish,
+        flash_attn: boolish,
+        gpu_device: Integer,
+        dtw_token_timestamps: boolish,
+        dtw_aheads_preset: Integer,
+        dtw_n_top: Integer | nil,
+      ) -> instance
+
+      def use_gpu=: (boolish) -> boolish
+      def use_gpu: () -> (true | false)
+      def flash_attn=: (boolish) -> boolish
+      def flash_attn: () -> (true | false)
+      def gpu_device=: (Integer) -> Integer
+      def gpu_device: () -> Integer
+      def dtw_token_timestamps=: (boolish) -> boolish
+      def dtw_token_timestamps: () -> (true | false)
+      def dtw_aheads_preset=: (Integer) -> Integer
+      def dtw_aheads_preset: () -> Integer
+      def dtw_n_top=: (Integer | nil) -> (Integer | nil)
+      def dtw_n_top: () -> (Integer | nil)
+    end
   end
 
   class Params
diff --git a/bindings/ruby/test/test_context_params.rb b/bindings/ruby/test/test_context_params.rb
new file mode 100644 (file)
index 0000000..8d19fdc
--- /dev/null
@@ -0,0 +1,82 @@
+require_relative "helper"
+
+class TestContextParams < TestBase
+  PARAM_NAMES = [
+    :use_gpu,
+    :flash_attn,
+    :gpu_device,
+    :dtw_token_timestamps,
+    :dtw_aheads_preset,
+    :dtw_n_top
+  ]
+
+  def test_new
+    params = Whisper::Context::Params.new
+    assert_instance_of Whisper::Context::Params, params
+  end
+
+  def test_attributes
+    params = Whisper::Context::Params.new
+
+    assert_true params.use_gpu
+    params.use_gpu = false
+    assert_false params.use_gpu
+
+    assert_true params.flash_attn
+    params.flash_attn = false
+    assert_false params.flash_attn
+
+    assert_equal 0, params.gpu_device
+    params.gpu_device = 1
+    assert_equal 1, params.gpu_device
+
+    assert_false params.dtw_token_timestamps
+    params.dtw_token_timestamps = true
+    assert_true params.dtw_token_timestamps
+
+    assert_equal Whisper::AHEADS_NONE, params.dtw_aheads_preset
+    params.dtw_aheads_preset =Whisper::AHEADS_BASE
+    assert_equal Whisper::AHEADS_BASE, params.dtw_aheads_preset
+
+    assert_nil params.dtw_n_top
+    params.dtw_n_top = 6
+    assert_equal 6, params.dtw_n_top
+    params.dtw_n_top = nil
+    assert_nil params.dtw_n_top
+  end
+
+  def test_new_with_kw_args
+    params = Whisper::Context::Params.new(use_gpu: false)
+    assert_false params.use_gpu
+  end
+
+  def test_new_with_kw_wargs_non_existent
+    assert_raise ArgumentError do
+      Whisper::Context::Params.new(non_existent: "value")
+    end
+  end
+
+  data(PARAM_NAMES.collect {|param| [param, param]}.to_h)
+  def test_new_with_kw_args_default_values(param)
+    default_params = Whisper::Context::Params.new
+    default_value = default_params.send(param)
+    value = if param == :dtw_n_top
+              6
+            else
+              case default_value
+              in true | false
+                !default_value
+              in Integer
+                default_value + 1
+              end
+            end
+    params = Whisper::Context::Params.new(param => value)
+    assert_equal value, params.send(param)
+
+    PARAM_NAMES.reject {|name| name == param}.each do |name|
+      expected = default_params.send(name)
+      actual = params.send(name)
+      assert_equal expected, actual
+    end
+  end
+end
index e5834b1b4804fd4ef4df7b5390504a6c58e16652..a23f6813675db074913d39eb5df1169d1467509e 100644 (file)
@@ -56,6 +56,17 @@ class TestToken < TestBase
                  @segment.each_token.collect(&:text)
   end
 
+  def test_token_timestamps
+    params = Whisper::Params.new(token_timestamps: true)
+    whisper.transcribe(TestBase::AUDIO, params)
+    prev = -1
+    whisper.each_segment.first.each_token do |token|
+      assert token.start_time >= prev
+      assert token.end_time >= token.start_time
+      prev = token.end_time
+    end
+  end
+
   def test_deconstruct_keys_with_nil
     keys = %i[id tid probability log_probability pt ptsum t_dtw voice_length start_time end_time text]
     expected = keys.collect {|key| [key, @token.send(key)] }.to_h