"net/url"
"os"
"path/filepath"
+ "strings"
"syscall"
"time"
)
// CONSTANTS
const (
- srcUrl = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main" // The location of the models
- srcExt = ".bin" // Filename extension
- bufSize = 1024 * 64 // Size of the buffer used for downloading the model
+ srcUrl = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/" // The location of the models
+ srcExt = ".bin" // Filename extension
+ bufSize = 1024 * 64 // Size of the buffer used for downloading the model
)
var (
// The models which will be downloaded, if no model is specified as an argument
- modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large-v2", "ggml-large-v3", "large-v3-turbo"}
+ modelNames = []string{
+ "tiny", "tiny-q5_1", "tiny-q8_0",
+ "tiny.en", "tiny.en-q5_1", "tiny.en-q8_0",
+ "base", "base-q5_1", "base-q8_0",
+ "base.en", "base.en-q5_1", "base.en-q8_0",
+ "small", "small-q5_1", "small-q8_0",
+ "small.en", "small.en-q5_1", "small.en-q8_0",
+ "medium", "medium-q5_0", "medium-q8_0",
+ "medium.en", "medium.en-q5_0", "medium.en-q8_0",
+ "large-v1",
+ "large-v2", "large-v2-q5_0", "large-v2-q8_0",
+ "large-v3", "large-v3-q5_0",
+ "large-v3-turbo", "large-v3-turbo-q5_0", "large-v3-turbo-q8_0",
+ }
)
var (
func main() {
flag.Usage = func() {
name := filepath.Base(flag.CommandLine.Name())
- fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [options] <model>\n\n", name)
+ fmt.Fprintf(flag.CommandLine.Output(), `
+ Usage: %s [options] [<model>...]
+
+ Options:
+ -out string Specify the output folder where models will be saved.
+ Default: Current working directory.
+ -timeout duration Set the maximum duration for downloading a model.
+ Example: 10m, 1h (default: 30m0s).
+ -quiet Suppress all output except errors.
+
+ Examples:
+ 1. Download a specific model:
+ %s -out ./models tiny-q8_0
+
+ 2. Download all models:
+ %s -out ./models
+
+ `, name, name, name)
+
flag.PrintDefaults()
}
flag.Parse()
// GetModels returns the list of models to download
func GetModels() []string {
if flag.NArg() == 0 {
- return modelNames
- } else {
- return flag.Args()
+ fmt.Println("No model specified.")
+ fmt.Println("Preparing to download all models...")
+
+ // Calculate total download size
+ fmt.Println("Calculating total download size...")
+ totalSize, err := CalculateTotalDownloadSize(modelNames)
+ if err != nil {
+ fmt.Println("Error calculating download sizes:", err)
+ os.Exit(1)
+ }
+
+ fmt.Println("View available models: https://huggingface.co/ggerganov/whisper.cpp/tree/main")
+ fmt.Printf("Total download size: %.2f GB\n", float64(totalSize)/(1024*1024*1024))
+ fmt.Println("Would you like to download all models? (y/N)")
+
+ // Prompt for user input
+ var response string
+ fmt.Scanln(&response)
+ if response != "y" && response != "Y" {
+ fmt.Println("Aborting. Specify a model to download.")
+ os.Exit(0)
+ }
+
+ return modelNames // Return all models if confirmed
}
+ return flag.Args() // Return specific models if arguments are provided
+}
+
+func CalculateTotalDownloadSize(models []string) (int64, error) {
+ var totalSize int64
+ client := http.Client{}
+
+ for _, model := range models {
+ modelURL, err := URLForModel(model)
+ if err != nil {
+ return 0, err
+ }
+
+ // Issue a HEAD request to get the file size
+ req, err := http.NewRequest("HEAD", modelURL, nil)
+ if err != nil {
+ return 0, err
+ }
+
+ resp, err := client.Do(req)
+ if err != nil {
+ return 0, err
+ }
+ resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ fmt.Printf("Warning: Unable to fetch size for %s (HTTP %d)\n", model, resp.StatusCode)
+ continue
+ }
+
+ size := resp.ContentLength
+ totalSize += size
+ }
+ return totalSize, nil
}
// URLForModel returns the URL for the given model on huggingface.co
func URLForModel(model string) (string, error) {
+ // Ensure "ggml-" prefix is added only once
+ if !strings.HasPrefix(model, "ggml-") {
+ model = "ggml-" + model
+ }
+
+ // Ensure ".bin" extension is added only once
if filepath.Ext(model) != srcExt {
model += srcExt
}
+
+ // Parse the base URL
url, err := url.Parse(srcUrl)
if err != nil {
return "", err
- } else {
- url.Path = filepath.Join(url.Path, model)
}
+
+ // Ensure no trailing slash in the base URL
+ url.Path = fmt.Sprintf("%s/%s", strings.TrimSuffix(url.Path, "/"), model)
return url.String(), nil
}