]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
bindings : initial import of golang bindings (#287)
authorDavid Thorpe <redacted>
Tue, 20 Dec 2022 06:54:33 +0000 (07:54 +0100)
committerGitHub <redacted>
Tue, 20 Dec 2022 06:54:33 +0000 (08:54 +0200)
* Initial import of golang bindings

* Updated makefile rules

* Updated bindings

* Makefile update to add in more tests

21 files changed:
bindings/go/.gitignore [new file with mode: 0755]
bindings/go/LICENSE [new file with mode: 0755]
bindings/go/Makefile [new file with mode: 0755]
bindings/go/README.md [new file with mode: 0755]
bindings/go/doc.go [new file with mode: 0644]
bindings/go/examples/go-model-download/context.go [new file with mode: 0755]
bindings/go/examples/go-model-download/main.go [new file with mode: 0755]
bindings/go/examples/go-whisper/flags.go [new file with mode: 0755]
bindings/go/examples/go-whisper/main.go [new file with mode: 0755]
bindings/go/examples/go-whisper/process.go [new file with mode: 0755]
bindings/go/go.mod [new file with mode: 0755]
bindings/go/params.go [new file with mode: 0644]
bindings/go/pkg/whisper/consts.go [new file with mode: 0755]
bindings/go/pkg/whisper/context.go [new file with mode: 0755]
bindings/go/pkg/whisper/context_test.go [new file with mode: 0755]
bindings/go/pkg/whisper/doc.go [new file with mode: 0755]
bindings/go/pkg/whisper/interface.go [new file with mode: 0755]
bindings/go/pkg/whisper/model.go [new file with mode: 0755]
bindings/go/samples/jfk.wav [new file with mode: 0755]
bindings/go/whisper.go [new file with mode: 0644]
bindings/go/whisper_test.go [new file with mode: 0644]

diff --git a/bindings/go/.gitignore b/bindings/go/.gitignore
new file mode 100755 (executable)
index 0000000..b4e1084
--- /dev/null
@@ -0,0 +1,3 @@
+build
+models
+go.sum
diff --git a/bindings/go/LICENSE b/bindings/go/LICENSE
new file mode 100755 (executable)
index 0000000..a8f0d7b
--- /dev/null
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 David Thorpe
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/bindings/go/Makefile b/bindings/go/Makefile
new file mode 100755 (executable)
index 0000000..3374212
--- /dev/null
@@ -0,0 +1,38 @@
+CMAKE := $(shell which cmake)
+BUILD_DIR := "build"
+MODELS_DIR := "models"
+EXAMPLES_DIR := $(wildcard examples/*)
+C_INCLUDE_PATH := "../.."
+
+all: clean whisper examples
+
+whisper: mkdir
+       @echo Build whisper
+       @${CMAKE} -S ../.. -B ${BUILD_DIR} -D BUILD_SHARED_LIBS=off -D WHISPER_NO_AVX2=on
+       @${CMAKE} --build ${BUILD_DIR} --target whisper
+
+test: model-small whisper
+       @go mod tidy
+       @go test -v .
+       @go test -v ./pkg/whisper/...
+
+examples: $(EXAMPLES_DIR)
+
+model-small: mkdir examples/go-model-download
+       @${BUILD_DIR}/go-model-download -out models small.en
+
+$(EXAMPLES_DIR): mkdir whisper 
+       @echo Build example $(notdir $@)
+       @go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
+
+mkdir:
+       @echo Mkdir ${BUILD_DIR}
+       @install -d ${BUILD_DIR}
+       @echo Mkdir ${MODELS_DIR}
+       @install -d ${MODELS_DIR}
+
+clean: 
+       @echo Clean
+       @rm -fr $(BUILD_DIR)
+       @go mod tidy
+       @go clean
diff --git a/bindings/go/README.md b/bindings/go/README.md
new file mode 100755 (executable)
index 0000000..8ae89c7
--- /dev/null
@@ -0,0 +1,77 @@
+# Go bindings for Whisper
+
+This package provides Go bindings for whisper.cpp. They have been tested on:
+
+  * Darwin (OS X) 12.6 on x64_64
+  * Debian Linux on arm64
+  * Fedora Linux on x86_64
+
+The "low level" bindings are in the `bindings/go` directory and there is a more
+Go-style package in the `bindings/go/pkg/whisper` directory. The most simple usage
+is as follows:
+
+```go
+import (
+       "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
+)
+
+func main() {
+       var modelpath string // Path to the model
+       var samples []float32 // Samples to process
+
+       // Load the model
+       model, err := whisper.New(modelpath)
+       if err != nil {
+               panic(err)
+       }
+       defer model.Close()
+
+       // Process samples
+       context, err := model.NewContext()
+       if err != nil {
+               panic(err)
+       }
+       if err := context.Process(samples, nil); err != nil {
+               return err
+       }
+
+       // Print out the results
+       for {
+               segment, err := context.NextSegment()
+               if err != nil {
+                       break
+               }
+               fmt.Printf("[%6s->%6s] %s\n", segment.Start, segment.End, segment.Text)
+       }
+}
+```
+
+## Building & Testing
+
+In order to build, you need to have the Go compiler installed. You can get it from [here](https://golang.org/dl/). Run the tests with:
+
+```bash
+git clone https://github.com/ggerganov/whisper.cpp.git
+cd whisper.cpp/bindings/go
+make test
+```
+
+This will compile a static `libwhisper.a` in a `build` folder, download a model file, then run the tests. To build the examples:
+
+```bash
+make examples
+```
+
+The examples are placed in the `build` directory. Once built, you can download all the models with the following command:
+
+```bash
+./build/go-model-download -out models
+```
+
+And you can then test a model against samples with the following command:
+
+```bash
+./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav 
+```
+
+
diff --git a/bindings/go/doc.go b/bindings/go/doc.go
new file mode 100644 (file)
index 0000000..dcc351f
--- /dev/null
@@ -0,0 +1,5 @@
+/*
+github.com/ggerganov/whisper.cpp/bindings/go
+provides a speech-to-text service bindings for the Go programming language.
+*/
+package whisper
diff --git a/bindings/go/examples/go-model-download/context.go b/bindings/go/examples/go-model-download/context.go
new file mode 100755 (executable)
index 0000000..639d8f5
--- /dev/null
@@ -0,0 +1,30 @@
+package main
+
+import (
+       "context"
+       "os"
+       "os/signal"
+)
+
+// ContextForSignal returns a context object which is cancelled when a signal
+// is received. It returns nil if no signal parameter is provided
+func ContextForSignal(signals ...os.Signal) context.Context {
+       if len(signals) == 0 {
+               return nil
+       }
+
+       ch := make(chan os.Signal)
+       ctx, cancel := context.WithCancel(context.Background())
+
+       // Send message on channel when signal received
+       signal.Notify(ch, signals...)
+
+       // When any signal received, call cancel
+       go func() {
+               <-ch
+               cancel()
+       }()
+
+       // Return success
+       return ctx
+}
diff --git a/bindings/go/examples/go-model-download/main.go b/bindings/go/examples/go-model-download/main.go
new file mode 100755 (executable)
index 0000000..841a2c6
--- /dev/null
@@ -0,0 +1,206 @@
+package main
+
+import (
+       "context"
+       "flag"
+       "fmt"
+       "io"
+       "net/http"
+       "net/url"
+       "os"
+       "path/filepath"
+       "syscall"
+       "time"
+)
+
+///////////////////////////////////////////////////////////////////////////////
+// CONSTANTS
+
+const (
+       srcUrl        = "https://huggingface.co/"                           // The location of the models
+       srcPathPrefix = "/datasets/ggerganov/whisper.cpp/resolve/main/ggml" // Filename prefix
+       srcExt        = ".bin"                                              // Filename extension
+       bufSize       = 1024 * 64                                           // Size of the buffer used for downloading the model
+)
+
+var (
+       // The models which will be downloaded, if no model is specified as an argument
+       modelNames = []string{"tiny.en", "tiny", "base.en", "base", "small.en", "small", "medium.en", "medium", "large-v1", "large"}
+)
+
+var (
+       // The output folder. When not set, use current working directory.
+       flagOut = flag.String("out", "", "Output folder")
+
+       // HTTP timeout parameter - will timeout if takes longer than this to download a model
+       flagTimeout = flag.Duration("timeout", 30*time.Minute, "HTTP timeout")
+
+       // Quiet parameter - will not print progress if set
+       flagQuiet = flag.Bool("quiet", false, "Quiet mode")
+)
+
+///////////////////////////////////////////////////////////////////////////////
+// MAIN
+
+func main() {
+       flag.Usage = func() {
+               name := filepath.Base(flag.CommandLine.Name())
+               fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [options] <model>\n\n", name)
+               flag.PrintDefaults()
+       }
+       flag.Parse()
+
+       // Get output path
+       out, err := GetOut()
+       if err != nil {
+               fmt.Fprintln(os.Stderr, "Error:", err)
+               os.Exit(-1)
+       }
+
+       // Create context which quits on SIGINT or SIGQUIT
+       ctx := ContextForSignal(os.Interrupt, syscall.SIGQUIT)
+
+       // Progress filehandle
+       progress := os.Stdout
+       if *flagQuiet {
+               progress, err = os.Open(os.DevNull)
+               if err != nil {
+                       fmt.Fprintln(os.Stderr, "Error:", err)
+                       os.Exit(-1)
+               }
+               defer progress.Close()
+       }
+
+       // Download models - exit on error or interrupt
+       for _, model := range GetModels() {
+               url, err := URLForModel(model)
+               if err != nil {
+                       fmt.Fprintln(os.Stderr, "Error:", err)
+                       continue
+               } else if path, err := Download(ctx, progress, url, out); err == nil || err == io.EOF {
+                       continue
+               } else if err == context.Canceled {
+                       os.Remove(path)
+                       fmt.Fprintln(progress, "\nInterrupted")
+                       break
+               } else if err == context.DeadlineExceeded {
+                       os.Remove(path)
+                       fmt.Fprintln(progress, "Timeout downloading model")
+                       continue
+               } else {
+                       os.Remove(path)
+                       fmt.Fprintln(os.Stderr, "Error:", err)
+                       break
+               }
+       }
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// PUBLIC METHODS
+
+// GetOut returns the path to the output directory
+func GetOut() (string, error) {
+       if *flagOut == "" {
+               return os.Getwd()
+       }
+       if info, err := os.Stat(*flagOut); err != nil {
+               return "", err
+       } else if !info.IsDir() {
+               return "", fmt.Errorf("not a directory: %s", info.Name())
+       } else {
+               return *flagOut, nil
+       }
+}
+
+// GetModels returns the list of models to download
+func GetModels() []string {
+       if flag.NArg() == 0 {
+               return modelNames
+       } else {
+               return flag.Args()
+       }
+}
+
+// URLForModel returns the URL for the given model on huggingface.co
+func URLForModel(model string) (string, error) {
+       url, err := url.Parse(srcUrl)
+       if err != nil {
+               return "", err
+       } else {
+               url.Path = srcPathPrefix + "-" + model + srcExt
+       }
+       return url.String(), nil
+}
+
+// Download downloads the model from the given URL to the given output directory
+func Download(ctx context.Context, p io.Writer, model, out string) (string, error) {
+       // Create HTTP client
+       client := http.Client{
+               Timeout: *flagTimeout,
+       }
+
+       // Initiate the download
+       req, err := http.NewRequest("GET", model, nil)
+       if err != nil {
+               return "", err
+       }
+       resp, err := client.Do(req)
+       if err != nil {
+               return "", err
+       }
+       defer resp.Body.Close()
+       if resp.StatusCode != http.StatusOK {
+               return "", fmt.Errorf("%s: %s", model, resp.Status)
+       }
+
+       // If output file exists and is the same size as the model, skip
+       path := filepath.Join(out, filepath.Base(model))
+       if info, err := os.Stat(path); err == nil && info.Size() == resp.ContentLength {
+               fmt.Fprintln(p, "Skipping", model, "as it already exists")
+               return "", nil
+       }
+
+       // Create file
+       w, err := os.Create(path)
+       if err != nil {
+               return "", err
+       }
+       defer w.Close()
+
+       // Report
+       fmt.Fprintln(p, "Downloading", model, "to", out)
+
+       // Progressively download the model
+       data := make([]byte, bufSize)
+       count, pct := int64(0), int64(0)
+       ticker := time.NewTicker(5 * time.Second)
+       for {
+               select {
+               case <-ctx.Done():
+                       // Cancelled, return error
+                       return path, ctx.Err()
+               case <-ticker.C:
+                       pct = DownloadReport(p, pct, count, resp.ContentLength)
+               default:
+                       // Read body
+                       n, err := resp.Body.Read(data)
+                       if err != nil {
+                               DownloadReport(p, pct, count, resp.ContentLength)
+                               return path, err
+                       } else if m, err := w.Write(data[:n]); err != nil {
+                               return path, err
+                       } else {
+                               count += int64(m)
+                       }
+               }
+       }
+}
+
+// Report periodically reports the download progress when percentage changes
+func DownloadReport(w io.Writer, pct, count, total int64) int64 {
+       pct_ := count * 100 / total
+       if pct_ > pct {
+               fmt.Fprintf(w, "  ...%d MB written (%d%%)\n", count/1e6, pct_)
+       }
+       return pct_
+}
diff --git a/bindings/go/examples/go-whisper/flags.go b/bindings/go/examples/go-whisper/flags.go
new file mode 100755 (executable)
index 0000000..a5353d1
--- /dev/null
@@ -0,0 +1,61 @@
+package main
+
+import (
+       "flag"
+)
+
+///////////////////////////////////////////////////////////////////////////////
+// TYPES
+
+type Flags struct {
+       *flag.FlagSet
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// LIFECYCLE
+
+func NewFlags(name string, args []string) (*Flags, error) {
+       flags := &Flags{
+               FlagSet: flag.NewFlagSet(name, flag.ContinueOnError),
+       }
+
+       // Register the command line arguments
+       registerFlags(flags)
+
+       // Parse command line
+       if err := flags.Parse(args); err != nil {
+               return nil, err
+       }
+
+       // Return success
+       return flags, nil
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// PUBLIC METHODS
+
+func (flags *Flags) GetModel() string {
+       return flags.Lookup("model").Value.String()
+}
+
+func (flags *Flags) GetLanguage() string {
+       return flags.Lookup("language").Value.String()
+}
+
+func (flags *Flags) IsSpeedup() bool {
+       return flags.Lookup("speedup").Value.String() == "true"
+}
+
+func (flags *Flags) IsTokens() bool {
+       return flags.Lookup("tokens").Value.String() == "true"
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// PRIVATE METHODS
+
+func registerFlags(flag *Flags) {
+       flag.String("model", "", "Path to the model file")
+       flag.String("language", "", "Language")
+       flag.Bool("speedup", false, "Enable speedup")
+       flag.Bool("tokens", false, "Display tokens")
+}
diff --git a/bindings/go/examples/go-whisper/main.go b/bindings/go/examples/go-whisper/main.go
new file mode 100755 (executable)
index 0000000..b3a89db
--- /dev/null
@@ -0,0 +1,44 @@
+package main
+
+import (
+       "flag"
+       "fmt"
+       "os"
+       "path/filepath"
+
+       // Packages
+       whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
+)
+
+func main() {
+       flags, err := NewFlags(filepath.Base(os.Args[0]), os.Args[1:])
+       if err == flag.ErrHelp {
+               os.Exit(0)
+       } else if err != nil {
+               fmt.Fprintln(os.Stderr, err)
+               os.Exit(1)
+       } else if flags.GetModel() == "" {
+               fmt.Fprintln(os.Stderr, "Use -model flag to specify which model file to use")
+               os.Exit(1)
+       } else if flags.NArg() == 0 {
+               fmt.Fprintln(os.Stderr, "No input files specified")
+               os.Exit(1)
+       }
+
+       // Load model
+       model, err := whisper.New(flags.GetModel())
+       if err != nil {
+               fmt.Fprintln(os.Stderr, err)
+               os.Exit(1)
+       }
+       defer model.Close()
+
+       // Process files
+       for _, filename := range flags.Args() {
+               fmt.Println("Processing", filename)
+               if err := Process(model, filename, flags.GetLanguage(), flags.IsSpeedup(), flags.IsTokens()); err != nil {
+                       fmt.Fprintln(os.Stderr, err)
+                       continue
+               }
+       }
+}
diff --git a/bindings/go/examples/go-whisper/process.go b/bindings/go/examples/go-whisper/process.go
new file mode 100755 (executable)
index 0000000..a0e2be8
--- /dev/null
@@ -0,0 +1,80 @@
+package main
+
+import (
+       "fmt"
+       "io"
+       "os"
+       "time"
+
+       // Package imports
+       whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
+       wav "github.com/go-audio/wav"
+)
+
+func Process(model whisper.Model, path string, lang string, speedup, tokens bool) error {
+       var data []float32
+
+       // Create processing context
+       context, err := model.NewContext()
+       if err != nil {
+               return err
+       }
+
+       // Open the file
+       fh, err := os.Open(path)
+       if err != nil {
+               return err
+       }
+       defer fh.Close()
+
+       // Decode the WAV file
+       dec := wav.NewDecoder(fh)
+       if buf, err := dec.FullPCMBuffer(); err != nil {
+               return err
+       } else if dec.SampleRate != whisper.SampleRate {
+               return fmt.Errorf("unsupported sample rate: %d", dec.SampleRate)
+       } else if dec.NumChans != 1 {
+               return fmt.Errorf("unsupported number of channels: %d", dec.NumChans)
+       } else {
+               data = buf.AsFloat32Buffer().Data
+       }
+
+       // Set the parameters
+       var cb whisper.SegmentCallback
+       if lang != "" {
+               if err := context.SetLanguage(lang); err != nil {
+                       return err
+               }
+       }
+       if speedup {
+               context.SetSpeedup(true)
+       }
+       if tokens {
+               cb = func(segment whisper.Segment) {
+                       fmt.Printf("%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
+                       for _, token := range segment.Tokens {
+                               fmt.Printf("%q ", token.Text)
+                       }
+                       fmt.Println("")
+               }
+       }
+
+       // Process the data
+       if err := context.Process(data, cb); err != nil {
+               return err
+       }
+
+       // Print out the results
+       for {
+               segment, err := context.NextSegment()
+               if err == io.EOF {
+                       break
+               } else if err != nil {
+                       return err
+               }
+               fmt.Printf("[%6s->%6s] %s\n", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond), segment.Text)
+       }
+
+       // Return success
+       return nil
+}
diff --git a/bindings/go/go.mod b/bindings/go/go.mod
new file mode 100755 (executable)
index 0000000..594f184
--- /dev/null
@@ -0,0 +1,16 @@
+module github.com/ggerganov/whisper.cpp/bindings/go
+
+go 1.19
+
+require (
+       github.com/go-audio/wav v1.1.0
+       github.com/stretchr/testify v1.8.1
+)
+
+require (
+       github.com/davecgh/go-spew v1.1.1 // indirect
+       github.com/go-audio/audio v1.0.0 // indirect
+       github.com/go-audio/riff v1.0.0 // indirect
+       github.com/pmezard/go-difflib v1.0.0 // indirect
+       gopkg.in/yaml.v3 v3.0.1 // indirect
+)
diff --git a/bindings/go/params.go b/bindings/go/params.go
new file mode 100644 (file)
index 0000000..7f4c509
--- /dev/null
@@ -0,0 +1,134 @@
+package whisper
+
+// This file defines the whisper_token, whisper_token_data and whisper_full_params
+// structures, which are used by the whisper_full() function.
+
+import (
+       "fmt"
+)
+
+///////////////////////////////////////////////////////////////////////////////
+// CGO
+
+/*
+#include <whisper.h>
+*/
+import "C"
+
+///////////////////////////////////////////////////////////////////////////////
+// PUBLIC METHODS
+
+func (p *Params) SetTranslate(v bool) {
+       p.translate = toBool(v)
+}
+
+func (p *Params) SetNoContext(v bool) {
+       p.no_context = toBool(v)
+}
+
+func (p *Params) SetSingleSegment(v bool) {
+       p.single_segment = toBool(v)
+}
+
+func (p *Params) SetPrintSpecial(v bool) {
+       p.print_special = toBool(v)
+}
+
+func (p *Params) SetPrintProgress(v bool) {
+       p.print_progress = toBool(v)
+}
+
+func (p *Params) SetPrintRealtime(v bool) {
+       p.print_realtime = toBool(v)
+}
+
+func (p *Params) SetPrintTimestamps(v bool) {
+       p.print_timestamps = toBool(v)
+}
+
+func (p *Params) SetSpeedup(v bool) {
+       p.speed_up = toBool(v)
+}
+
+func (p *Params) SetLanguage(lang int) error {
+       str := C.whisper_lang_str(C.int(lang))
+       if str == nil {
+               return ErrInvalidLanguage
+       } else {
+               p.language = str
+       }
+       return nil
+}
+
+func (p *Params) Language() int {
+       if p.language == nil {
+               return -1
+       }
+       return int(C.whisper_lang_id(p.language))
+}
+
+func (p *Params) SetThreads(threads int) {
+       p.n_threads = C.int(threads)
+}
+
+func (p *Params) SetOffset(offset_ms int) {
+       p.offset_ms = C.int(offset_ms)
+}
+
+func (p *Params) SetDuration(duration_ms int) {
+       p.duration_ms = C.int(duration_ms)
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// PRIVATE METHODS
+
+func toBool(v bool) C.bool {
+       if v {
+               return C.bool(true)
+       }
+       return C.bool(false)
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// STRINGIFY
+
+func (p *Params) String() string {
+       str := "<whisper.params"
+       str += fmt.Sprintf(" strategy=%v", p.strategy)
+       str += fmt.Sprintf(" n_threads=%d", p.n_threads)
+       if p.language != nil {
+               str += fmt.Sprintf(" language=%s", C.GoString(p.language))
+       }
+       str += fmt.Sprintf(" n_max_text_ctx=%d", p.n_max_text_ctx)
+       str += fmt.Sprintf(" offset_ms=%d", p.offset_ms)
+       str += fmt.Sprintf(" duration_ms=%d", p.duration_ms)
+       if p.translate {
+               str += " translate"
+       }
+       if p.no_context {
+               str += " no_context"
+       }
+       if p.single_segment {
+               str += " single_segment"
+       }
+       if p.print_special {
+               str += " print_special"
+       }
+       if p.print_progress {
+               str += " print_progress"
+       }
+       if p.print_realtime {
+               str += " print_realtime"
+       }
+       if p.print_timestamps {
+               str += " print_timestamps"
+       }
+       if p.token_timestamps {
+               str += " token_timestamps"
+       }
+       if p.speed_up {
+               str += " speed_up"
+       }
+
+       return str + ">"
+}
diff --git a/bindings/go/pkg/whisper/consts.go b/bindings/go/pkg/whisper/consts.go
new file mode 100755 (executable)
index 0000000..710073f
--- /dev/null
@@ -0,0 +1,27 @@
+package whisper
+
+import (
+       "errors"
+
+       // Bindings
+       whisper "github.com/ggerganov/whisper.cpp/bindings/go"
+)
+
+///////////////////////////////////////////////////////////////////////////////
+// ERRORS
+
+var (
+       ErrUnableToLoadModel   = errors.New("unable to load model")
+       ErrInternalAppError    = errors.New("internal application error")
+       ErrProcessingFailed    = errors.New("processing failed")
+       ErrUnsupportedLanguage = errors.New("unsupported language")
+)
+
+///////////////////////////////////////////////////////////////////////////////
+// CONSTANTS
+
+// SampleRate is the sample rate of the audio data.
+const SampleRate = whisper.SampleRate
+
+// SampleBits is the number of bytes per sample.
+const SampleBits = whisper.SampleBits
diff --git a/bindings/go/pkg/whisper/context.go b/bindings/go/pkg/whisper/context.go
new file mode 100755 (executable)
index 0000000..baff611
--- /dev/null
@@ -0,0 +1,145 @@
+package whisper
+
+import (
+       "io"
+       "strings"
+       "time"
+
+       // Bindings
+       whisper "github.com/ggerganov/whisper.cpp/bindings/go"
+)
+
+///////////////////////////////////////////////////////////////////////////////
+// TYPES
+
+type context struct {
+       n      int
+       model  *model
+       params whisper.Params
+}
+
+// Make sure context adheres to the interface
+var _ Context = (*context)(nil)
+
+///////////////////////////////////////////////////////////////////////////////
+// LIFECYCLE
+
+func NewContext(model *model, params whisper.Params) (Context, error) {
+       context := new(context)
+       context.model = model
+       context.params = params
+
+       // Return success
+       return context, nil
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// PUBLIC METHODS
+
+// Set the language to use for speech recognition.
+func (context *context) SetLanguage(lang string) error {
+       if context.model.ctx == nil {
+               return ErrInternalAppError
+       }
+       if id := context.model.ctx.Whisper_lang_id(lang); id < 0 {
+               return ErrUnsupportedLanguage
+       } else if err := context.params.SetLanguage(id); err != nil {
+               return err
+       }
+       // Return success
+       return nil
+}
+
+// Get language
+func (context *context) Language() string {
+       return whisper.Whisper_lang_str(context.params.Language())
+}
+
+// Set speedup flag
+func (context *context) SetSpeedup(v bool) {
+       context.params.SetSpeedup(v)
+}
+
+// Process new sample data and return any errors
+func (context *context) Process(data []float32, cb SegmentCallback) error {
+       if context.model.ctx == nil {
+               return ErrInternalAppError
+       }
+       // If the callback is defined then we force on single_segment mode
+       if cb != nil {
+               context.params.SetSingleSegment(true)
+       }
+
+       // We don't do parallel processing at the moment
+       processors := 0
+       if processors > 1 {
+               if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
+                       if cb != nil {
+                               num_segments := context.model.ctx.Whisper_full_n_segments()
+                               s0 := num_segments - new
+                               for i := s0; i < num_segments; i++ {
+                                       cb(toSegment(context.model.ctx, i))
+                               }
+                       }
+               }); err != nil {
+                       return err
+               }
+       } else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
+               if cb != nil {
+                       num_segments := context.model.ctx.Whisper_full_n_segments()
+                       s0 := num_segments - new
+                       for i := s0; i < num_segments; i++ {
+                               cb(toSegment(context.model.ctx, i))
+                       }
+               }
+       }); err != nil {
+               return err
+       }
+
+       // Return success
+       return nil
+}
+
+// Return the next segment of tokens
+func (context *context) NextSegment() (Segment, error) {
+       if context.model.ctx == nil {
+               return Segment{}, ErrInternalAppError
+       }
+       if context.n >= context.model.ctx.Whisper_full_n_segments() {
+               return Segment{}, io.EOF
+       }
+
+       // Populate result
+       result := toSegment(context.model.ctx, context.n)
+
+       // Increment the cursor
+       context.n++
+
+       // Return success
+       return result, nil
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// PRIVATE METHODS
+
+func toSegment(ctx *whisper.Context, n int) Segment {
+       return Segment{
+               Num:    n,
+               Text:   strings.TrimSpace(ctx.Whisper_full_get_segment_text(n)),
+               Start:  time.Duration(ctx.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10,
+               End:    time.Duration(ctx.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10,
+               Tokens: toTokens(ctx, n),
+       }
+}
+
+func toTokens(ctx *whisper.Context, n int) []Token {
+       result := make([]Token, ctx.Whisper_full_n_tokens(n))
+       for i := 0; i < len(result); i++ {
+               result[i] = Token{
+                       Id:   int(ctx.Whisper_full_get_token_id(n, i)),
+                       Text: strings.TrimSpace(ctx.Whisper_full_get_token_text(n, i)),
+                       P:    ctx.Whisper_full_get_token_p(n, i),
+               }
+       }
+       return result
+}
diff --git a/bindings/go/pkg/whisper/context_test.go b/bindings/go/pkg/whisper/context_test.go
new file mode 100755 (executable)
index 0000000..c8c6016
--- /dev/null
@@ -0,0 +1,55 @@
+package whisper_test
+
+import (
+       "os"
+       "testing"
+
+       // Packages
+       whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
+       assert "github.com/stretchr/testify/assert"
+)
+
+const (
+       ModelPath  = "../../models/ggml-tiny.bin"
+       SamplePath = "../../samples/jfk.wav"
+)
+
+func Test_Whisper_000(t *testing.T) {
+       assert := assert.New(t)
+       if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
+               t.Skip("Skipping test, model not found:", ModelPath)
+       }
+       if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
+               t.Skip("Skipping test, sample not found:", SamplePath)
+       }
+
+       // Load model
+       model, err := whisper.New(ModelPath)
+       assert.NoError(err)
+       assert.NotNil(model)
+       assert.NoError(model.Close())
+
+       t.Log("languages=", model.Languages())
+}
+
+func Test_Whisper_001(t *testing.T) {
+       assert := assert.New(t)
+       if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
+               t.Skip("Skipping test, model not found:", ModelPath)
+       }
+       if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
+               t.Skip("Skipping test, sample not found:", SamplePath)
+       }
+
+       // Load model
+       model, err := whisper.New(ModelPath)
+       assert.NoError(err)
+       assert.NotNil(model)
+       defer model.Close()
+
+       // Get context for decoding
+       ctx, err := model.NewContext()
+       assert.NoError(err)
+       assert.NotNil(ctx)
+
+}
diff --git a/bindings/go/pkg/whisper/doc.go b/bindings/go/pkg/whisper/doc.go
new file mode 100755 (executable)
index 0000000..fd4f1b9
--- /dev/null
@@ -0,0 +1,4 @@
+/*
+This is the higher-level speech-to-text whisper.cpp API for go
+*/
+package whisper
diff --git a/bindings/go/pkg/whisper/interface.go b/bindings/go/pkg/whisper/interface.go
new file mode 100755 (executable)
index 0000000..53e4f3f
--- /dev/null
@@ -0,0 +1,63 @@
+package whisper
+
+import (
+       "io"
+       "time"
+)
+
+///////////////////////////////////////////////////////////////////////////////
+// TYPES
+
+// SegmentCallback is the callback function for processing segments in real
+// time. It is called during the Process function
+type SegmentCallback func(Segment)
+
+// Model is the interface to a whisper model. Create a new model with the
+// function whisper.New(string)
+type Model interface {
+       io.Closer
+
+       // Return a new speech-to-text context.
+       NewContext() (Context, error)
+
+       // Return all languages supported.
+       Languages() []string
+}
+
+// Context is the speach recognition context.
+type Context interface {
+       SetLanguage(string) error // Set the language to use for speech recognition.
+       Language() string         // Get language
+       SetSpeedup(bool)          // Set speedup flag
+
+       // Process mono audio data and return any errors.
+       // If defined, newly generated segments are passed to the
+       // callback function during processing.
+       Process([]float32, SegmentCallback) error
+
+       // After process is called, return segments until the end of the stream
+       // is reached, when io.EOF is returned.
+       NextSegment() (Segment, error)
+}
+
+// Segment is the text result of a speech recognition.
+type Segment struct {
+       // Segment Number
+       Num int
+
+       // Time beginning and end timestamps for the segment.
+       Start, End time.Duration
+
+       // The text of the segment.
+       Text string
+
+       // The tokens of the segment.
+       Tokens []Token
+}
+
+// Token is a text or special token
+type Token struct {
+       Id   int
+       Text string
+       P    float32
+}
diff --git a/bindings/go/pkg/whisper/model.go b/bindings/go/pkg/whisper/model.go
new file mode 100755 (executable)
index 0000000..13cb52c
--- /dev/null
@@ -0,0 +1,95 @@
+package whisper
+
+import (
+       "fmt"
+       "os"
+       "runtime"
+
+       // Bindings
+       whisper "github.com/ggerganov/whisper.cpp/bindings/go"
+)
+
+///////////////////////////////////////////////////////////////////////////////
+// TYPES
+
+type model struct {
+       path string
+       ctx  *whisper.Context
+}
+
+// Make sure model adheres to the interface
+var _ Model = (*model)(nil)
+
+///////////////////////////////////////////////////////////////////////////////
+// LIFECYCLE
+
+func New(path string) (*model, error) {
+       model := new(model)
+       if _, err := os.Stat(path); err != nil {
+               return nil, err
+       } else if ctx := whisper.Whisper_init(path); ctx == nil {
+               return nil, ErrUnableToLoadModel
+       } else {
+               model.ctx = ctx
+               model.path = path
+       }
+
+       // Return success
+       return model, nil
+}
+
+func (model *model) Close() error {
+       if model.ctx != nil {
+               model.ctx.Whisper_free()
+       }
+
+       // Release resources
+       model.ctx = nil
+
+       // Return success
+       return nil
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// STRINGIFY
+
+func (model *model) String() string {
+       str := "<whisper.model"
+       if model.ctx != nil {
+               str += fmt.Sprintf(" model=%q", model.path)
+       }
+       return str + ">"
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// PUBLIC METHODS
+
+// Return all recognized languages. Initially it is set to auto-detect
+func (model *model) Languages() []string {
+       result := make([]string, 0, whisper.Whisper_lang_max_id())
+       for i := 0; i < whisper.Whisper_lang_max_id(); i++ {
+               str := whisper.Whisper_lang_str(i)
+               if model.ctx.Whisper_lang_id(str) >= 0 {
+                       result = append(result, str)
+               }
+       }
+       return result
+}
+
+func (model *model) NewContext() (Context, error) {
+       if model.ctx == nil {
+               return nil, ErrInternalAppError
+       }
+
+       // Create new context
+       params := model.ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
+       params.SetTranslate(false)
+       params.SetPrintSpecial(false)
+       params.SetPrintProgress(false)
+       params.SetPrintRealtime(false)
+       params.SetPrintTimestamps(false)
+       params.SetThreads(runtime.NumCPU())
+
+       // Return new context
+       return NewContext(model, params)
+}
diff --git a/bindings/go/samples/jfk.wav b/bindings/go/samples/jfk.wav
new file mode 100755 (executable)
index 0000000..3184d37
Binary files /dev/null and b/bindings/go/samples/jfk.wav differ
diff --git a/bindings/go/whisper.go b/bindings/go/whisper.go
new file mode 100644 (file)
index 0000000..2584f7b
--- /dev/null
@@ -0,0 +1,412 @@
+package whisper
+
+import (
+       "errors"
+       "unsafe"
+)
+
+///////////////////////////////////////////////////////////////////////////////
+// CGO
+
+/*
+#cgo CFLAGS: -I${SRCDIR}/../..
+#cgo LDFLAGS: -L${SRCDIR}/build -lwhisper -lm -lstdc++
+#cgo darwin LDFLAGS: -framework Accelerate
+#include <whisper.h>
+#include <stdlib.h>
+
+extern void callNewSegment(void* user_data, int new);
+extern bool callEncoderBegin(void* user_data);
+
+// Text segment callback
+// Called on every newly generated text segment
+// Use the whisper_full_...() functions to obtain the text segments
+static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void* user_data) {
+    if(user_data != NULL && ctx != NULL) {
+        callNewSegment(user_data, n_new);
+    }
+}
+
+// Encoder begin callback
+// If not NULL, called before the encoder starts
+// If it returns false, the computation is aborted
+static bool whisper_encoder_begin_cb(struct whisper_context* ctx, void* user_data) {
+    if(user_data != NULL && ctx != NULL) {
+        return callEncoderBegin(user_data);
+    }
+    return false;
+}
+
+// Get default parameters and set callbacks
+static struct whisper_full_params whisper_full_default_params_cb(struct whisper_context* ctx, enum whisper_sampling_strategy strategy) {
+       struct whisper_full_params params = whisper_full_default_params(strategy);
+       params.new_segment_callback = whisper_new_segment_cb;
+       params.new_segment_callback_user_data = (void*)(ctx);
+       params.encoder_begin_callback = whisper_encoder_begin_cb;
+       params.encoder_begin_callback_user_data = (void*)(ctx);
+       return params;
+}
+*/
+import "C"
+
+///////////////////////////////////////////////////////////////////////////////
+// TYPES
+
+type (
+       Context          C.struct_whisper_context
+       Token            C.whisper_token
+       TokenData        C.struct_whisper_token_data
+       SamplingStrategy C.enum_whisper_sampling_strategy
+       Params           C.struct_whisper_full_params
+)
+
+///////////////////////////////////////////////////////////////////////////////
+// GLOBALS
+
+const (
+       SAMPLING_GREEDY      SamplingStrategy = C.WHISPER_SAMPLING_GREEDY
+       SAMPLING_BEAM_SEARCH SamplingStrategy = C.WHISPER_SAMPLING_BEAM_SEARCH
+)
+
+const (
+       SampleRate = C.WHISPER_SAMPLE_RATE                 // Expected sample rate, samples per second
+       SampleBits = uint16(unsafe.Sizeof(C.float(0))) * 8 // Sample size in bits
+       NumFFT     = C.WHISPER_N_FFT
+       NumMEL     = C.WHISPER_N_MEL
+       HopLength  = C.WHISPER_HOP_LENGTH
+       ChunkSize  = C.WHISPER_CHUNK_SIZE
+)
+
+var (
+       ErrTokenizerFailed  = errors.New("whisper_tokenize failed")
+       ErrAutoDetectFailed = errors.New("whisper_lang_auto_detect failed")
+       ErrConversionFailed = errors.New("whisper_convert failed")
+       ErrInvalidLanguage  = errors.New("invalid language")
+)
+
+///////////////////////////////////////////////////////////////////////////////
+// PUBLIC METHODS
+
+// Allocates all memory needed for the model and loads the model from the given file.
+// Returns NULL on failure.
+func Whisper_init(path string) *Context {
+       cPath := C.CString(path)
+       defer C.free(unsafe.Pointer(cPath))
+       if ctx := C.whisper_init(cPath); ctx != nil {
+               return (*Context)(ctx)
+       } else {
+               return nil
+       }
+}
+
+// Frees all memory allocated by the model.
+func (ctx *Context) Whisper_free() {
+       C.whisper_free((*C.struct_whisper_context)(ctx))
+}
+
+// Convert RAW PCM audio to log mel spectrogram.
+// The resulting spectrogram is stored inside the provided whisper context.
+func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) error {
+       if C.whisper_pcm_to_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(threads)) == 0 {
+               return nil
+       } else {
+               return ErrConversionFailed
+       }
+}
+
+// This can be used to set a custom log mel spectrogram inside the provided whisper context.
+// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
+// n_mel must be 80
+func (ctx *Context) Whisper_set_mel(data []float32, n_mel int) error {
+       if C.whisper_set_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(n_mel)) == 0 {
+               return nil
+       } else {
+               return ErrConversionFailed
+       }
+}
+
+// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
+// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
+// offset can be used to specify the offset of the first frame in the spectrogram.
+func (ctx *Context) Whisper_encode(offset, threads int) error {
+       if C.whisper_encode((*C.struct_whisper_context)(ctx), C.int(offset), C.int(threads)) == 0 {
+               return nil
+       } else {
+               return ErrConversionFailed
+       }
+}
+
+// Run the Whisper decoder to obtain the logits and probabilities for the next token.
+// Make sure to call whisper_encode() first.
+// tokens + n_tokens is the provided context for the decoder.
+// n_past is the number of tokens to use from previous decoder calls.
+func (ctx *Context) Whisper_decode(tokens []Token, past, threads int) error {
+       if C.whisper_decode((*C.struct_whisper_context)(ctx), (*C.whisper_token)(&tokens[0]), C.int(len(tokens)), C.int(past), C.int(threads)) == 0 {
+               return nil
+       } else {
+               return ErrConversionFailed
+       }
+}
+
+// whisper_sample_best() returns the token with the highest probability
+func (ctx *Context) Whisper_sample_best() TokenData {
+       return TokenData(C.whisper_sample_best((*C.struct_whisper_context)(ctx)))
+}
+
+// whisper_sample_timestamp() returns the most probable timestamp token
+func (ctx *Context) Whisper_sample_timestamp(is_initial bool) TokenData {
+       return TokenData(C.whisper_sample_timestamp((*C.struct_whisper_context)(ctx), C.bool(is_initial)))
+}
+
+// Convert the provided text into tokens. The tokens pointer must be large enough to hold the resulting tokens.
+// Returns the number of tokens on success
+func (ctx *Context) Whisper_tokenize(text string, tokens []Token) (int, error) {
+       cText := C.CString(text)
+       defer C.free(unsafe.Pointer(cText))
+       if n := C.whisper_tokenize((*C.struct_whisper_context)(ctx), cText, (*C.whisper_token)(&tokens[0]), C.int(len(tokens))); n >= 0 {
+               return int(n), nil
+       } else {
+               return 0, ErrTokenizerFailed
+       }
+}
+
+// Return the id of the specified language, returns -1 if not found
+func (ctx *Context) Whisper_lang_id(lang string) int {
+       return int(C.whisper_lang_id(C.CString(lang)))
+}
+
+// Largest language id (i.e. number of available languages - 1)
+func Whisper_lang_max_id() int {
+       return int(C.whisper_lang_max_id())
+}
+
+// Return the short string of the specified language id (e.g. 2 -> "de"),
+// returns empty string if not found
+func Whisper_lang_str(id int) string {
+       return C.GoString(C.whisper_lang_str(C.int(id)))
+}
+
+// Use mel data at offset_ms to try and auto-detect the spoken language
+// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
+// Returns the probabilities of all languages.
+// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
+func (ctx *Context) Whisper_lang_auto_detect(offset_ms, n_threads int) ([]float32, error) {
+       probs := make([]float32, Whisper_lang_max_id()+1)
+       if n := int(C.whisper_lang_auto_detect((*C.struct_whisper_context)(ctx), C.int(offset_ms), C.int(n_threads), (*C.float)(&probs[0]))); n < 0 {
+               return nil, ErrAutoDetectFailed
+       } else {
+               return probs, nil
+       }
+}
+
+func (ctx *Context) Whisper_n_len() int {
+       return int(C.whisper_n_len((*C.struct_whisper_context)(ctx)))
+}
+
+func (ctx *Context) Whisper_n_vocab() int {
+       return int(C.whisper_n_vocab((*C.struct_whisper_context)(ctx)))
+}
+
+func (ctx *Context) Whisper_n_text_ctx() int {
+       return int(C.whisper_n_text_ctx((*C.struct_whisper_context)(ctx)))
+}
+
+func (ctx *Context) Whisper_is_multilingual() int {
+       return int(C.whisper_is_multilingual((*C.struct_whisper_context)(ctx)))
+}
+
+// The probabilities for the next token
+//func (ctx *Whisper_context) Whisper_get_probs() []float32 {
+//     return (*[1 << 30]float32)(unsafe.Pointer(C.whisper_get_probs((*C.struct_whisper_context)(ctx))))[:ctx.Whisper_n_vocab()]
+//}
+
+// Token Id -> String. Uses the vocabulary in the provided context
+func (ctx *Context) Whisper_token_to_str(token Token) string {
+       return C.GoString(C.whisper_token_to_str((*C.struct_whisper_context)(ctx), C.whisper_token(token)))
+}
+
+// Special tokens
+func (ctx *Context) Whisper_token_eot() Token {
+       return Token(C.whisper_token_eot((*C.struct_whisper_context)(ctx)))
+}
+
+// Special tokens
+func (ctx *Context) Whisper_token_sot() Token {
+       return Token(C.whisper_token_sot((*C.struct_whisper_context)(ctx)))
+}
+
+// Special tokens
+func (ctx *Context) Whisper_token_prev() Token {
+       return Token(C.whisper_token_prev((*C.struct_whisper_context)(ctx)))
+}
+
+// Special tokens
+func (ctx *Context) Whisper_token_solm() Token {
+       return Token(C.whisper_token_solm((*C.struct_whisper_context)(ctx)))
+}
+
+// Special tokens
+func (ctx *Context) Whisper_token_not() Token {
+       return Token(C.whisper_token_not((*C.struct_whisper_context)(ctx)))
+}
+
+// Special tokens
+func (ctx *Context) Whisper_token_beg() Token {
+       return Token(C.whisper_token_beg((*C.struct_whisper_context)(ctx)))
+}
+
+// Special tokens
+func (ctx *Context) Whisper_token_lang(lang_id int) Token {
+       return Token(C.whisper_token_lang((*C.struct_whisper_context)(ctx), C.int(lang_id)))
+}
+
+// Task tokens
+func Whisper_token_translate() Token {
+       return Token(C.whisper_token_translate())
+}
+
+// Task tokens
+func Whisper_token_transcribe() Token {
+       return Token(C.whisper_token_transcribe())
+}
+
+// Performance information
+func (ctx *Context) Whisper_print_timings() {
+       C.whisper_print_timings((*C.struct_whisper_context)(ctx))
+}
+
+// Performance information
+func (ctx *Context) Whisper_reset_timings() {
+       C.whisper_reset_timings((*C.struct_whisper_context)(ctx))
+}
+
+// Print system information
+func Whisper_print_system_info() string {
+       return C.GoString(C.whisper_print_system_info())
+}
+
+// Return default parameters for a strategy
+func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Params {
+       // Get default parameters
+       return Params(C.whisper_full_default_params_cb((*C.struct_whisper_context)(ctx), C.enum_whisper_sampling_strategy(strategy)))
+}
+
+// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
+// Uses the specified decoding strategy to obtain the text.
+func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
+       registerEncoderBeginCallback(ctx, encoderBeginCallback)
+       registerNewSegmentCallback(ctx, newSegmentCallback)
+       defer registerEncoderBeginCallback(ctx, nil)
+       defer registerNewSegmentCallback(ctx, nil)
+       if C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
+               return nil
+       } else {
+               return ErrConversionFailed
+       }
+}
+
+// Split the input audio in chunks and process each chunk separately using whisper_full()
+// It seems this approach can offer some speedup in some cases.
+// However, the transcription accuracy can be worse at the beginning and end of each chunk.
+func (ctx *Context) Whisper_full_parallel(params Params, samples []float32, processors int, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
+       registerEncoderBeginCallback(ctx, encoderBeginCallback)
+       registerNewSegmentCallback(ctx, newSegmentCallback)
+       defer registerEncoderBeginCallback(ctx, nil)
+       defer registerNewSegmentCallback(ctx, nil)
+
+       if C.whisper_full_parallel((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples)), C.int(processors)) == 0 {
+               return nil
+       } else {
+               return ErrConversionFailed
+       }
+}
+
+// Number of generated text segments.
+// A segment can be a few words, a sentence, or even a paragraph.
+func (ctx *Context) Whisper_full_n_segments() int {
+       return int(C.whisper_full_n_segments((*C.struct_whisper_context)(ctx)))
+}
+
+// Get the start and end time of the specified segment.
+func (ctx *Context) Whisper_full_get_segment_t0(segment int) int64 {
+       return int64(C.whisper_full_get_segment_t0((*C.struct_whisper_context)(ctx), C.int(segment)))
+}
+
+// Get the start and end time of the specified segment.
+func (ctx *Context) Whisper_full_get_segment_t1(segment int) int64 {
+       return int64(C.whisper_full_get_segment_t1((*C.struct_whisper_context)(ctx), C.int(segment)))
+}
+
+// Get the text of the specified segment.
+func (ctx *Context) Whisper_full_get_segment_text(segment int) string {
+       return C.GoString(C.whisper_full_get_segment_text((*C.struct_whisper_context)(ctx), C.int(segment)))
+}
+
+// Get number of tokens in the specified segment.
+func (ctx *Context) Whisper_full_n_tokens(segment int) int {
+       return int(C.whisper_full_n_tokens((*C.struct_whisper_context)(ctx), C.int(segment)))
+}
+
+// Get the token text of the specified token index in the specified segment.
+func (ctx *Context) Whisper_full_get_token_text(segment int, token int) string {
+       return C.GoString(C.whisper_full_get_token_text((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
+}
+
+// Get the token of the specified token index in the specified segment.
+func (ctx *Context) Whisper_full_get_token_id(segment int, token int) Token {
+       return Token(C.whisper_full_get_token_id((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
+}
+
+// Get token data for the specified token in the specified segment.
+// This contains probabilities, timestamps, etc.
+func (ctx *Context) whisper_full_get_token_data(segment int, token int) TokenData {
+       return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
+}
+
+// Get the probability of the specified token in the specified segment.
+func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 {
+       return float32(C.whisper_full_get_token_p((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// CALLBACKS
+
+var (
+       cbNewSegment   = make(map[unsafe.Pointer]func(int))
+       cbEncoderBegin = make(map[unsafe.Pointer]func() bool)
+)
+
+func registerNewSegmentCallback(ctx *Context, fn func(int)) {
+       if fn == nil {
+               delete(cbNewSegment, unsafe.Pointer(ctx))
+       } else {
+               cbNewSegment[unsafe.Pointer(ctx)] = fn
+       }
+}
+
+func registerEncoderBeginCallback(ctx *Context, fn func() bool) {
+       if fn == nil {
+               delete(cbEncoderBegin, unsafe.Pointer(ctx))
+       } else {
+               cbEncoderBegin[unsafe.Pointer(ctx)] = fn
+       }
+}
+
+//export callNewSegment
+func callNewSegment(user_data unsafe.Pointer, new C.int) {
+       if fn, ok := cbNewSegment[user_data]; ok {
+               fn(int(new))
+       }
+}
+
+//export callEncoderBegin
+func callEncoderBegin(user_data unsafe.Pointer) C.bool {
+       if fn, ok := cbEncoderBegin[user_data]; ok {
+               if fn() {
+                       return C.bool(true)
+               } else {
+                       return C.bool(false)
+               }
+       }
+       return true
+}
diff --git a/bindings/go/whisper_test.go b/bindings/go/whisper_test.go
new file mode 100644 (file)
index 0000000..d7b8cae
--- /dev/null
@@ -0,0 +1,110 @@
+package whisper_test
+
+import (
+       "os"
+       "runtime"
+       "testing"
+       "time"
+
+       // Packages
+       whisper "github.com/ggerganov/whisper.cpp/bindings/go"
+       wav "github.com/go-audio/wav"
+       assert "github.com/stretchr/testify/assert"
+)
+
+const (
+       ModelPath  = "models/ggml-small.en.bin"
+       SamplePath = "samples/jfk.wav"
+)
+
+func Test_Whisper_000(t *testing.T) {
+       assert := assert.New(t)
+       if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
+               t.Skip("Skipping test, model not found:", ModelPath)
+       }
+       ctx := whisper.Whisper_init(ModelPath)
+       assert.NotNil(ctx)
+       ctx.Whisper_free()
+}
+
+func Test_Whisper_001(t *testing.T) {
+       assert := assert.New(t)
+       if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
+               t.Skip("Skipping test, model not found:", ModelPath)
+       }
+       if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
+               t.Skip("Skipping test, sample not found:", SamplePath)
+       }
+
+       // Open samples
+       fh, err := os.Open(SamplePath)
+       assert.NoError(err)
+       defer fh.Close()
+
+       // Read samples
+       d := wav.NewDecoder(fh)
+       buf, err := d.FullPCMBuffer()
+       assert.NoError(err)
+
+       // Run whisper
+       ctx := whisper.Whisper_init(ModelPath)
+       assert.NotNil(ctx)
+       defer ctx.Whisper_free()
+       assert.NoError(ctx.Whisper_full(ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY), buf.AsFloat32Buffer().Data, nil, nil))
+
+       // Print out tokens
+       num_segments := ctx.Whisper_full_n_segments()
+       assert.GreaterOrEqual(num_segments, 1)
+       for i := 0; i < num_segments; i++ {
+               str := ctx.Whisper_full_get_segment_text(i)
+               assert.NotEmpty(str)
+               t0 := time.Duration(ctx.Whisper_full_get_segment_t0(i)) * time.Millisecond
+               t1 := time.Duration(ctx.Whisper_full_get_segment_t1(i)) * time.Millisecond
+               t.Logf("[%6s->%-6s] %q", t0, t1, str)
+       }
+}
+
+func Test_Whisper_002(t *testing.T) {
+       assert := assert.New(t)
+       for i := 0; i < whisper.Whisper_lang_max_id(); i++ {
+               str := whisper.Whisper_lang_str(i)
+               assert.NotEmpty(str)
+               t.Log(str)
+       }
+}
+
+func Test_Whisper_003(t *testing.T) {
+       threads := runtime.NumCPU()
+       assert := assert.New(t)
+       if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
+               t.Skip("Skipping test, model not found:", ModelPath)
+       }
+       if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
+               t.Skip("Skipping test, sample not found:", SamplePath)
+       }
+
+       // Open samples
+       fh, err := os.Open(SamplePath)
+       assert.NoError(err)
+       defer fh.Close()
+
+       // Read samples
+       d := wav.NewDecoder(fh)
+       buf, err := d.FullPCMBuffer()
+       assert.NoError(err)
+
+       // Make the model
+       ctx := whisper.Whisper_init(ModelPath)
+       assert.NotNil(ctx)
+       defer ctx.Whisper_free()
+
+       // Get MEL
+       assert.NoError(ctx.Whisper_pcm_to_mel(buf.AsFloat32Buffer().Data, threads))
+
+       // Get Languages
+       languages, err := ctx.Whisper_lang_auto_detect(0, threads)
+       assert.NoError(err)
+       for i, p := range languages {
+               t.Logf("%s: %f", whisper.Whisper_lang_str(i), p)
+       }
+}