]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
go : improve model download (#2756)
authorRyan Johnson <redacted>
Fri, 7 Mar 2025 08:03:51 +0000 (02:03 -0600)
committerGitHub <redacted>
Fri, 7 Mar 2025 08:03:51 +0000 (10:03 +0200)
* Updated models download URL

* Updated list of models available

All of the high efficiency quantized models are rejected when trying to download. They exist on the server. Let's allow them.

* added path prefix for whisper-cli in message to user. The message is misleading if this script is called from another script in a different folder. So the message has to be fixed.

* undid download URL change I made earlier. Fixed filepath.Join(urlPath, model) bug.

* Undid download URL change I made earlier.

Seems that the old URL works but only when provided a model to download. Still doesn't explain why there's a different download URL that also works. Please elucidate in docs.

* Fixed URLForModel Function's bug

filepath.Join is designed for filesystem paths, and it uses backslashes (\) on Windows. URLs, however, require forward slashes (/), so the use of filepath.Join is inappropriate for constructing URLs.

The fmt.Sprintf function ensures that forward slashes are used.

* Fixed URL trailing / double slash bug

Ensure no double slash by trimming trailing '/' from srcUrl if present

* Fixed bad download URL, missing ggml prefix

Not sure if that was a bug I introduced but it was trying to download without the prefix.

* Added question before downloading all models. Added download size estimate

HEAD Requests:
Efficiently fetches file sizes without downloading the content.
Interactive Workflow:
Allows the user to make informed decisions about downloading all models.
Safe Defaults:
Aborts if the user does not explicitly confirm.

* Fixed Unbuffered channel warning.

warning in context.go : misuse of unbuffered os.Signal channel as argument to signal.

The warning indicates that the unbuffered channel used in signal.Notify in context.go may be misused. In Go, unbuffered channels can cause potential deadlocks if signals are sent faster than they are received.

* Fixed download size calculation, download URL prefix bug, added link to models URL for user.

The URL formatter was prepending the model name to the formatted model name in the URL

* Added logs and exes to gitignore

* Delete bindings/go/examples/go-model-download/go-model-download.exe

* Delete whisper_build.log

.gitignore
bindings/go/examples/go-model-download/context.go
bindings/go/examples/go-model-download/main.go
models/download-ggml-model.cmd

index c1e584dba3269c8780963ce8060cb4ea6816aff7..91368ec577b62d37f7ad11bc40f5d53e7b192cee 100644 (file)
@@ -58,3 +58,5 @@ cmake-build-debug/
 .cxx/
 .gradle/
 local.properties
+.log
+.exe
\ No newline at end of file
index 639d8f5bd96dd2cf141941929250ecf9ff1e302e..7d5f0ddb1df9d352d7cd52792cb0b9c493e5a098 100644 (file)
@@ -9,22 +9,23 @@ import (
 // ContextForSignal returns a context object which is cancelled when a signal
 // is received. It returns nil if no signal parameter is provided
 func ContextForSignal(signals ...os.Signal) context.Context {
-       if len(signals) == 0 {
-               return nil
-       }
+    if len(signals) == 0 {
+        return nil
+    }
 
-       ch := make(chan os.Signal)
-       ctx, cancel := context.WithCancel(context.Background())
+    ch := make(chan os.Signal, 1) // Buffered channel with space for 1 signal
+    ctx, cancel := context.WithCancel(context.Background())
 
-       // Send message on channel when signal received
-       signal.Notify(ch, signals...)
+    // Send message on channel when signal received
+    signal.Notify(ch, signals...)
 
-       // When any signal received, call cancel
-       go func() {
-               <-ch
-               cancel()
-       }()
+    // When any signal is received, call cancel
+    go func() {
+        <-ch
+        cancel()
+    }()
 
-       // Return success
-       return ctx
+    // Return success
+    return ctx
 }
+
index d0c1cc78b1e4bad7afbf48e9f0df5478153f6dce..728c6df53d426e52531ee04e29d0fc3aec13cdf9 100644 (file)
@@ -9,6 +9,7 @@ import (
        "net/url"
        "os"
        "path/filepath"
+       "strings"
        "syscall"
        "time"
 )
@@ -17,14 +18,27 @@ import (
 // CONSTANTS
 
 const (
-       srcUrl  = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main" // The location of the models
-       srcExt  = ".bin"                                                      // Filename extension
-       bufSize = 1024 * 64                                                   // Size of the buffer used for downloading the model
+       srcUrl  = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/" // The location of the models
+       srcExt  = ".bin"                                                       // Filename extension
+       bufSize = 1024 * 64                                                    // Size of the buffer used for downloading the model
 )
 
 var (
        // The models which will be downloaded, if no model is specified as an argument
-       modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large-v2", "ggml-large-v3", "large-v3-turbo"}
+       modelNames = []string{
+               "tiny", "tiny-q5_1", "tiny-q8_0",
+               "tiny.en", "tiny.en-q5_1", "tiny.en-q8_0",
+               "base", "base-q5_1", "base-q8_0",
+               "base.en", "base.en-q5_1", "base.en-q8_0",
+               "small", "small-q5_1", "small-q8_0",
+               "small.en", "small.en-q5_1", "small.en-q8_0",
+               "medium", "medium-q5_0", "medium-q8_0",
+               "medium.en", "medium.en-q5_0", "medium.en-q8_0",
+               "large-v1",
+               "large-v2", "large-v2-q5_0", "large-v2-q8_0",
+               "large-v3", "large-v3-q5_0",
+               "large-v3-turbo", "large-v3-turbo-q5_0", "large-v3-turbo-q8_0",
+       }
 )
 
 var (
@@ -44,7 +58,25 @@ var (
 func main() {
        flag.Usage = func() {
                name := filepath.Base(flag.CommandLine.Name())
-               fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [options] <model>\n\n", name)
+               fmt.Fprintf(flag.CommandLine.Output(), `
+                       Usage: %s [options] [<model>...]
+
+                       Options:
+                       -out string     Specify the output folder where models will be saved.
+                                       Default: Current working directory.
+                       -timeout duration Set the maximum duration for downloading a model.
+                                     Example: 10m, 1h (default: 30m0s).
+                       -quiet           Suppress all output except errors.
+
+                       Examples:
+                       1. Download a specific model:
+                       %s -out ./models tiny-q8_0
+
+                         2. Download all models:
+                       %s -out ./models
+
+                       `, name, name, name)
+
                flag.PrintDefaults()
        }
        flag.Parse()
@@ -114,23 +146,87 @@ func GetOut() (string, error) {
 // GetModels returns the list of models to download
 func GetModels() []string {
        if flag.NArg() == 0 {
-               return modelNames
-       } else {
-               return flag.Args()
+               fmt.Println("No model specified.")
+               fmt.Println("Preparing to download all models...")
+
+               // Calculate total download size
+               fmt.Println("Calculating total download size...")
+               totalSize, err := CalculateTotalDownloadSize(modelNames)
+               if err != nil {
+                       fmt.Println("Error calculating download sizes:", err)
+                       os.Exit(1)
+               }
+
+               fmt.Println("View available models: https://huggingface.co/ggerganov/whisper.cpp/tree/main")
+               fmt.Printf("Total download size: %.2f GB\n", float64(totalSize)/(1024*1024*1024))
+               fmt.Println("Would you like to download all models? (y/N)")
+
+               // Prompt for user input
+               var response string
+               fmt.Scanln(&response)
+               if response != "y" && response != "Y" {
+                       fmt.Println("Aborting. Specify a model to download.")
+                       os.Exit(0)
+               }
+
+               return modelNames // Return all models if confirmed
        }
+       return flag.Args() // Return specific models if arguments are provided
+}
+
+func CalculateTotalDownloadSize(models []string) (int64, error) {
+       var totalSize int64
+       client := http.Client{}
+
+       for _, model := range models {
+               modelURL, err := URLForModel(model)
+               if err != nil {
+                       return 0, err
+               }
+
+               // Issue a HEAD request to get the file size
+               req, err := http.NewRequest("HEAD", modelURL, nil)
+               if err != nil {
+                       return 0, err
+               }
+
+               resp, err := client.Do(req)
+               if err != nil {
+                       return 0, err
+               }
+               resp.Body.Close()
+
+               if resp.StatusCode != http.StatusOK {
+                       fmt.Printf("Warning: Unable to fetch size for %s (HTTP %d)\n", model, resp.StatusCode)
+                       continue
+               }
+
+               size := resp.ContentLength
+               totalSize += size
+       }
+       return totalSize, nil
 }
 
 // URLForModel returns the URL for the given model on huggingface.co
 func URLForModel(model string) (string, error) {
+       // Ensure "ggml-" prefix is added only once
+       if !strings.HasPrefix(model, "ggml-") {
+               model = "ggml-" + model
+       }
+
+       // Ensure ".bin" extension is added only once
        if filepath.Ext(model) != srcExt {
                model += srcExt
        }
+
+       // Parse the base URL
        url, err := url.Parse(srcUrl)
        if err != nil {
                return "", err
-       } else {
-               url.Path = filepath.Join(url.Path, model)
        }
+
+       // Ensure no trailing slash in the base URL
+       url.Path = fmt.Sprintf("%s/%s", strings.TrimSuffix(url.Path, "/"), model)
        return url.String(), nil
 }
 
index f329011deb1c99a1987ca794c13bffc11df0e25b..566aa1bfe74801b902d7e93b98b3ec919dcedfe2 100644 (file)
@@ -8,7 +8,18 @@ popd
 set argc=0
 for %%x in (%*) do set /A argc+=1
 
-set models=tiny.en tiny base.en base small.en small medium.en medium large-v1 large-v2 large-v3 large-v3-turbo
+set models=tiny tiny-q5_1 tiny-q8_0 ^
+tiny.en tiny.en-q5_1 tiny.en-q8_0 ^
+base base-q5_1 base-q8_0 ^
+base.en base.en-q5_1 base.en-q8_0 ^
+small small-q5_1 small-q8_0 ^
+small.en small.en-q5_1 small.en-q8_0 ^
+medium medium-q5_0 medium-q8_0 ^
+medium.en medium.en-q5_0 medium.en-q8_0 ^
+large-v1 ^
+large-v2 large-v2-q5_0 large-v2-q8_0 ^
+large-v3 large-v3-q5_0 ^
+large-v3-turbo large-v3-turbo-q5_0 large-v3-turbo-q8_0
 
 if %argc% neq 1 (
   echo.
@@ -50,7 +61,7 @@ if %ERRORLEVEL% neq 0 (
 
 echo Done! Model %model% saved in %root_path%\models\ggml-%model%.bin
 echo You can now use it like this:
-echo build\bin\Release\whisper-cli.exe -m %root_path%\models\ggml-%model%.bin -f %root_path%\samples\jfk.wav
+echo %~dp0build\bin\Release\whisper-cli.exe -m %root_path%\models\ggml-%model%.bin -f %root_path%\samples\jfk.wav
 
 goto :eof