]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
go : adding features to the go-whisper example, go ci, etc (#384)
authorDavid Thorpe <redacted>
Sat, 7 Jan 2023 19:21:43 +0000 (19:21 +0000)
committerGitHub <redacted>
Sat, 7 Jan 2023 19:21:43 +0000 (21:21 +0200)
* Updated bindings so they can be used in third pary packages.

* Updated makefiles to set FMA flag on optionally, for xeon E5 on Darwin

* Added test script

* Changes for examples

* Reverted

* Made the NewContext method private

.github/workflows/bindings.yml [new file with mode: 0644]
bindings/go/examples/go-whisper/color.go [new file with mode: 0644]
bindings/go/examples/go-whisper/flags.go
bindings/go/examples/go-whisper/main.go
bindings/go/examples/go-whisper/process.go
bindings/go/params.go
bindings/go/pkg/whisper/consts.go
bindings/go/pkg/whisper/context.go
bindings/go/pkg/whisper/interface.go
bindings/go/pkg/whisper/model.go

diff --git a/.github/workflows/bindings.yml b/.github/workflows/bindings.yml
new file mode 100644 (file)
index 0000000..1bccf59
--- /dev/null
@@ -0,0 +1,17 @@
+name: Bindings Tests
+on:
+  push:
+    paths:
+      - bindings/go/**
+
+jobs:
+    ubuntu-latest:
+      runs-on: ubuntu-latest
+      steps:
+      - uses: actions/setup-go@v3
+        with:
+          go-version: '^1.19'
+      - uses: actions/checkout@v1
+      - run: |
+          cd bindings/go
+          make test
diff --git a/bindings/go/examples/go-whisper/color.go b/bindings/go/examples/go-whisper/color.go
new file mode 100644 (file)
index 0000000..fa5ac2f
--- /dev/null
@@ -0,0 +1,22 @@
+package main
+
+import "fmt"
+
+///////////////////////////////////////////////////////////////////////////////
+// CONSTANTS
+
+const (
+       Reset     = "\033[0m"
+       RGBPrefix = "\033[38;5;" // followed by RGB values in decimal format separated by colons
+       RGBSuffix = "m"
+)
+
+///////////////////////////////////////////////////////////////////////////////
+// PUBLIC METHODS
+
+// Colorize text with RGB values, from 0 to 23
+func Colorize(text string, v int) string {
+       // https://en.wikipedia.org/wiki/ANSI_escape_code#8-bit
+       // Grayscale colors are in the range 232-255
+       return RGBPrefix + fmt.Sprint(v%24+232) + RGBSuffix + text + Reset
+}
index a5353d1c83a356f81f115af116fca810a7349bf0..ea204455c80a35d4d6f7fafc2f2acfed8486fdff 100644 (file)
@@ -2,6 +2,12 @@ package main
 
 import (
        "flag"
+       "fmt"
+       "strings"
+       "time"
+
+       // Packages
+       whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
 )
 
 ///////////////////////////////////////////////////////////////////////////////
@@ -42,6 +48,26 @@ func (flags *Flags) GetLanguage() string {
        return flags.Lookup("language").Value.String()
 }
 
+func (flags *Flags) IsTranslate() bool {
+       return flags.Lookup("translate").Value.(flag.Getter).Get().(bool)
+}
+
+func (flags *Flags) GetOffset() time.Duration {
+       return flags.Lookup("offset").Value.(flag.Getter).Get().(time.Duration)
+}
+
+func (flags *Flags) GetDuration() time.Duration {
+       return flags.Lookup("duration").Value.(flag.Getter).Get().(time.Duration)
+}
+
+func (flags *Flags) GetThreads() uint {
+       return flags.Lookup("threads").Value.(flag.Getter).Get().(uint)
+}
+
+func (flags *Flags) GetOut() string {
+       return strings.ToLower(flags.Lookup("out").Value.String())
+}
+
 func (flags *Flags) IsSpeedup() bool {
        return flags.Lookup("speedup").Value.String() == "true"
 }
@@ -50,12 +76,81 @@ func (flags *Flags) IsTokens() bool {
        return flags.Lookup("tokens").Value.String() == "true"
 }
 
+func (flags *Flags) IsColorize() bool {
+       return flags.Lookup("colorize").Value.String() == "true"
+}
+
+func (flags *Flags) GetMaxLen() uint {
+       return flags.Lookup("max-len").Value.(flag.Getter).Get().(uint)
+}
+
+func (flags *Flags) GetMaxTokens() uint {
+       return flags.Lookup("max-tokens").Value.(flag.Getter).Get().(uint)
+}
+
+func (flags *Flags) GetWordThreshold() float32 {
+       return float32(flags.Lookup("word-thold").Value.(flag.Getter).Get().(float64))
+}
+
+func (flags *Flags) SetParams(context whisper.Context) error {
+       if lang := flags.GetLanguage(); lang != "" && lang != "auto" {
+               fmt.Fprintf(flags.Output(), "Setting language to %q\n", lang)
+               if err := context.SetLanguage(lang); err != nil {
+                       return err
+               }
+       }
+       if flags.IsTranslate() && context.IsMultilingual() {
+               fmt.Fprintf(flags.Output(), "Setting translate to true\n")
+               context.SetTranslate(true)
+       }
+       if offset := flags.GetOffset(); offset != 0 {
+               fmt.Fprintf(flags.Output(), "Setting offset to %v\n", offset)
+               context.SetOffset(offset)
+       }
+       if duration := flags.GetDuration(); duration != 0 {
+               fmt.Fprintf(flags.Output(), "Setting duration to %v\n", duration)
+               context.SetDuration(duration)
+       }
+       if flags.IsSpeedup() {
+               fmt.Fprintf(flags.Output(), "Setting speedup to true\n")
+               context.SetSpeedup(true)
+       }
+       if threads := flags.GetThreads(); threads != 0 {
+               fmt.Fprintf(flags.Output(), "Setting threads to %d\n", threads)
+               context.SetThreads(threads)
+       }
+       if max_len := flags.GetMaxLen(); max_len != 0 {
+               fmt.Fprintf(flags.Output(), "Setting max_segment_length to %d\n", max_len)
+               context.SetMaxSegmentLength(max_len)
+       }
+       if max_tokens := flags.GetMaxTokens(); max_tokens != 0 {
+               fmt.Fprintf(flags.Output(), "Setting max_tokens to %d\n", max_tokens)
+               context.SetMaxTokensPerSegment(max_tokens)
+       }
+       if word_threshold := flags.GetWordThreshold(); word_threshold != 0 {
+               fmt.Fprintf(flags.Output(), "Setting word_threshold to %f\n", word_threshold)
+               context.SetTokenThreshold(word_threshold)
+       }
+
+       // Return success
+       return nil
+}
+
 ///////////////////////////////////////////////////////////////////////////////
 // PRIVATE METHODS
 
 func registerFlags(flag *Flags) {
        flag.String("model", "", "Path to the model file")
-       flag.String("language", "", "Language")
+       flag.String("language", "", "Spoken language")
+       flag.Bool("translate", false, "Translate from source language to english")
+       flag.Duration("offset", 0, "Time offset")
+       flag.Duration("duration", 0, "Duration of audio to process")
+       flag.Uint("threads", 0, "Number of threads to use")
        flag.Bool("speedup", false, "Enable speedup")
+       flag.Uint("max-len", 0, "Maximum segment length in characters")
+       flag.Uint("max-tokens", 0, "Maximum tokens per segment")
+       flag.Float64("word-thold", 0, "Maximum segment score")
        flag.Bool("tokens", false, "Display tokens")
+       flag.Bool("colorize", false, "Colorize tokens")
+       flag.String("out", "", "Output format (srt, none or leave as empty string)")
 }
index b3a89db755276842a21b892b812616cfad9c06a7..1bff7f5d50ae7dbe0dbfdffb40c0d301973bc328 100644 (file)
@@ -35,8 +35,7 @@ func main() {
 
        // Process files
        for _, filename := range flags.Args() {
-               fmt.Println("Processing", filename)
-               if err := Process(model, filename, flags.GetLanguage(), flags.IsSpeedup(), flags.IsTokens()); err != nil {
+               if err := Process(model, filename, flags); err != nil {
                        fmt.Fprintln(os.Stderr, err)
                        continue
                }
index a0e2be86c9b01bb2ac9b1e557ab7097de790211d..aacdc6965be9617c97c3ec346f9c3f53d7494daf 100644 (file)
@@ -11,7 +11,7 @@ import (
        wav "github.com/go-audio/wav"
 )
 
-func Process(model whisper.Model, path string, lang string, speedup, tokens bool) error {
+func Process(model whisper.Model, path string, flags *Flags) error {
        var data []float32
 
        // Create processing context
@@ -20,14 +20,20 @@ func Process(model whisper.Model, path string, lang string, speedup, tokens bool
                return err
        }
 
+       // Set the parameters
+       if err := flags.SetParams(context); err != nil {
+               return err
+       }
+
        // Open the file
+       fmt.Fprintf(flags.Output(), "Loading %q\n", path)
        fh, err := os.Open(path)
        if err != nil {
                return err
        }
        defer fh.Close()
 
-       // Decode the WAV file
+       // Decode the WAV file - load the full buffer
        dec := wav.NewDecoder(fh)
        if buf, err := dec.FullPCMBuffer(); err != nil {
                return err
@@ -39,42 +45,83 @@ func Process(model whisper.Model, path string, lang string, speedup, tokens bool
                data = buf.AsFloat32Buffer().Data
        }
 
-       // Set the parameters
+       // Segment callback when -tokens is specified
        var cb whisper.SegmentCallback
-       if lang != "" {
-               if err := context.SetLanguage(lang); err != nil {
-                       return err
-               }
-       }
-       if speedup {
-               context.SetSpeedup(true)
-       }
-       if tokens {
+       if flags.IsTokens() {
                cb = func(segment whisper.Segment) {
-                       fmt.Printf("%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
+                       fmt.Fprintf(flags.Output(), "%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
                        for _, token := range segment.Tokens {
-                               fmt.Printf("%q ", token.Text)
+                               if flags.IsColorize() && context.IsText(token) {
+                                       fmt.Fprint(flags.Output(), Colorize(token.Text, int(token.P*24.0)), " ")
+                               } else {
+                                       fmt.Fprint(flags.Output(), token.Text, " ")
+                               }
                        }
-                       fmt.Println("")
+                       fmt.Fprintln(flags.Output(), "")
+                       fmt.Fprintln(flags.Output(), "")
                }
        }
 
        // Process the data
+       fmt.Fprintf(flags.Output(), "  ...processing %q\n", path)
        if err := context.Process(data, cb); err != nil {
                return err
        }
 
        // Print out the results
+       switch {
+       case flags.GetOut() == "srt":
+               return OutputSRT(os.Stdout, context)
+       case flags.GetOut() == "none":
+               return nil
+       default:
+               return Output(os.Stdout, context, flags.IsColorize())
+       }
+}
+
+// Output text as SRT file
+func OutputSRT(w io.Writer, context whisper.Context) error {
+       n := 1
        for {
                segment, err := context.NextSegment()
                if err == io.EOF {
-                       break
+                       return nil
                } else if err != nil {
                        return err
                }
-               fmt.Printf("[%6s->%6s] %s\n", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond), segment.Text)
+               fmt.Fprintln(w, n)
+               fmt.Fprintln(w, srtTimestamp(segment.Start), " --> ", srtTimestamp(segment.End))
+               fmt.Fprintln(w, segment.Text)
+               fmt.Fprintln(w, "")
+               n++
        }
+}
+
+// Output text to terminal
+func Output(w io.Writer, context whisper.Context, colorize bool) error {
+       for {
+               segment, err := context.NextSegment()
+               if err == io.EOF {
+                       return nil
+               } else if err != nil {
+                       return err
+               }
+               fmt.Fprintf(w, "[%6s->%6s]", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
+               if colorize {
+                       for _, token := range segment.Tokens {
+                               if !context.IsText(token) {
+                                       continue
+                               }
+                               fmt.Fprint(w, " ", Colorize(token.Text, int(token.P*24.0)))
+                       }
+                       fmt.Fprint(w, "\n")
+               } else {
+                       fmt.Fprintln(w, " ", segment.Text)
+               }
+       }
+}
 
-       // Return success
-       return nil
+// Return srtTimestamp
+func srtTimestamp(t time.Duration) string {
+       return fmt.Sprintf("%02d:%02d:%02d,%03d", t/time.Hour, (t%time.Hour)/time.Minute, (t%time.Minute)/time.Second, (t%time.Second)/time.Millisecond)
 }
index c67a7299b8557bb57fdd980abd4fecee796840c3..d7dc238f5ad03f2026ae639a1ed2b4917ed5eb47 100644 (file)
@@ -47,6 +47,7 @@ func (p *Params) SetSpeedup(v bool) {
        p.speed_up = toBool(v)
 }
 
+// Set language id
 func (p *Params) SetLanguage(lang int) error {
        str := C.whisper_lang_str(C.int(lang))
        if str == nil {
@@ -57,6 +58,7 @@ func (p *Params) SetLanguage(lang int) error {
        return nil
 }
 
+// Get language id
 func (p *Params) Language() int {
        if p.language == nil {
                return -1
@@ -64,18 +66,41 @@ func (p *Params) Language() int {
        return int(C.whisper_lang_id(p.language))
 }
 
+// Set number of threads to use
 func (p *Params) SetThreads(threads int) {
        p.n_threads = C.int(threads)
 }
 
+// Set start offset in ms
 func (p *Params) SetOffset(offset_ms int) {
        p.offset_ms = C.int(offset_ms)
 }
 
+// Set audio duration to process in ms
 func (p *Params) SetDuration(duration_ms int) {
        p.duration_ms = C.int(duration_ms)
 }
 
+// Set timestamp token probability threshold (~0.01)
+func (p *Params) SetTokenThreshold(t float32) {
+       p.thold_pt = C.float(t)
+}
+
+// Set timestamp token sum probability threshold (~0.01)
+func (p *Params) SetTokenSumThreshold(t float32) {
+       p.thold_ptsum = C.float(t)
+}
+
+// Set max segment length in characters
+func (p *Params) SetMaxSegmentLength(n int) {
+       p.max_len = C.int(n)
+}
+
+// Set max tokens per segment (0 = no limit)
+func (p *Params) SetMaxTokensPerSegment(n int) {
+       p.max_tokens = C.int(n)
+}
+
 ///////////////////////////////////////////////////////////////////////////////
 // PRIVATE METHODS
 
index 710073f08e2a51869a381f9ad9482c33fa15d4c5..5c22dc13a3108709d091290914e5e7bf2488e5fa 100644 (file)
@@ -11,10 +11,11 @@ import (
 // ERRORS
 
 var (
-       ErrUnableToLoadModel   = errors.New("unable to load model")
-       ErrInternalAppError    = errors.New("internal application error")
-       ErrProcessingFailed    = errors.New("processing failed")
-       ErrUnsupportedLanguage = errors.New("unsupported language")
+       ErrUnableToLoadModel    = errors.New("unable to load model")
+       ErrInternalAppError     = errors.New("internal application error")
+       ErrProcessingFailed     = errors.New("processing failed")
+       ErrUnsupportedLanguage  = errors.New("unsupported language")
+       ErrModelNotMultilingual = errors.New("model is not multilingual")
 )
 
 ///////////////////////////////////////////////////////////////////////////////
index baff611c81343e5d8477c029b8298bfc97f904c3..5dda57e97e63c91ccfdfd960c564f7bb0f4853a7 100644 (file)
@@ -24,7 +24,7 @@ var _ Context = (*context)(nil)
 ///////////////////////////////////////////////////////////////////////////////
 // LIFECYCLE
 
-func NewContext(model *model, params whisper.Params) (Context, error) {
+func newContext(model *model, params whisper.Params) (Context, error) {
        context := new(context)
        context.model = model
        context.params = params
@@ -41,6 +41,9 @@ func (context *context) SetLanguage(lang string) error {
        if context.model.ctx == nil {
                return ErrInternalAppError
        }
+       if !context.model.IsMultilingual() {
+               return ErrModelNotMultilingual
+       }
        if id := context.model.ctx.Whisper_lang_id(lang); id < 0 {
                return ErrUnsupportedLanguage
        } else if err := context.params.SetLanguage(id); err != nil {
@@ -50,16 +53,60 @@ func (context *context) SetLanguage(lang string) error {
        return nil
 }
 
+func (context *context) IsMultilingual() bool {
+       return context.model.IsMultilingual()
+}
+
 // Get language
 func (context *context) Language() string {
        return whisper.Whisper_lang_str(context.params.Language())
 }
 
+// Set translate flag
+func (context *context) SetTranslate(v bool) {
+       context.params.SetTranslate(v)
+}
+
 // Set speedup flag
 func (context *context) SetSpeedup(v bool) {
        context.params.SetSpeedup(v)
 }
 
+// Set number of threads to use
+func (context *context) SetThreads(v uint) {
+       context.params.SetThreads(int(v))
+}
+
+// Set time offset
+func (context *context) SetOffset(v time.Duration) {
+       context.params.SetOffset(int(v.Milliseconds()))
+}
+
+// Set duration of audio to process
+func (context *context) SetDuration(v time.Duration) {
+       context.params.SetOffset(int(v.Milliseconds()))
+}
+
+// Set timestamp token probability threshold (~0.01)
+func (context *context) SetTokenThreshold(t float32) {
+       context.params.SetTokenThreshold(t)
+}
+
+// Set timestamp token sum probability threshold (~0.01)
+func (context *context) SetTokenSumThreshold(t float32) {
+       context.params.SetTokenSumThreshold(t)
+}
+
+// Set max segment length in characters
+func (context *context) SetMaxSegmentLength(n uint) {
+       context.params.SetMaxSegmentLength(int(n))
+}
+
+// Set max tokens per segment (0 = no limit)
+func (context *context) SetMaxTokensPerSegment(n uint) {
+       context.params.SetMaxTokensPerSegment(int(n))
+}
+
 // Process new sample data and return any errors
 func (context *context) Process(data []float32, cb SegmentCallback) error {
        if context.model.ctx == nil {
@@ -119,6 +166,65 @@ func (context *context) NextSegment() (Segment, error) {
        return result, nil
 }
 
+// Test for text tokens
+func (context *context) IsText(t Token) bool {
+       switch {
+       case context.IsBEG(t):
+               return false
+       case context.IsSOT(t):
+               return false
+       case whisper.Token(t.Id) >= context.model.ctx.Whisper_token_eot():
+               return false
+       case context.IsPREV(t):
+               return false
+       case context.IsSOLM(t):
+               return false
+       case context.IsNOT(t):
+               return false
+       default:
+               return true
+       }
+}
+
+// Test for "begin" token
+func (context *context) IsBEG(t Token) bool {
+       return whisper.Token(t.Id) == context.model.ctx.Whisper_token_beg()
+}
+
+// Test for "start of transcription" token
+func (context *context) IsSOT(t Token) bool {
+       return whisper.Token(t.Id) == context.model.ctx.Whisper_token_sot()
+}
+
+// Test for "end of transcription" token
+func (context *context) IsEOT(t Token) bool {
+       return whisper.Token(t.Id) == context.model.ctx.Whisper_token_eot()
+}
+
+// Test for "start of prev" token
+func (context *context) IsPREV(t Token) bool {
+       return whisper.Token(t.Id) == context.model.ctx.Whisper_token_prev()
+}
+
+// Test for "start of lm" token
+func (context *context) IsSOLM(t Token) bool {
+       return whisper.Token(t.Id) == context.model.ctx.Whisper_token_solm()
+}
+
+// Test for "No timestamps" token
+func (context *context) IsNOT(t Token) bool {
+       return whisper.Token(t.Id) == context.model.ctx.Whisper_token_not()
+}
+
+// Test for token associated with a specific language
+func (context *context) IsLANG(t Token, lang string) bool {
+       if id := context.model.ctx.Whisper_lang_id(lang); id >= 0 {
+               return whisper.Token(t.Id) == context.model.ctx.Whisper_token_lang(id)
+       } else {
+               return false
+       }
+}
+
 ///////////////////////////////////////////////////////////////////////////////
 // PRIVATE METHODS
 
index 53e4f3f0257cb207911a31901bee2f3c2944731d..5ca913a8f721defbc984c0d977917a52d8900c65 100644 (file)
@@ -20,6 +20,9 @@ type Model interface {
        // Return a new speech-to-text context.
        NewContext() (Context, error)
 
+       // Return true if the model is multilingual.
+       IsMultilingual() bool
+
        // Return all languages supported.
        Languages() []string
 }
@@ -27,8 +30,18 @@ type Model interface {
 // Context is the speach recognition context.
 type Context interface {
        SetLanguage(string) error // Set the language to use for speech recognition.
+       SetTranslate(bool)        // Set translate flag
+       IsMultilingual() bool     // Return true if the model is multilingual.
        Language() string         // Get language
-       SetSpeedup(bool)          // Set speedup flag
+
+       SetOffset(time.Duration)      // Set offset
+       SetDuration(time.Duration)    // Set duration
+       SetThreads(uint)              // Set number of threads to use
+       SetSpeedup(bool)              // Set speedup flag
+       SetTokenThreshold(float32)    // Set timestamp token probability threshold
+       SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold
+       SetMaxSegmentLength(uint)     // Set max segment length in characters
+       SetMaxTokensPerSegment(uint)  // Set max tokens per segment (0 = no limit)
 
        // Process mono audio data and return any errors.
        // If defined, newly generated segments are passed to the
@@ -38,6 +51,15 @@ type Context interface {
        // After process is called, return segments until the end of the stream
        // is reached, when io.EOF is returned.
        NextSegment() (Segment, error)
+
+       IsBEG(Token) bool          // Test for "begin" token
+       IsSOT(Token) bool          // Test for "start of transcription" token
+       IsEOT(Token) bool          // Test for "end of transcription" token
+       IsPREV(Token) bool         // Test for "start of prev" token
+       IsSOLM(Token) bool         // Test for "start of lm" token
+       IsNOT(Token) bool          // Test for "No timestamps" token
+       IsLANG(Token, string) bool // Test for token associated with a specific language
+       IsText(Token) bool         // Test for text token
 }
 
 // Segment is the text result of a speech recognition.
index 13cb52ca7ecb433643e5757fd336f139a28e5d12..94c2197db739a3ee0e797faaf9085d84b5cbe4a5 100644 (file)
@@ -23,7 +23,7 @@ var _ Model = (*model)(nil)
 ///////////////////////////////////////////////////////////////////////////////
 // LIFECYCLE
 
-func New(path string) (*model, error) {
+func New(path string) (Model, error) {
        model := new(model)
        if _, err := os.Stat(path); err != nil {
                return nil, err
@@ -64,6 +64,11 @@ func (model *model) String() string {
 ///////////////////////////////////////////////////////////////////////////////
 // PUBLIC METHODS
 
+// Return true if model is multilingual (language and translation options are supported)
+func (model *model) IsMultilingual() bool {
+       return model.ctx.Whisper_is_multilingual() != 0
+}
+
 // Return all recognized languages. Initially it is set to auto-detect
 func (model *model) Languages() []string {
        result := make([]string, 0, whisper.Whisper_lang_max_id())
@@ -91,5 +96,5 @@ func (model *model) NewContext() (Context, error) {
        params.SetThreads(runtime.NumCPU())
 
        // Return new context
-       return NewContext(model, params)
+       return newContext(model, params)
 }