* Updated bindings so they can be used in third pary packages.
* Updated makefiles to set FMA flag on optionally, for xeon E5 on Darwin
* Added test script
* Changes for examples
* Reverted
* Made the NewContext method private
--- /dev/null
+name: Bindings Tests
+on:
+ push:
+ paths:
+ - bindings/go/**
+
+jobs:
+ ubuntu-latest:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/setup-go@v3
+ with:
+ go-version: '^1.19'
+ - uses: actions/checkout@v1
+ - run: |
+ cd bindings/go
+ make test
--- /dev/null
+package main
+
+import "fmt"
+
+///////////////////////////////////////////////////////////////////////////////
+// CONSTANTS
+
+const (
+ Reset = "\033[0m"
+ RGBPrefix = "\033[38;5;" // followed by RGB values in decimal format separated by colons
+ RGBSuffix = "m"
+)
+
+///////////////////////////////////////////////////////////////////////////////
+// PUBLIC METHODS
+
+// Colorize text with RGB values, from 0 to 23
+func Colorize(text string, v int) string {
+ // https://en.wikipedia.org/wiki/ANSI_escape_code#8-bit
+ // Grayscale colors are in the range 232-255
+ return RGBPrefix + fmt.Sprint(v%24+232) + RGBSuffix + text + Reset
+}
import (
"flag"
+ "fmt"
+ "strings"
+ "time"
+
+ // Packages
+ whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
)
///////////////////////////////////////////////////////////////////////////////
return flags.Lookup("language").Value.String()
}
+func (flags *Flags) IsTranslate() bool {
+ return flags.Lookup("translate").Value.(flag.Getter).Get().(bool)
+}
+
+func (flags *Flags) GetOffset() time.Duration {
+ return flags.Lookup("offset").Value.(flag.Getter).Get().(time.Duration)
+}
+
+func (flags *Flags) GetDuration() time.Duration {
+ return flags.Lookup("duration").Value.(flag.Getter).Get().(time.Duration)
+}
+
+func (flags *Flags) GetThreads() uint {
+ return flags.Lookup("threads").Value.(flag.Getter).Get().(uint)
+}
+
+func (flags *Flags) GetOut() string {
+ return strings.ToLower(flags.Lookup("out").Value.String())
+}
+
func (flags *Flags) IsSpeedup() bool {
return flags.Lookup("speedup").Value.String() == "true"
}
return flags.Lookup("tokens").Value.String() == "true"
}
+func (flags *Flags) IsColorize() bool {
+ return flags.Lookup("colorize").Value.String() == "true"
+}
+
+func (flags *Flags) GetMaxLen() uint {
+ return flags.Lookup("max-len").Value.(flag.Getter).Get().(uint)
+}
+
+func (flags *Flags) GetMaxTokens() uint {
+ return flags.Lookup("max-tokens").Value.(flag.Getter).Get().(uint)
+}
+
+func (flags *Flags) GetWordThreshold() float32 {
+ return float32(flags.Lookup("word-thold").Value.(flag.Getter).Get().(float64))
+}
+
+func (flags *Flags) SetParams(context whisper.Context) error {
+ if lang := flags.GetLanguage(); lang != "" && lang != "auto" {
+ fmt.Fprintf(flags.Output(), "Setting language to %q\n", lang)
+ if err := context.SetLanguage(lang); err != nil {
+ return err
+ }
+ }
+ if flags.IsTranslate() && context.IsMultilingual() {
+ fmt.Fprintf(flags.Output(), "Setting translate to true\n")
+ context.SetTranslate(true)
+ }
+ if offset := flags.GetOffset(); offset != 0 {
+ fmt.Fprintf(flags.Output(), "Setting offset to %v\n", offset)
+ context.SetOffset(offset)
+ }
+ if duration := flags.GetDuration(); duration != 0 {
+ fmt.Fprintf(flags.Output(), "Setting duration to %v\n", duration)
+ context.SetDuration(duration)
+ }
+ if flags.IsSpeedup() {
+ fmt.Fprintf(flags.Output(), "Setting speedup to true\n")
+ context.SetSpeedup(true)
+ }
+ if threads := flags.GetThreads(); threads != 0 {
+ fmt.Fprintf(flags.Output(), "Setting threads to %d\n", threads)
+ context.SetThreads(threads)
+ }
+ if max_len := flags.GetMaxLen(); max_len != 0 {
+ fmt.Fprintf(flags.Output(), "Setting max_segment_length to %d\n", max_len)
+ context.SetMaxSegmentLength(max_len)
+ }
+ if max_tokens := flags.GetMaxTokens(); max_tokens != 0 {
+ fmt.Fprintf(flags.Output(), "Setting max_tokens to %d\n", max_tokens)
+ context.SetMaxTokensPerSegment(max_tokens)
+ }
+ if word_threshold := flags.GetWordThreshold(); word_threshold != 0 {
+ fmt.Fprintf(flags.Output(), "Setting word_threshold to %f\n", word_threshold)
+ context.SetTokenThreshold(word_threshold)
+ }
+
+ // Return success
+ return nil
+}
+
///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS
func registerFlags(flag *Flags) {
flag.String("model", "", "Path to the model file")
- flag.String("language", "", "Language")
+ flag.String("language", "", "Spoken language")
+ flag.Bool("translate", false, "Translate from source language to english")
+ flag.Duration("offset", 0, "Time offset")
+ flag.Duration("duration", 0, "Duration of audio to process")
+ flag.Uint("threads", 0, "Number of threads to use")
flag.Bool("speedup", false, "Enable speedup")
+ flag.Uint("max-len", 0, "Maximum segment length in characters")
+ flag.Uint("max-tokens", 0, "Maximum tokens per segment")
+ flag.Float64("word-thold", 0, "Maximum segment score")
flag.Bool("tokens", false, "Display tokens")
+ flag.Bool("colorize", false, "Colorize tokens")
+ flag.String("out", "", "Output format (srt, none or leave as empty string)")
}
// Process files
for _, filename := range flags.Args() {
- fmt.Println("Processing", filename)
- if err := Process(model, filename, flags.GetLanguage(), flags.IsSpeedup(), flags.IsTokens()); err != nil {
+ if err := Process(model, filename, flags); err != nil {
fmt.Fprintln(os.Stderr, err)
continue
}
wav "github.com/go-audio/wav"
)
-func Process(model whisper.Model, path string, lang string, speedup, tokens bool) error {
+func Process(model whisper.Model, path string, flags *Flags) error {
var data []float32
// Create processing context
return err
}
+ // Set the parameters
+ if err := flags.SetParams(context); err != nil {
+ return err
+ }
+
// Open the file
+ fmt.Fprintf(flags.Output(), "Loading %q\n", path)
fh, err := os.Open(path)
if err != nil {
return err
}
defer fh.Close()
- // Decode the WAV file
+ // Decode the WAV file - load the full buffer
dec := wav.NewDecoder(fh)
if buf, err := dec.FullPCMBuffer(); err != nil {
return err
data = buf.AsFloat32Buffer().Data
}
- // Set the parameters
+ // Segment callback when -tokens is specified
var cb whisper.SegmentCallback
- if lang != "" {
- if err := context.SetLanguage(lang); err != nil {
- return err
- }
- }
- if speedup {
- context.SetSpeedup(true)
- }
- if tokens {
+ if flags.IsTokens() {
cb = func(segment whisper.Segment) {
- fmt.Printf("%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
+ fmt.Fprintf(flags.Output(), "%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
for _, token := range segment.Tokens {
- fmt.Printf("%q ", token.Text)
+ if flags.IsColorize() && context.IsText(token) {
+ fmt.Fprint(flags.Output(), Colorize(token.Text, int(token.P*24.0)), " ")
+ } else {
+ fmt.Fprint(flags.Output(), token.Text, " ")
+ }
}
- fmt.Println("")
+ fmt.Fprintln(flags.Output(), "")
+ fmt.Fprintln(flags.Output(), "")
}
}
// Process the data
+ fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
if err := context.Process(data, cb); err != nil {
return err
}
// Print out the results
+ switch {
+ case flags.GetOut() == "srt":
+ return OutputSRT(os.Stdout, context)
+ case flags.GetOut() == "none":
+ return nil
+ default:
+ return Output(os.Stdout, context, flags.IsColorize())
+ }
+}
+
+// Output text as SRT file
+func OutputSRT(w io.Writer, context whisper.Context) error {
+ n := 1
for {
segment, err := context.NextSegment()
if err == io.EOF {
- break
+ return nil
} else if err != nil {
return err
}
- fmt.Printf("[%6s->%6s] %s\n", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond), segment.Text)
+ fmt.Fprintln(w, n)
+ fmt.Fprintln(w, srtTimestamp(segment.Start), " --> ", srtTimestamp(segment.End))
+ fmt.Fprintln(w, segment.Text)
+ fmt.Fprintln(w, "")
+ n++
}
+}
+
+// Output text to terminal
+func Output(w io.Writer, context whisper.Context, colorize bool) error {
+ for {
+ segment, err := context.NextSegment()
+ if err == io.EOF {
+ return nil
+ } else if err != nil {
+ return err
+ }
+ fmt.Fprintf(w, "[%6s->%6s]", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
+ if colorize {
+ for _, token := range segment.Tokens {
+ if !context.IsText(token) {
+ continue
+ }
+ fmt.Fprint(w, " ", Colorize(token.Text, int(token.P*24.0)))
+ }
+ fmt.Fprint(w, "\n")
+ } else {
+ fmt.Fprintln(w, " ", segment.Text)
+ }
+ }
+}
- // Return success
- return nil
+// Return srtTimestamp
+func srtTimestamp(t time.Duration) string {
+ return fmt.Sprintf("%02d:%02d:%02d,%03d", t/time.Hour, (t%time.Hour)/time.Minute, (t%time.Minute)/time.Second, (t%time.Second)/time.Millisecond)
}
p.speed_up = toBool(v)
}
+// Set language id
func (p *Params) SetLanguage(lang int) error {
str := C.whisper_lang_str(C.int(lang))
if str == nil {
return nil
}
+// Get language id
func (p *Params) Language() int {
if p.language == nil {
return -1
return int(C.whisper_lang_id(p.language))
}
+// Set number of threads to use
func (p *Params) SetThreads(threads int) {
p.n_threads = C.int(threads)
}
+// Set start offset in ms
func (p *Params) SetOffset(offset_ms int) {
p.offset_ms = C.int(offset_ms)
}
+// Set audio duration to process in ms
func (p *Params) SetDuration(duration_ms int) {
p.duration_ms = C.int(duration_ms)
}
+// Set timestamp token probability threshold (~0.01)
+func (p *Params) SetTokenThreshold(t float32) {
+ p.thold_pt = C.float(t)
+}
+
+// Set timestamp token sum probability threshold (~0.01)
+func (p *Params) SetTokenSumThreshold(t float32) {
+ p.thold_ptsum = C.float(t)
+}
+
+// Set max segment length in characters
+func (p *Params) SetMaxSegmentLength(n int) {
+ p.max_len = C.int(n)
+}
+
+// Set max tokens per segment (0 = no limit)
+func (p *Params) SetMaxTokensPerSegment(n int) {
+ p.max_tokens = C.int(n)
+}
+
///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS
// ERRORS
var (
- ErrUnableToLoadModel = errors.New("unable to load model")
- ErrInternalAppError = errors.New("internal application error")
- ErrProcessingFailed = errors.New("processing failed")
- ErrUnsupportedLanguage = errors.New("unsupported language")
+ ErrUnableToLoadModel = errors.New("unable to load model")
+ ErrInternalAppError = errors.New("internal application error")
+ ErrProcessingFailed = errors.New("processing failed")
+ ErrUnsupportedLanguage = errors.New("unsupported language")
+ ErrModelNotMultilingual = errors.New("model is not multilingual")
)
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
// LIFECYCLE
-func NewContext(model *model, params whisper.Params) (Context, error) {
+func newContext(model *model, params whisper.Params) (Context, error) {
context := new(context)
context.model = model
context.params = params
if context.model.ctx == nil {
return ErrInternalAppError
}
+ if !context.model.IsMultilingual() {
+ return ErrModelNotMultilingual
+ }
if id := context.model.ctx.Whisper_lang_id(lang); id < 0 {
return ErrUnsupportedLanguage
} else if err := context.params.SetLanguage(id); err != nil {
return nil
}
+func (context *context) IsMultilingual() bool {
+ return context.model.IsMultilingual()
+}
+
// Get language
func (context *context) Language() string {
return whisper.Whisper_lang_str(context.params.Language())
}
+// Set translate flag
+func (context *context) SetTranslate(v bool) {
+ context.params.SetTranslate(v)
+}
+
// Set speedup flag
func (context *context) SetSpeedup(v bool) {
context.params.SetSpeedup(v)
}
+// Set number of threads to use
+func (context *context) SetThreads(v uint) {
+ context.params.SetThreads(int(v))
+}
+
+// Set time offset
+func (context *context) SetOffset(v time.Duration) {
+ context.params.SetOffset(int(v.Milliseconds()))
+}
+
+// Set duration of audio to process
+func (context *context) SetDuration(v time.Duration) {
+ context.params.SetOffset(int(v.Milliseconds()))
+}
+
+// Set timestamp token probability threshold (~0.01)
+func (context *context) SetTokenThreshold(t float32) {
+ context.params.SetTokenThreshold(t)
+}
+
+// Set timestamp token sum probability threshold (~0.01)
+func (context *context) SetTokenSumThreshold(t float32) {
+ context.params.SetTokenSumThreshold(t)
+}
+
+// Set max segment length in characters
+func (context *context) SetMaxSegmentLength(n uint) {
+ context.params.SetMaxSegmentLength(int(n))
+}
+
+// Set max tokens per segment (0 = no limit)
+func (context *context) SetMaxTokensPerSegment(n uint) {
+ context.params.SetMaxTokensPerSegment(int(n))
+}
+
// Process new sample data and return any errors
func (context *context) Process(data []float32, cb SegmentCallback) error {
if context.model.ctx == nil {
return result, nil
}
+// Test for text tokens
+func (context *context) IsText(t Token) bool {
+ switch {
+ case context.IsBEG(t):
+ return false
+ case context.IsSOT(t):
+ return false
+ case whisper.Token(t.Id) >= context.model.ctx.Whisper_token_eot():
+ return false
+ case context.IsPREV(t):
+ return false
+ case context.IsSOLM(t):
+ return false
+ case context.IsNOT(t):
+ return false
+ default:
+ return true
+ }
+}
+
+// Test for "begin" token
+func (context *context) IsBEG(t Token) bool {
+ return whisper.Token(t.Id) == context.model.ctx.Whisper_token_beg()
+}
+
+// Test for "start of transcription" token
+func (context *context) IsSOT(t Token) bool {
+ return whisper.Token(t.Id) == context.model.ctx.Whisper_token_sot()
+}
+
+// Test for "end of transcription" token
+func (context *context) IsEOT(t Token) bool {
+ return whisper.Token(t.Id) == context.model.ctx.Whisper_token_eot()
+}
+
+// Test for "start of prev" token
+func (context *context) IsPREV(t Token) bool {
+ return whisper.Token(t.Id) == context.model.ctx.Whisper_token_prev()
+}
+
+// Test for "start of lm" token
+func (context *context) IsSOLM(t Token) bool {
+ return whisper.Token(t.Id) == context.model.ctx.Whisper_token_solm()
+}
+
+// Test for "No timestamps" token
+func (context *context) IsNOT(t Token) bool {
+ return whisper.Token(t.Id) == context.model.ctx.Whisper_token_not()
+}
+
+// Test for token associated with a specific language
+func (context *context) IsLANG(t Token, lang string) bool {
+ if id := context.model.ctx.Whisper_lang_id(lang); id >= 0 {
+ return whisper.Token(t.Id) == context.model.ctx.Whisper_token_lang(id)
+ } else {
+ return false
+ }
+}
+
///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS
// Return a new speech-to-text context.
NewContext() (Context, error)
+ // Return true if the model is multilingual.
+ IsMultilingual() bool
+
// Return all languages supported.
Languages() []string
}
// Context is the speach recognition context.
type Context interface {
SetLanguage(string) error // Set the language to use for speech recognition.
+ SetTranslate(bool) // Set translate flag
+ IsMultilingual() bool // Return true if the model is multilingual.
Language() string // Get language
- SetSpeedup(bool) // Set speedup flag
+
+ SetOffset(time.Duration) // Set offset
+ SetDuration(time.Duration) // Set duration
+ SetThreads(uint) // Set number of threads to use
+ SetSpeedup(bool) // Set speedup flag
+ SetTokenThreshold(float32) // Set timestamp token probability threshold
+ SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold
+ SetMaxSegmentLength(uint) // Set max segment length in characters
+ SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit)
// Process mono audio data and return any errors.
// If defined, newly generated segments are passed to the
// After process is called, return segments until the end of the stream
// is reached, when io.EOF is returned.
NextSegment() (Segment, error)
+
+ IsBEG(Token) bool // Test for "begin" token
+ IsSOT(Token) bool // Test for "start of transcription" token
+ IsEOT(Token) bool // Test for "end of transcription" token
+ IsPREV(Token) bool // Test for "start of prev" token
+ IsSOLM(Token) bool // Test for "start of lm" token
+ IsNOT(Token) bool // Test for "No timestamps" token
+ IsLANG(Token, string) bool // Test for token associated with a specific language
+ IsText(Token) bool // Test for text token
}
// Segment is the text result of a speech recognition.
///////////////////////////////////////////////////////////////////////////////
// LIFECYCLE
-func New(path string) (*model, error) {
+func New(path string) (Model, error) {
model := new(model)
if _, err := os.Stat(path); err != nil {
return nil, err
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
+// Return true if model is multilingual (language and translation options are supported)
+func (model *model) IsMultilingual() bool {
+ return model.ctx.Whisper_is_multilingual() != 0
+}
+
// Return all recognized languages. Initially it is set to auto-detect
func (model *model) Languages() []string {
result := make([]string, 0, whisper.Whisper_lang_max_id())
params.SetThreads(runtime.NumCPU())
// Return new context
- return NewContext(model, params)
+ return newContext(model, params)
}