--- /dev/null
+build
+models
+go.sum
--- /dev/null
+MIT License
+
+Copyright (c) 2022 David Thorpe
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
--- /dev/null
+CMAKE := $(shell which cmake)
+BUILD_DIR := "build"
+MODELS_DIR := "models"
+EXAMPLES_DIR := $(wildcard examples/*)
+C_INCLUDE_PATH := "../.."
+
+all: clean whisper examples
+
+whisper: mkdir
+ @echo Build whisper
+ @${CMAKE} -S ../.. -B ${BUILD_DIR} -D BUILD_SHARED_LIBS=off -D WHISPER_NO_AVX2=on
+ @${CMAKE} --build ${BUILD_DIR} --target whisper
+
+test: model-small whisper
+ @go mod tidy
+ @go test -v .
+ @go test -v ./pkg/whisper/...
+
+examples: $(EXAMPLES_DIR)
+
+model-small: mkdir examples/go-model-download
+ @${BUILD_DIR}/go-model-download -out models small.en
+
+$(EXAMPLES_DIR): mkdir whisper
+ @echo Build example $(notdir $@)
+ @go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
+
+mkdir:
+ @echo Mkdir ${BUILD_DIR}
+ @install -d ${BUILD_DIR}
+ @echo Mkdir ${MODELS_DIR}
+ @install -d ${MODELS_DIR}
+
+clean:
+ @echo Clean
+ @rm -fr $(BUILD_DIR)
+ @go mod tidy
+ @go clean
--- /dev/null
+# Go bindings for Whisper
+
+This package provides Go bindings for whisper.cpp. They have been tested on:
+
+ * Darwin (OS X) 12.6 on x64_64
+ * Debian Linux on arm64
+ * Fedora Linux on x86_64
+
+The "low level" bindings are in the `bindings/go` directory and there is a more
+Go-style package in the `bindings/go/pkg/whisper` directory. The most simple usage
+is as follows:
+
+```go
+import (
+ "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
+)
+
+func main() {
+ var modelpath string // Path to the model
+ var samples []float32 // Samples to process
+
+ // Load the model
+ model, err := whisper.New(modelpath)
+ if err != nil {
+ panic(err)
+ }
+ defer model.Close()
+
+ // Process samples
+ context, err := model.NewContext()
+ if err != nil {
+ panic(err)
+ }
+ if err := context.Process(samples, nil); err != nil {
+ return err
+ }
+
+ // Print out the results
+ for {
+ segment, err := context.NextSegment()
+ if err != nil {
+ break
+ }
+ fmt.Printf("[%6s->%6s] %s\n", segment.Start, segment.End, segment.Text)
+ }
+}
+```
+
+## Building & Testing
+
+In order to build, you need to have the Go compiler installed. You can get it from [here](https://golang.org/dl/). Run the tests with:
+
+```bash
+git clone https://github.com/ggerganov/whisper.cpp.git
+cd whisper.cpp/bindings/go
+make test
+```
+
+This will compile a static `libwhisper.a` in a `build` folder, download a model file, then run the tests. To build the examples:
+
+```bash
+make examples
+```
+
+The examples are placed in the `build` directory. Once built, you can download all the models with the following command:
+
+```bash
+./build/go-model-download -out models
+```
+
+And you can then test a model against samples with the following command:
+
+```bash
+./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav
+```
+
+
--- /dev/null
+/*
+github.com/ggerganov/whisper.cpp/bindings/go
+provides a speech-to-text service bindings for the Go programming language.
+*/
+package whisper
--- /dev/null
+package main
+
+import (
+ "context"
+ "os"
+ "os/signal"
+)
+
+// ContextForSignal returns a context object which is cancelled when a signal
+// is received. It returns nil if no signal parameter is provided
+func ContextForSignal(signals ...os.Signal) context.Context {
+ if len(signals) == 0 {
+ return nil
+ }
+
+ ch := make(chan os.Signal)
+ ctx, cancel := context.WithCancel(context.Background())
+
+ // Send message on channel when signal received
+ signal.Notify(ch, signals...)
+
+ // When any signal received, call cancel
+ go func() {
+ <-ch
+ cancel()
+ }()
+
+ // Return success
+ return ctx
+}
--- /dev/null
+package main
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "os"
+ "path/filepath"
+ "syscall"
+ "time"
+)
+
+///////////////////////////////////////////////////////////////////////////////
+// CONSTANTS
+
+const (
+ srcUrl = "https://huggingface.co/" // The location of the models
+ srcPathPrefix = "/datasets/ggerganov/whisper.cpp/resolve/main/ggml" // Filename prefix
+ srcExt = ".bin" // Filename extension
+ bufSize = 1024 * 64 // Size of the buffer used for downloading the model
+)
+
+var (
+ // The models which will be downloaded, if no model is specified as an argument
+ modelNames = []string{"tiny.en", "tiny", "base.en", "base", "small.en", "small", "medium.en", "medium", "large-v1", "large"}
+)
+
+var (
+ // The output folder. When not set, use current working directory.
+ flagOut = flag.String("out", "", "Output folder")
+
+ // HTTP timeout parameter - will timeout if takes longer than this to download a model
+ flagTimeout = flag.Duration("timeout", 30*time.Minute, "HTTP timeout")
+
+ // Quiet parameter - will not print progress if set
+ flagQuiet = flag.Bool("quiet", false, "Quiet mode")
+)
+
+///////////////////////////////////////////////////////////////////////////////
+// MAIN
+
+func main() {
+ flag.Usage = func() {
+ name := filepath.Base(flag.CommandLine.Name())
+ fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [options] <model>\n\n", name)
+ flag.PrintDefaults()
+ }
+ flag.Parse()
+
+ // Get output path
+ out, err := GetOut()
+ if err != nil {
+ fmt.Fprintln(os.Stderr, "Error:", err)
+ os.Exit(-1)
+ }
+
+ // Create context which quits on SIGINT or SIGQUIT
+ ctx := ContextForSignal(os.Interrupt, syscall.SIGQUIT)
+
+ // Progress filehandle
+ progress := os.Stdout
+ if *flagQuiet {
+ progress, err = os.Open(os.DevNull)
+ if err != nil {
+ fmt.Fprintln(os.Stderr, "Error:", err)
+ os.Exit(-1)
+ }
+ defer progress.Close()
+ }
+
+ // Download models - exit on error or interrupt
+ for _, model := range GetModels() {
+ url, err := URLForModel(model)
+ if err != nil {
+ fmt.Fprintln(os.Stderr, "Error:", err)
+ continue
+ } else if path, err := Download(ctx, progress, url, out); err == nil || err == io.EOF {
+ continue
+ } else if err == context.Canceled {
+ os.Remove(path)
+ fmt.Fprintln(progress, "\nInterrupted")
+ break
+ } else if err == context.DeadlineExceeded {
+ os.Remove(path)
+ fmt.Fprintln(progress, "Timeout downloading model")
+ continue
+ } else {
+ os.Remove(path)
+ fmt.Fprintln(os.Stderr, "Error:", err)
+ break
+ }
+ }
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// PUBLIC METHODS
+
+// GetOut returns the path to the output directory
+func GetOut() (string, error) {
+ if *flagOut == "" {
+ return os.Getwd()
+ }
+ if info, err := os.Stat(*flagOut); err != nil {
+ return "", err
+ } else if !info.IsDir() {
+ return "", fmt.Errorf("not a directory: %s", info.Name())
+ } else {
+ return *flagOut, nil
+ }
+}
+
+// GetModels returns the list of models to download
+func GetModels() []string {
+ if flag.NArg() == 0 {
+ return modelNames
+ } else {
+ return flag.Args()
+ }
+}
+
+// URLForModel returns the URL for the given model on huggingface.co
+func URLForModel(model string) (string, error) {
+ url, err := url.Parse(srcUrl)
+ if err != nil {
+ return "", err
+ } else {
+ url.Path = srcPathPrefix + "-" + model + srcExt
+ }
+ return url.String(), nil
+}
+
+// Download downloads the model from the given URL to the given output directory
+func Download(ctx context.Context, p io.Writer, model, out string) (string, error) {
+ // Create HTTP client
+ client := http.Client{
+ Timeout: *flagTimeout,
+ }
+
+ // Initiate the download
+ req, err := http.NewRequest("GET", model, nil)
+ if err != nil {
+ return "", err
+ }
+ resp, err := client.Do(req)
+ if err != nil {
+ return "", err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ return "", fmt.Errorf("%s: %s", model, resp.Status)
+ }
+
+ // If output file exists and is the same size as the model, skip
+ path := filepath.Join(out, filepath.Base(model))
+ if info, err := os.Stat(path); err == nil && info.Size() == resp.ContentLength {
+ fmt.Fprintln(p, "Skipping", model, "as it already exists")
+ return "", nil
+ }
+
+ // Create file
+ w, err := os.Create(path)
+ if err != nil {
+ return "", err
+ }
+ defer w.Close()
+
+ // Report
+ fmt.Fprintln(p, "Downloading", model, "to", out)
+
+ // Progressively download the model
+ data := make([]byte, bufSize)
+ count, pct := int64(0), int64(0)
+ ticker := time.NewTicker(5 * time.Second)
+ for {
+ select {
+ case <-ctx.Done():
+ // Cancelled, return error
+ return path, ctx.Err()
+ case <-ticker.C:
+ pct = DownloadReport(p, pct, count, resp.ContentLength)
+ default:
+ // Read body
+ n, err := resp.Body.Read(data)
+ if err != nil {
+ DownloadReport(p, pct, count, resp.ContentLength)
+ return path, err
+ } else if m, err := w.Write(data[:n]); err != nil {
+ return path, err
+ } else {
+ count += int64(m)
+ }
+ }
+ }
+}
+
+// Report periodically reports the download progress when percentage changes
+func DownloadReport(w io.Writer, pct, count, total int64) int64 {
+ pct_ := count * 100 / total
+ if pct_ > pct {
+ fmt.Fprintf(w, " ...%d MB written (%d%%)\n", count/1e6, pct_)
+ }
+ return pct_
+}
--- /dev/null
+package main
+
+import (
+ "flag"
+)
+
+///////////////////////////////////////////////////////////////////////////////
+// TYPES
+
+type Flags struct {
+ *flag.FlagSet
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// LIFECYCLE
+
+func NewFlags(name string, args []string) (*Flags, error) {
+ flags := &Flags{
+ FlagSet: flag.NewFlagSet(name, flag.ContinueOnError),
+ }
+
+ // Register the command line arguments
+ registerFlags(flags)
+
+ // Parse command line
+ if err := flags.Parse(args); err != nil {
+ return nil, err
+ }
+
+ // Return success
+ return flags, nil
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// PUBLIC METHODS
+
+func (flags *Flags) GetModel() string {
+ return flags.Lookup("model").Value.String()
+}
+
+func (flags *Flags) GetLanguage() string {
+ return flags.Lookup("language").Value.String()
+}
+
+func (flags *Flags) IsSpeedup() bool {
+ return flags.Lookup("speedup").Value.String() == "true"
+}
+
+func (flags *Flags) IsTokens() bool {
+ return flags.Lookup("tokens").Value.String() == "true"
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// PRIVATE METHODS
+
+func registerFlags(flag *Flags) {
+ flag.String("model", "", "Path to the model file")
+ flag.String("language", "", "Language")
+ flag.Bool("speedup", false, "Enable speedup")
+ flag.Bool("tokens", false, "Display tokens")
+}
--- /dev/null
+package main
+
+import (
+ "flag"
+ "fmt"
+ "os"
+ "path/filepath"
+
+ // Packages
+ whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
+)
+
+func main() {
+ flags, err := NewFlags(filepath.Base(os.Args[0]), os.Args[1:])
+ if err == flag.ErrHelp {
+ os.Exit(0)
+ } else if err != nil {
+ fmt.Fprintln(os.Stderr, err)
+ os.Exit(1)
+ } else if flags.GetModel() == "" {
+ fmt.Fprintln(os.Stderr, "Use -model flag to specify which model file to use")
+ os.Exit(1)
+ } else if flags.NArg() == 0 {
+ fmt.Fprintln(os.Stderr, "No input files specified")
+ os.Exit(1)
+ }
+
+ // Load model
+ model, err := whisper.New(flags.GetModel())
+ if err != nil {
+ fmt.Fprintln(os.Stderr, err)
+ os.Exit(1)
+ }
+ defer model.Close()
+
+ // Process files
+ for _, filename := range flags.Args() {
+ fmt.Println("Processing", filename)
+ if err := Process(model, filename, flags.GetLanguage(), flags.IsSpeedup(), flags.IsTokens()); err != nil {
+ fmt.Fprintln(os.Stderr, err)
+ continue
+ }
+ }
+}
--- /dev/null
+package main
+
+import (
+ "fmt"
+ "io"
+ "os"
+ "time"
+
+ // Package imports
+ whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
+ wav "github.com/go-audio/wav"
+)
+
+func Process(model whisper.Model, path string, lang string, speedup, tokens bool) error {
+ var data []float32
+
+ // Create processing context
+ context, err := model.NewContext()
+ if err != nil {
+ return err
+ }
+
+ // Open the file
+ fh, err := os.Open(path)
+ if err != nil {
+ return err
+ }
+ defer fh.Close()
+
+ // Decode the WAV file
+ dec := wav.NewDecoder(fh)
+ if buf, err := dec.FullPCMBuffer(); err != nil {
+ return err
+ } else if dec.SampleRate != whisper.SampleRate {
+ return fmt.Errorf("unsupported sample rate: %d", dec.SampleRate)
+ } else if dec.NumChans != 1 {
+ return fmt.Errorf("unsupported number of channels: %d", dec.NumChans)
+ } else {
+ data = buf.AsFloat32Buffer().Data
+ }
+
+ // Set the parameters
+ var cb whisper.SegmentCallback
+ if lang != "" {
+ if err := context.SetLanguage(lang); err != nil {
+ return err
+ }
+ }
+ if speedup {
+ context.SetSpeedup(true)
+ }
+ if tokens {
+ cb = func(segment whisper.Segment) {
+ fmt.Printf("%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
+ for _, token := range segment.Tokens {
+ fmt.Printf("%q ", token.Text)
+ }
+ fmt.Println("")
+ }
+ }
+
+ // Process the data
+ if err := context.Process(data, cb); err != nil {
+ return err
+ }
+
+ // Print out the results
+ for {
+ segment, err := context.NextSegment()
+ if err == io.EOF {
+ break
+ } else if err != nil {
+ return err
+ }
+ fmt.Printf("[%6s->%6s] %s\n", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond), segment.Text)
+ }
+
+ // Return success
+ return nil
+}
--- /dev/null
+module github.com/ggerganov/whisper.cpp/bindings/go
+
+go 1.19
+
+require (
+ github.com/go-audio/wav v1.1.0
+ github.com/stretchr/testify v1.8.1
+)
+
+require (
+ github.com/davecgh/go-spew v1.1.1 // indirect
+ github.com/go-audio/audio v1.0.0 // indirect
+ github.com/go-audio/riff v1.0.0 // indirect
+ github.com/pmezard/go-difflib v1.0.0 // indirect
+ gopkg.in/yaml.v3 v3.0.1 // indirect
+)
--- /dev/null
+package whisper
+
+// This file defines the whisper_token, whisper_token_data and whisper_full_params
+// structures, which are used by the whisper_full() function.
+
+import (
+ "fmt"
+)
+
+///////////////////////////////////////////////////////////////////////////////
+// CGO
+
+/*
+#include <whisper.h>
+*/
+import "C"
+
+///////////////////////////////////////////////////////////////////////////////
+// PUBLIC METHODS
+
+func (p *Params) SetTranslate(v bool) {
+ p.translate = toBool(v)
+}
+
+func (p *Params) SetNoContext(v bool) {
+ p.no_context = toBool(v)
+}
+
+func (p *Params) SetSingleSegment(v bool) {
+ p.single_segment = toBool(v)
+}
+
+func (p *Params) SetPrintSpecial(v bool) {
+ p.print_special = toBool(v)
+}
+
+func (p *Params) SetPrintProgress(v bool) {
+ p.print_progress = toBool(v)
+}
+
+func (p *Params) SetPrintRealtime(v bool) {
+ p.print_realtime = toBool(v)
+}
+
+func (p *Params) SetPrintTimestamps(v bool) {
+ p.print_timestamps = toBool(v)
+}
+
+func (p *Params) SetSpeedup(v bool) {
+ p.speed_up = toBool(v)
+}
+
+func (p *Params) SetLanguage(lang int) error {
+ str := C.whisper_lang_str(C.int(lang))
+ if str == nil {
+ return ErrInvalidLanguage
+ } else {
+ p.language = str
+ }
+ return nil
+}
+
+func (p *Params) Language() int {
+ if p.language == nil {
+ return -1
+ }
+ return int(C.whisper_lang_id(p.language))
+}
+
+func (p *Params) SetThreads(threads int) {
+ p.n_threads = C.int(threads)
+}
+
+func (p *Params) SetOffset(offset_ms int) {
+ p.offset_ms = C.int(offset_ms)
+}
+
+func (p *Params) SetDuration(duration_ms int) {
+ p.duration_ms = C.int(duration_ms)
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// PRIVATE METHODS
+
+func toBool(v bool) C.bool {
+ if v {
+ return C.bool(true)
+ }
+ return C.bool(false)
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// STRINGIFY
+
+func (p *Params) String() string {
+ str := "<whisper.params"
+ str += fmt.Sprintf(" strategy=%v", p.strategy)
+ str += fmt.Sprintf(" n_threads=%d", p.n_threads)
+ if p.language != nil {
+ str += fmt.Sprintf(" language=%s", C.GoString(p.language))
+ }
+ str += fmt.Sprintf(" n_max_text_ctx=%d", p.n_max_text_ctx)
+ str += fmt.Sprintf(" offset_ms=%d", p.offset_ms)
+ str += fmt.Sprintf(" duration_ms=%d", p.duration_ms)
+ if p.translate {
+ str += " translate"
+ }
+ if p.no_context {
+ str += " no_context"
+ }
+ if p.single_segment {
+ str += " single_segment"
+ }
+ if p.print_special {
+ str += " print_special"
+ }
+ if p.print_progress {
+ str += " print_progress"
+ }
+ if p.print_realtime {
+ str += " print_realtime"
+ }
+ if p.print_timestamps {
+ str += " print_timestamps"
+ }
+ if p.token_timestamps {
+ str += " token_timestamps"
+ }
+ if p.speed_up {
+ str += " speed_up"
+ }
+
+ return str + ">"
+}
--- /dev/null
+package whisper
+
+import (
+ "errors"
+
+ // Bindings
+ whisper "github.com/ggerganov/whisper.cpp/bindings/go"
+)
+
+///////////////////////////////////////////////////////////////////////////////
+// ERRORS
+
+var (
+ ErrUnableToLoadModel = errors.New("unable to load model")
+ ErrInternalAppError = errors.New("internal application error")
+ ErrProcessingFailed = errors.New("processing failed")
+ ErrUnsupportedLanguage = errors.New("unsupported language")
+)
+
+///////////////////////////////////////////////////////////////////////////////
+// CONSTANTS
+
+// SampleRate is the sample rate of the audio data.
+const SampleRate = whisper.SampleRate
+
+// SampleBits is the number of bytes per sample.
+const SampleBits = whisper.SampleBits
--- /dev/null
+package whisper
+
+import (
+ "io"
+ "strings"
+ "time"
+
+ // Bindings
+ whisper "github.com/ggerganov/whisper.cpp/bindings/go"
+)
+
+///////////////////////////////////////////////////////////////////////////////
+// TYPES
+
+type context struct {
+ n int
+ model *model
+ params whisper.Params
+}
+
+// Make sure context adheres to the interface
+var _ Context = (*context)(nil)
+
+///////////////////////////////////////////////////////////////////////////////
+// LIFECYCLE
+
+func NewContext(model *model, params whisper.Params) (Context, error) {
+ context := new(context)
+ context.model = model
+ context.params = params
+
+ // Return success
+ return context, nil
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// PUBLIC METHODS
+
+// Set the language to use for speech recognition.
+func (context *context) SetLanguage(lang string) error {
+ if context.model.ctx == nil {
+ return ErrInternalAppError
+ }
+ if id := context.model.ctx.Whisper_lang_id(lang); id < 0 {
+ return ErrUnsupportedLanguage
+ } else if err := context.params.SetLanguage(id); err != nil {
+ return err
+ }
+ // Return success
+ return nil
+}
+
+// Get language
+func (context *context) Language() string {
+ return whisper.Whisper_lang_str(context.params.Language())
+}
+
+// Set speedup flag
+func (context *context) SetSpeedup(v bool) {
+ context.params.SetSpeedup(v)
+}
+
+// Process new sample data and return any errors
+func (context *context) Process(data []float32, cb SegmentCallback) error {
+ if context.model.ctx == nil {
+ return ErrInternalAppError
+ }
+ // If the callback is defined then we force on single_segment mode
+ if cb != nil {
+ context.params.SetSingleSegment(true)
+ }
+
+ // We don't do parallel processing at the moment
+ processors := 0
+ if processors > 1 {
+ if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
+ if cb != nil {
+ num_segments := context.model.ctx.Whisper_full_n_segments()
+ s0 := num_segments - new
+ for i := s0; i < num_segments; i++ {
+ cb(toSegment(context.model.ctx, i))
+ }
+ }
+ }); err != nil {
+ return err
+ }
+ } else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
+ if cb != nil {
+ num_segments := context.model.ctx.Whisper_full_n_segments()
+ s0 := num_segments - new
+ for i := s0; i < num_segments; i++ {
+ cb(toSegment(context.model.ctx, i))
+ }
+ }
+ }); err != nil {
+ return err
+ }
+
+ // Return success
+ return nil
+}
+
+// Return the next segment of tokens
+func (context *context) NextSegment() (Segment, error) {
+ if context.model.ctx == nil {
+ return Segment{}, ErrInternalAppError
+ }
+ if context.n >= context.model.ctx.Whisper_full_n_segments() {
+ return Segment{}, io.EOF
+ }
+
+ // Populate result
+ result := toSegment(context.model.ctx, context.n)
+
+ // Increment the cursor
+ context.n++
+
+ // Return success
+ return result, nil
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// PRIVATE METHODS
+
+func toSegment(ctx *whisper.Context, n int) Segment {
+ return Segment{
+ Num: n,
+ Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text(n)),
+ Start: time.Duration(ctx.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10,
+ End: time.Duration(ctx.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10,
+ Tokens: toTokens(ctx, n),
+ }
+}
+
+func toTokens(ctx *whisper.Context, n int) []Token {
+ result := make([]Token, ctx.Whisper_full_n_tokens(n))
+ for i := 0; i < len(result); i++ {
+ result[i] = Token{
+ Id: int(ctx.Whisper_full_get_token_id(n, i)),
+ Text: strings.TrimSpace(ctx.Whisper_full_get_token_text(n, i)),
+ P: ctx.Whisper_full_get_token_p(n, i),
+ }
+ }
+ return result
+}
--- /dev/null
+package whisper_test
+
+import (
+ "os"
+ "testing"
+
+ // Packages
+ whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
+ assert "github.com/stretchr/testify/assert"
+)
+
+const (
+ ModelPath = "../../models/ggml-tiny.bin"
+ SamplePath = "../../samples/jfk.wav"
+)
+
+func Test_Whisper_000(t *testing.T) {
+ assert := assert.New(t)
+ if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
+ t.Skip("Skipping test, model not found:", ModelPath)
+ }
+ if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
+ t.Skip("Skipping test, sample not found:", SamplePath)
+ }
+
+ // Load model
+ model, err := whisper.New(ModelPath)
+ assert.NoError(err)
+ assert.NotNil(model)
+ assert.NoError(model.Close())
+
+ t.Log("languages=", model.Languages())
+}
+
+func Test_Whisper_001(t *testing.T) {
+ assert := assert.New(t)
+ if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
+ t.Skip("Skipping test, model not found:", ModelPath)
+ }
+ if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
+ t.Skip("Skipping test, sample not found:", SamplePath)
+ }
+
+ // Load model
+ model, err := whisper.New(ModelPath)
+ assert.NoError(err)
+ assert.NotNil(model)
+ defer model.Close()
+
+ // Get context for decoding
+ ctx, err := model.NewContext()
+ assert.NoError(err)
+ assert.NotNil(ctx)
+
+}
--- /dev/null
+/*
+This is the higher-level speech-to-text whisper.cpp API for go
+*/
+package whisper
--- /dev/null
+package whisper
+
+import (
+ "io"
+ "time"
+)
+
+///////////////////////////////////////////////////////////////////////////////
+// TYPES
+
+// SegmentCallback is the callback function for processing segments in real
+// time. It is called during the Process function
+type SegmentCallback func(Segment)
+
+// Model is the interface to a whisper model. Create a new model with the
+// function whisper.New(string)
+type Model interface {
+ io.Closer
+
+ // Return a new speech-to-text context.
+ NewContext() (Context, error)
+
+ // Return all languages supported.
+ Languages() []string
+}
+
+// Context is the speach recognition context.
+type Context interface {
+ SetLanguage(string) error // Set the language to use for speech recognition.
+ Language() string // Get language
+ SetSpeedup(bool) // Set speedup flag
+
+ // Process mono audio data and return any errors.
+ // If defined, newly generated segments are passed to the
+ // callback function during processing.
+ Process([]float32, SegmentCallback) error
+
+ // After process is called, return segments until the end of the stream
+ // is reached, when io.EOF is returned.
+ NextSegment() (Segment, error)
+}
+
+// Segment is the text result of a speech recognition.
+type Segment struct {
+ // Segment Number
+ Num int
+
+ // Time beginning and end timestamps for the segment.
+ Start, End time.Duration
+
+ // The text of the segment.
+ Text string
+
+ // The tokens of the segment.
+ Tokens []Token
+}
+
+// Token is a text or special token
+type Token struct {
+ Id int
+ Text string
+ P float32
+}
--- /dev/null
+package whisper
+
+import (
+ "fmt"
+ "os"
+ "runtime"
+
+ // Bindings
+ whisper "github.com/ggerganov/whisper.cpp/bindings/go"
+)
+
+///////////////////////////////////////////////////////////////////////////////
+// TYPES
+
+type model struct {
+ path string
+ ctx *whisper.Context
+}
+
+// Make sure model adheres to the interface
+var _ Model = (*model)(nil)
+
+///////////////////////////////////////////////////////////////////////////////
+// LIFECYCLE
+
+func New(path string) (*model, error) {
+ model := new(model)
+ if _, err := os.Stat(path); err != nil {
+ return nil, err
+ } else if ctx := whisper.Whisper_init(path); ctx == nil {
+ return nil, ErrUnableToLoadModel
+ } else {
+ model.ctx = ctx
+ model.path = path
+ }
+
+ // Return success
+ return model, nil
+}
+
+func (model *model) Close() error {
+ if model.ctx != nil {
+ model.ctx.Whisper_free()
+ }
+
+ // Release resources
+ model.ctx = nil
+
+ // Return success
+ return nil
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// STRINGIFY
+
+func (model *model) String() string {
+ str := "<whisper.model"
+ if model.ctx != nil {
+ str += fmt.Sprintf(" model=%q", model.path)
+ }
+ return str + ">"
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// PUBLIC METHODS
+
+// Return all recognized languages. Initially it is set to auto-detect
+func (model *model) Languages() []string {
+ result := make([]string, 0, whisper.Whisper_lang_max_id())
+ for i := 0; i < whisper.Whisper_lang_max_id(); i++ {
+ str := whisper.Whisper_lang_str(i)
+ if model.ctx.Whisper_lang_id(str) >= 0 {
+ result = append(result, str)
+ }
+ }
+ return result
+}
+
+func (model *model) NewContext() (Context, error) {
+ if model.ctx == nil {
+ return nil, ErrInternalAppError
+ }
+
+ // Create new context
+ params := model.ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
+ params.SetTranslate(false)
+ params.SetPrintSpecial(false)
+ params.SetPrintProgress(false)
+ params.SetPrintRealtime(false)
+ params.SetPrintTimestamps(false)
+ params.SetThreads(runtime.NumCPU())
+
+ // Return new context
+ return NewContext(model, params)
+}
--- /dev/null
+package whisper
+
+import (
+ "errors"
+ "unsafe"
+)
+
+///////////////////////////////////////////////////////////////////////////////
+// CGO
+
+/*
+#cgo CFLAGS: -I${SRCDIR}/../..
+#cgo LDFLAGS: -L${SRCDIR}/build -lwhisper -lm -lstdc++
+#cgo darwin LDFLAGS: -framework Accelerate
+#include <whisper.h>
+#include <stdlib.h>
+
+extern void callNewSegment(void* user_data, int new);
+extern bool callEncoderBegin(void* user_data);
+
+// Text segment callback
+// Called on every newly generated text segment
+// Use the whisper_full_...() functions to obtain the text segments
+static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void* user_data) {
+ if(user_data != NULL && ctx != NULL) {
+ callNewSegment(user_data, n_new);
+ }
+}
+
+// Encoder begin callback
+// If not NULL, called before the encoder starts
+// If it returns false, the computation is aborted
+static bool whisper_encoder_begin_cb(struct whisper_context* ctx, void* user_data) {
+ if(user_data != NULL && ctx != NULL) {
+ return callEncoderBegin(user_data);
+ }
+ return false;
+}
+
+// Get default parameters and set callbacks
+static struct whisper_full_params whisper_full_default_params_cb(struct whisper_context* ctx, enum whisper_sampling_strategy strategy) {
+ struct whisper_full_params params = whisper_full_default_params(strategy);
+ params.new_segment_callback = whisper_new_segment_cb;
+ params.new_segment_callback_user_data = (void*)(ctx);
+ params.encoder_begin_callback = whisper_encoder_begin_cb;
+ params.encoder_begin_callback_user_data = (void*)(ctx);
+ return params;
+}
+*/
+import "C"
+
+///////////////////////////////////////////////////////////////////////////////
+// TYPES
+
+type (
+ Context C.struct_whisper_context
+ Token C.whisper_token
+ TokenData C.struct_whisper_token_data
+ SamplingStrategy C.enum_whisper_sampling_strategy
+ Params C.struct_whisper_full_params
+)
+
+///////////////////////////////////////////////////////////////////////////////
+// GLOBALS
+
+const (
+ SAMPLING_GREEDY SamplingStrategy = C.WHISPER_SAMPLING_GREEDY
+ SAMPLING_BEAM_SEARCH SamplingStrategy = C.WHISPER_SAMPLING_BEAM_SEARCH
+)
+
+const (
+ SampleRate = C.WHISPER_SAMPLE_RATE // Expected sample rate, samples per second
+ SampleBits = uint16(unsafe.Sizeof(C.float(0))) * 8 // Sample size in bits
+ NumFFT = C.WHISPER_N_FFT
+ NumMEL = C.WHISPER_N_MEL
+ HopLength = C.WHISPER_HOP_LENGTH
+ ChunkSize = C.WHISPER_CHUNK_SIZE
+)
+
+var (
+ ErrTokenizerFailed = errors.New("whisper_tokenize failed")
+ ErrAutoDetectFailed = errors.New("whisper_lang_auto_detect failed")
+ ErrConversionFailed = errors.New("whisper_convert failed")
+ ErrInvalidLanguage = errors.New("invalid language")
+)
+
+///////////////////////////////////////////////////////////////////////////////
+// PUBLIC METHODS
+
+// Allocates all memory needed for the model and loads the model from the given file.
+// Returns NULL on failure.
+func Whisper_init(path string) *Context {
+ cPath := C.CString(path)
+ defer C.free(unsafe.Pointer(cPath))
+ if ctx := C.whisper_init(cPath); ctx != nil {
+ return (*Context)(ctx)
+ } else {
+ return nil
+ }
+}
+
+// Frees all memory allocated by the model.
+func (ctx *Context) Whisper_free() {
+ C.whisper_free((*C.struct_whisper_context)(ctx))
+}
+
+// Convert RAW PCM audio to log mel spectrogram.
+// The resulting spectrogram is stored inside the provided whisper context.
+func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) error {
+ if C.whisper_pcm_to_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(threads)) == 0 {
+ return nil
+ } else {
+ return ErrConversionFailed
+ }
+}
+
+// This can be used to set a custom log mel spectrogram inside the provided whisper context.
+// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
+// n_mel must be 80
+func (ctx *Context) Whisper_set_mel(data []float32, n_mel int) error {
+ if C.whisper_set_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(n_mel)) == 0 {
+ return nil
+ } else {
+ return ErrConversionFailed
+ }
+}
+
+// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
+// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
+// offset can be used to specify the offset of the first frame in the spectrogram.
+func (ctx *Context) Whisper_encode(offset, threads int) error {
+ if C.whisper_encode((*C.struct_whisper_context)(ctx), C.int(offset), C.int(threads)) == 0 {
+ return nil
+ } else {
+ return ErrConversionFailed
+ }
+}
+
+// Run the Whisper decoder to obtain the logits and probabilities for the next token.
+// Make sure to call whisper_encode() first.
+// tokens + n_tokens is the provided context for the decoder.
+// n_past is the number of tokens to use from previous decoder calls.
+func (ctx *Context) Whisper_decode(tokens []Token, past, threads int) error {
+ if C.whisper_decode((*C.struct_whisper_context)(ctx), (*C.whisper_token)(&tokens[0]), C.int(len(tokens)), C.int(past), C.int(threads)) == 0 {
+ return nil
+ } else {
+ return ErrConversionFailed
+ }
+}
+
+// whisper_sample_best() returns the token with the highest probability
+func (ctx *Context) Whisper_sample_best() TokenData {
+ return TokenData(C.whisper_sample_best((*C.struct_whisper_context)(ctx)))
+}
+
+// whisper_sample_timestamp() returns the most probable timestamp token
+func (ctx *Context) Whisper_sample_timestamp(is_initial bool) TokenData {
+ return TokenData(C.whisper_sample_timestamp((*C.struct_whisper_context)(ctx), C.bool(is_initial)))
+}
+
+// Convert the provided text into tokens. The tokens pointer must be large enough to hold the resulting tokens.
+// Returns the number of tokens on success
+func (ctx *Context) Whisper_tokenize(text string, tokens []Token) (int, error) {
+ cText := C.CString(text)
+ defer C.free(unsafe.Pointer(cText))
+ if n := C.whisper_tokenize((*C.struct_whisper_context)(ctx), cText, (*C.whisper_token)(&tokens[0]), C.int(len(tokens))); n >= 0 {
+ return int(n), nil
+ } else {
+ return 0, ErrTokenizerFailed
+ }
+}
+
+// Return the id of the specified language, returns -1 if not found
+func (ctx *Context) Whisper_lang_id(lang string) int {
+ return int(C.whisper_lang_id(C.CString(lang)))
+}
+
+// Largest language id (i.e. number of available languages - 1)
+func Whisper_lang_max_id() int {
+ return int(C.whisper_lang_max_id())
+}
+
+// Return the short string of the specified language id (e.g. 2 -> "de"),
+// returns empty string if not found
+func Whisper_lang_str(id int) string {
+ return C.GoString(C.whisper_lang_str(C.int(id)))
+}
+
+// Use mel data at offset_ms to try and auto-detect the spoken language
+// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
+// Returns the probabilities of all languages.
+// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
+func (ctx *Context) Whisper_lang_auto_detect(offset_ms, n_threads int) ([]float32, error) {
+ probs := make([]float32, Whisper_lang_max_id()+1)
+ if n := int(C.whisper_lang_auto_detect((*C.struct_whisper_context)(ctx), C.int(offset_ms), C.int(n_threads), (*C.float)(&probs[0]))); n < 0 {
+ return nil, ErrAutoDetectFailed
+ } else {
+ return probs, nil
+ }
+}
+
+func (ctx *Context) Whisper_n_len() int {
+ return int(C.whisper_n_len((*C.struct_whisper_context)(ctx)))
+}
+
+func (ctx *Context) Whisper_n_vocab() int {
+ return int(C.whisper_n_vocab((*C.struct_whisper_context)(ctx)))
+}
+
+func (ctx *Context) Whisper_n_text_ctx() int {
+ return int(C.whisper_n_text_ctx((*C.struct_whisper_context)(ctx)))
+}
+
+func (ctx *Context) Whisper_is_multilingual() int {
+ return int(C.whisper_is_multilingual((*C.struct_whisper_context)(ctx)))
+}
+
+// The probabilities for the next token
+//func (ctx *Whisper_context) Whisper_get_probs() []float32 {
+// return (*[1 << 30]float32)(unsafe.Pointer(C.whisper_get_probs((*C.struct_whisper_context)(ctx))))[:ctx.Whisper_n_vocab()]
+//}
+
+// Token Id -> String. Uses the vocabulary in the provided context
+func (ctx *Context) Whisper_token_to_str(token Token) string {
+ return C.GoString(C.whisper_token_to_str((*C.struct_whisper_context)(ctx), C.whisper_token(token)))
+}
+
+// Special tokens
+func (ctx *Context) Whisper_token_eot() Token {
+ return Token(C.whisper_token_eot((*C.struct_whisper_context)(ctx)))
+}
+
+// Special tokens
+func (ctx *Context) Whisper_token_sot() Token {
+ return Token(C.whisper_token_sot((*C.struct_whisper_context)(ctx)))
+}
+
+// Special tokens
+func (ctx *Context) Whisper_token_prev() Token {
+ return Token(C.whisper_token_prev((*C.struct_whisper_context)(ctx)))
+}
+
+// Special tokens
+func (ctx *Context) Whisper_token_solm() Token {
+ return Token(C.whisper_token_solm((*C.struct_whisper_context)(ctx)))
+}
+
+// Special tokens
+func (ctx *Context) Whisper_token_not() Token {
+ return Token(C.whisper_token_not((*C.struct_whisper_context)(ctx)))
+}
+
+// Special tokens
+func (ctx *Context) Whisper_token_beg() Token {
+ return Token(C.whisper_token_beg((*C.struct_whisper_context)(ctx)))
+}
+
+// Special tokens
+func (ctx *Context) Whisper_token_lang(lang_id int) Token {
+ return Token(C.whisper_token_lang((*C.struct_whisper_context)(ctx), C.int(lang_id)))
+}
+
+// Task tokens
+func Whisper_token_translate() Token {
+ return Token(C.whisper_token_translate())
+}
+
+// Task tokens
+func Whisper_token_transcribe() Token {
+ return Token(C.whisper_token_transcribe())
+}
+
+// Performance information
+func (ctx *Context) Whisper_print_timings() {
+ C.whisper_print_timings((*C.struct_whisper_context)(ctx))
+}
+
+// Performance information
+func (ctx *Context) Whisper_reset_timings() {
+ C.whisper_reset_timings((*C.struct_whisper_context)(ctx))
+}
+
+// Print system information
+func Whisper_print_system_info() string {
+ return C.GoString(C.whisper_print_system_info())
+}
+
+// Return default parameters for a strategy
+func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Params {
+ // Get default parameters
+ return Params(C.whisper_full_default_params_cb((*C.struct_whisper_context)(ctx), C.enum_whisper_sampling_strategy(strategy)))
+}
+
+// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
+// Uses the specified decoding strategy to obtain the text.
+func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
+ registerEncoderBeginCallback(ctx, encoderBeginCallback)
+ registerNewSegmentCallback(ctx, newSegmentCallback)
+ defer registerEncoderBeginCallback(ctx, nil)
+ defer registerNewSegmentCallback(ctx, nil)
+ if C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
+ return nil
+ } else {
+ return ErrConversionFailed
+ }
+}
+
+// Split the input audio in chunks and process each chunk separately using whisper_full()
+// It seems this approach can offer some speedup in some cases.
+// However, the transcription accuracy can be worse at the beginning and end of each chunk.
+func (ctx *Context) Whisper_full_parallel(params Params, samples []float32, processors int, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
+ registerEncoderBeginCallback(ctx, encoderBeginCallback)
+ registerNewSegmentCallback(ctx, newSegmentCallback)
+ defer registerEncoderBeginCallback(ctx, nil)
+ defer registerNewSegmentCallback(ctx, nil)
+
+ if C.whisper_full_parallel((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples)), C.int(processors)) == 0 {
+ return nil
+ } else {
+ return ErrConversionFailed
+ }
+}
+
+// Number of generated text segments.
+// A segment can be a few words, a sentence, or even a paragraph.
+func (ctx *Context) Whisper_full_n_segments() int {
+ return int(C.whisper_full_n_segments((*C.struct_whisper_context)(ctx)))
+}
+
+// Get the start and end time of the specified segment.
+func (ctx *Context) Whisper_full_get_segment_t0(segment int) int64 {
+ return int64(C.whisper_full_get_segment_t0((*C.struct_whisper_context)(ctx), C.int(segment)))
+}
+
+// Get the start and end time of the specified segment.
+func (ctx *Context) Whisper_full_get_segment_t1(segment int) int64 {
+ return int64(C.whisper_full_get_segment_t1((*C.struct_whisper_context)(ctx), C.int(segment)))
+}
+
+// Get the text of the specified segment.
+func (ctx *Context) Whisper_full_get_segment_text(segment int) string {
+ return C.GoString(C.whisper_full_get_segment_text((*C.struct_whisper_context)(ctx), C.int(segment)))
+}
+
+// Get number of tokens in the specified segment.
+func (ctx *Context) Whisper_full_n_tokens(segment int) int {
+ return int(C.whisper_full_n_tokens((*C.struct_whisper_context)(ctx), C.int(segment)))
+}
+
+// Get the token text of the specified token index in the specified segment.
+func (ctx *Context) Whisper_full_get_token_text(segment int, token int) string {
+ return C.GoString(C.whisper_full_get_token_text((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
+}
+
+// Get the token of the specified token index in the specified segment.
+func (ctx *Context) Whisper_full_get_token_id(segment int, token int) Token {
+ return Token(C.whisper_full_get_token_id((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
+}
+
+// Get token data for the specified token in the specified segment.
+// This contains probabilities, timestamps, etc.
+func (ctx *Context) whisper_full_get_token_data(segment int, token int) TokenData {
+ return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
+}
+
+// Get the probability of the specified token in the specified segment.
+func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 {
+ return float32(C.whisper_full_get_token_p((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// CALLBACKS
+
+var (
+ cbNewSegment = make(map[unsafe.Pointer]func(int))
+ cbEncoderBegin = make(map[unsafe.Pointer]func() bool)
+)
+
+func registerNewSegmentCallback(ctx *Context, fn func(int)) {
+ if fn == nil {
+ delete(cbNewSegment, unsafe.Pointer(ctx))
+ } else {
+ cbNewSegment[unsafe.Pointer(ctx)] = fn
+ }
+}
+
+func registerEncoderBeginCallback(ctx *Context, fn func() bool) {
+ if fn == nil {
+ delete(cbEncoderBegin, unsafe.Pointer(ctx))
+ } else {
+ cbEncoderBegin[unsafe.Pointer(ctx)] = fn
+ }
+}
+
+//export callNewSegment
+func callNewSegment(user_data unsafe.Pointer, new C.int) {
+ if fn, ok := cbNewSegment[user_data]; ok {
+ fn(int(new))
+ }
+}
+
+//export callEncoderBegin
+func callEncoderBegin(user_data unsafe.Pointer) C.bool {
+ if fn, ok := cbEncoderBegin[user_data]; ok {
+ if fn() {
+ return C.bool(true)
+ } else {
+ return C.bool(false)
+ }
+ }
+ return true
+}
--- /dev/null
+package whisper_test
+
+import (
+ "os"
+ "runtime"
+ "testing"
+ "time"
+
+ // Packages
+ whisper "github.com/ggerganov/whisper.cpp/bindings/go"
+ wav "github.com/go-audio/wav"
+ assert "github.com/stretchr/testify/assert"
+)
+
+const (
+ ModelPath = "models/ggml-small.en.bin"
+ SamplePath = "samples/jfk.wav"
+)
+
+func Test_Whisper_000(t *testing.T) {
+ assert := assert.New(t)
+ if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
+ t.Skip("Skipping test, model not found:", ModelPath)
+ }
+ ctx := whisper.Whisper_init(ModelPath)
+ assert.NotNil(ctx)
+ ctx.Whisper_free()
+}
+
+func Test_Whisper_001(t *testing.T) {
+ assert := assert.New(t)
+ if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
+ t.Skip("Skipping test, model not found:", ModelPath)
+ }
+ if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
+ t.Skip("Skipping test, sample not found:", SamplePath)
+ }
+
+ // Open samples
+ fh, err := os.Open(SamplePath)
+ assert.NoError(err)
+ defer fh.Close()
+
+ // Read samples
+ d := wav.NewDecoder(fh)
+ buf, err := d.FullPCMBuffer()
+ assert.NoError(err)
+
+ // Run whisper
+ ctx := whisper.Whisper_init(ModelPath)
+ assert.NotNil(ctx)
+ defer ctx.Whisper_free()
+ assert.NoError(ctx.Whisper_full(ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY), buf.AsFloat32Buffer().Data, nil, nil))
+
+ // Print out tokens
+ num_segments := ctx.Whisper_full_n_segments()
+ assert.GreaterOrEqual(num_segments, 1)
+ for i := 0; i < num_segments; i++ {
+ str := ctx.Whisper_full_get_segment_text(i)
+ assert.NotEmpty(str)
+ t0 := time.Duration(ctx.Whisper_full_get_segment_t0(i)) * time.Millisecond
+ t1 := time.Duration(ctx.Whisper_full_get_segment_t1(i)) * time.Millisecond
+ t.Logf("[%6s->%-6s] %q", t0, t1, str)
+ }
+}
+
+func Test_Whisper_002(t *testing.T) {
+ assert := assert.New(t)
+ for i := 0; i < whisper.Whisper_lang_max_id(); i++ {
+ str := whisper.Whisper_lang_str(i)
+ assert.NotEmpty(str)
+ t.Log(str)
+ }
+}
+
+func Test_Whisper_003(t *testing.T) {
+ threads := runtime.NumCPU()
+ assert := assert.New(t)
+ if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
+ t.Skip("Skipping test, model not found:", ModelPath)
+ }
+ if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
+ t.Skip("Skipping test, sample not found:", SamplePath)
+ }
+
+ // Open samples
+ fh, err := os.Open(SamplePath)
+ assert.NoError(err)
+ defer fh.Close()
+
+ // Read samples
+ d := wav.NewDecoder(fh)
+ buf, err := d.FullPCMBuffer()
+ assert.NoError(err)
+
+ // Make the model
+ ctx := whisper.Whisper_init(ModelPath)
+ assert.NotNil(ctx)
+ defer ctx.Whisper_free()
+
+ // Get MEL
+ assert.NoError(ctx.Whisper_pcm_to_mel(buf.AsFloat32Buffer().Data, threads))
+
+ // Get Languages
+ languages, err := ctx.Whisper_lang_auto_detect(0, threads)
+ assert.NoError(err)
+ for i, p := range languages {
+ t.Logf("%s: %f", whisper.Whisper_lang_str(i), p)
+ }
+}