]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : add context param to disable gpu (#1293)
authorJhen-Jie Hong <redacted>
Mon, 6 Nov 2023 09:04:24 +0000 (17:04 +0800)
committerGitHub <redacted>
Mon, 6 Nov 2023 09:04:24 +0000 (11:04 +0200)
* whisper : check state->ctx_metal not null

* whisper : add whisper_context_params { use_gpu }

* whisper : new API with params & deprecate old API

* examples : use no-gpu param && whisper_init_from_file_with_params

* whisper.objc : enable metal & disable on simulator

* whisper.swiftui, metal : enable metal & support load default.metallib

* whisper.android : use new API

* bindings : use new API

* addon.node : fix build & test

* bindings : updata java binding

* bindings : add missing whisper_context_default_params_by_ref WHISPER_API for java

* metal : use SWIFTPM_MODULE_BUNDLE for GGML_SWIFT and reuse library load

* metal : move bundle var into block

* metal : use SWIFT_PACKAGE instead of GGML_SWIFT

* style : minor updates

---------

Co-authored-by: Georgi Gerganov <redacted>
29 files changed:
bindings/go/whisper.go
bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperContext.java
bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCpp.java
bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java
bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperContextParams.java [new file with mode: 0644]
bindings/javascript/emscripten.cpp
bindings/ruby/ext/ruby_whisper.cpp
examples/addon.node/__test__/whisper.spec.js
examples/addon.node/addon.cpp
examples/addon.node/index.js
examples/bench.wasm/emscripten.cpp
examples/bench/bench.cpp
examples/command.wasm/emscripten.cpp
examples/command/command.cpp
examples/lsp/lsp.cpp
examples/main/main.cpp
examples/stream.wasm/emscripten.cpp
examples/stream/stream.cpp
examples/talk-llama/talk-llama.cpp
examples/talk.wasm/emscripten.cpp
examples/talk/talk.cpp
examples/whisper.android/app/src/main/jni/whisper/jni.c
examples/whisper.objc/whisper.objc.xcodeproj/project.pbxproj
examples/whisper.objc/whisper.objc/ViewController.m
examples/whisper.swiftui/whisper.cpp.swift/LibWhisper.swift
examples/whisper.swiftui/whisper.swiftui.xcodeproj/project.pbxproj
examples/whisper.wasm/emscripten.cpp
whisper.cpp
whisper.h

index e605d8e0c85b16d06ec25226e3d15ba27a69d62d..b77e103c4e3cd8ef9efd0741d30250035da8beee 100644 (file)
@@ -103,7 +103,7 @@ var (
 func Whisper_init(path string) *Context {
        cPath := C.CString(path)
        defer C.free(unsafe.Pointer(cPath))
-       if ctx := C.whisper_init_from_file(cPath); ctx != nil {
+       if ctx := C.whisper_init_from_file_with_params(cPath, C.whisper_context_default_params()); ctx != nil {
                return (*Context)(ctx)
        } else {
                return nil
index 22d4ce87fe6284e1fac1e5b80baa0da1074eb650..0498eb4df817f96a06c1180b476a0c2637d500fa 100644 (file)
@@ -4,6 +4,7 @@ import com.sun.jna.Structure;
 import com.sun.jna.ptr.PointerByReference;\r
 import io.github.ggerganov.whispercpp.ggml.GgmlType;\r
 import io.github.ggerganov.whispercpp.WhisperModel;\r
+import io.github.ggerganov.whispercpp.params.WhisperContextParams;\r
 \r
 import java.util.List;\r
 \r
@@ -23,8 +24,9 @@ public class WhisperContext extends Structure {
     public PointerByReference vocab;\r
     public PointerByReference state;\r
 \r
-    /** populated by whisper_init_from_file() */\r
+    /** populated by whisper_init_from_file_with_params() */\r
     String path_model;\r
+    WhisperContextParams params;\r
 \r
 //    public static class ByReference extends WhisperContext implements Structure.ByReference {\r
 //    }\r
index 9bc1a8601a923eb100b1fcd0f36cbc60a932f668..4a25040377ceccc71b9322f57b89d5aa81223579 100644 (file)
@@ -2,6 +2,7 @@ package io.github.ggerganov.whispercpp;
 \r
 import com.sun.jna.Native;\r
 import com.sun.jna.Pointer;\r
+import io.github.ggerganov.whispercpp.params.WhisperContextParams;\r
 import io.github.ggerganov.whispercpp.params.WhisperFullParams;\r
 import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;\r
 \r
@@ -15,8 +16,9 @@ import java.io.IOException;
 public class WhisperCpp implements AutoCloseable {\r
     private WhisperCppJnaLibrary lib = WhisperCppJnaLibrary.instance;\r
     private Pointer ctx = null;\r
-    private Pointer greedyPointer = null;\r
-    private Pointer beamPointer = null;\r
+    private Pointer paramsPointer = null;\r
+    private Pointer greedyParamsPointer = null;\r
+    private Pointer beamParamsPointer = null;\r
 \r
     public File modelDir() {\r
         String modelDirPath = System.getenv("XDG_CACHE_HOME");\r
@@ -31,6 +33,18 @@ public class WhisperCpp implements AutoCloseable {
      * @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")\r
      */\r
     public void initContext(String modelPath) throws FileNotFoundException {\r
+        initContextImpl(modelPath, getContextDefaultParams());\r
+    }\r
+\r
+    /**\r
+     * @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")\r
+     * @param params - params to use when initialising the context\r
+     */\r
+    public void initContext(String modelPath, WhisperContextParams params) throws FileNotFoundException {\r
+        initContextImpl(modelPath, params);\r
+    }\r
+\r
+    private void initContextImpl(String modelPath, WhisperContextParams params) throws FileNotFoundException {\r
         if (ctx != null) {\r
             lib.whisper_free(ctx);\r
         }\r
@@ -43,13 +57,26 @@ public class WhisperCpp implements AutoCloseable {
             modelPath = new File(modelDir(), modelPath).getAbsolutePath();\r
         }\r
 \r
-        ctx = lib.whisper_init_from_file(modelPath);\r
+        ctx = lib.whisper_init_from_file_with_params(modelPath, params);\r
 \r
         if (ctx == null) {\r
             throw new FileNotFoundException(modelPath);\r
         }\r
     }\r
 \r
+    /**\r
+     * Provides default params which can be used with `whisper_init_from_file_with_params()` etc.\r
+     * Because this function allocates memory for the params, the caller must call either:\r
+     * - call `whisper_free_context_params()`\r
+     * - `Native.free(Pointer.nativeValue(pointer));`\r
+     */\r
+    public WhisperContextParams getContextDefaultParams() {\r
+        paramsPointer = lib.whisper_context_default_params_by_ref();\r
+        WhisperContextParams params = new WhisperContextParams(paramsPointer);\r
+        params.read();\r
+        return params;\r
+    }\r
+    \r
     /**\r
      * Provides default params which can be used with `whisper_full()` etc.\r
      * Because this function allocates memory for the params, the caller must call either:\r
@@ -63,15 +90,15 @@ public class WhisperCpp implements AutoCloseable {
 \r
         // whisper_full_default_params_by_ref allocates memory which we need to delete, so only create max 1 pointer for each strategy.\r
         if (strategy == WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY) {\r
-            if (greedyPointer == null) {\r
-                greedyPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());\r
+            if (greedyParamsPointer == null) {\r
+                greedyParamsPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());\r
             }\r
-            pointer = greedyPointer;\r
+            pointer = greedyParamsPointer;\r
         } else {\r
-            if (beamPointer == null) {\r
-                beamPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());\r
+            if (beamParamsPointer == null) {\r
+                beamParamsPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());\r
             }\r
-            pointer = beamPointer;\r
+            pointer = beamParamsPointer;\r
         }\r
 \r
         WhisperFullParams params = new WhisperFullParams(pointer);\r
@@ -93,13 +120,17 @@ public class WhisperCpp implements AutoCloseable {
     }\r
 \r
     private void freeParams() {\r
-        if (greedyPointer != null) {\r
-            Native.free(Pointer.nativeValue(greedyPointer));\r
-            greedyPointer = null;\r
+        if (paramsPointer != null) {\r
+            Native.free(Pointer.nativeValue(paramsPointer));\r
+            paramsPointer = null;\r
+        }\r
+        if (greedyParamsPointer != null) {\r
+            Native.free(Pointer.nativeValue(greedyParamsPointer));\r
+            greedyParamsPointer = null;\r
         }\r
-        if (beamPointer != null) {\r
-            Native.free(Pointer.nativeValue(beamPointer));\r
-            beamPointer = null;\r
+        if (beamParamsPointer != null) {\r
+            Native.free(Pointer.nativeValue(beamParamsPointer));\r
+            beamParamsPointer = null;\r
         }\r
     }\r
 \r
index ad9faa0be70485615e3fa8c2353ae55dff9ebe15..56a37380136999952bea8853ccb59ecdc28ffd6a 100644 (file)
@@ -5,6 +5,7 @@ import com.sun.jna.Native;
 import com.sun.jna.Pointer;\r
 import io.github.ggerganov.whispercpp.model.WhisperModelLoader;\r
 import io.github.ggerganov.whispercpp.model.WhisperTokenData;\r
+import io.github.ggerganov.whispercpp.params.WhisperContextParams;\r
 import io.github.ggerganov.whispercpp.params.WhisperFullParams;\r
 \r
 public interface WhisperCppJnaLibrary extends Library {\r
@@ -13,12 +14,31 @@ public interface WhisperCppJnaLibrary extends Library {
     String whisper_print_system_info();\r
 \r
     /**\r
-     * Allocate (almost) all memory needed for the model by loading from a file.\r
+     * DEPRECATED. Allocate (almost) all memory needed for the model by loading from a file.\r
      *\r
      * @param path_model Path to the model file\r
      * @return Whisper context on success, null on failure\r
      */\r
     Pointer whisper_init_from_file(String path_model);\r
+    \r
+    /**\r
+     * Provides default params which can be used with `whisper_init_from_file_with_params()` etc.\r
+     * Because this function allocates memory for the params, the caller must call either:\r
+     * - call `whisper_free_context_params()`\r
+     * - `Native.free(Pointer.nativeValue(pointer));`\r
+     */\r
+    Pointer whisper_context_default_params_by_ref();\r
+\r
+    void whisper_free_context_params(Pointer params);\r
+\r
+    /**\r
+     * Allocate (almost) all memory needed for the model by loading from a file.\r
+     *\r
+     * @param path_model Path to the model file\r
+     * @param params     Pointer to whisper_context_params\r
+     * @return Whisper context on success, null on failure\r
+     */\r
+    Pointer whisper_init_from_file_with_params(String path_model, WhisperContextParams params);\r
 \r
     /**\r
      * Allocate (almost) all memory needed for the model by loading from a buffer.\r
diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperContextParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperContextParams.java
new file mode 100644 (file)
index 0000000..cf98d2c
--- /dev/null
@@ -0,0 +1,31 @@
+package io.github.ggerganov.whispercpp.params;
+
+import com.sun.jna.*;
+
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Parameters for the whisper_init_from_file_with_params() function.
+ * If you change the order or add new parameters, make sure to update the default values in whisper.cpp:
+ * whisper_context_default_params()
+ */
+public class WhisperContextParams extends Structure {
+
+    public WhisperContextParams(Pointer p) {
+        super(p);
+    }
+
+    /** Use GPU for inference Number (default = true) */
+    public CBool use_gpu;
+
+    /** Use GPU for inference Number (default = true) */
+    public void useGpu(boolean enable) {
+        use_gpu = enable ? CBool.TRUE : CBool.FALSE;
+    }
+
+    @Override
+    protected List<String> getFieldOrder() {
+        return Arrays.asList("use_gpu");
+    }
+}
index 789ad8b51f8cdd58ec4237ee24442d1c33b69e1a..b442c1fcdbecf076f5d64f5475935017934afd6b 100644 (file)
@@ -20,7 +20,7 @@ struct whisper_context * g_context;
 EMSCRIPTEN_BINDINGS(whisper) {
     emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
         if (g_context == nullptr) {
-            g_context = whisper_init_from_file(path_model.c_str());
+            g_context = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
             if (g_context != nullptr) {
                 return true;
             } else {
index 82027d42fa516dba8a8e3468fb9fc4b213f12df8..86af9391e2c31e3695e08c649021909279432a67 100644 (file)
@@ -87,7 +87,7 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
   if (!rb_respond_to(whisper_model_file_path, rb_intern("to_s"))) {
     rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
   }
-  rw->context = whisper_init_from_file(StringValueCStr(whisper_model_file_path));
+  rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params());
   if (rw->context == nullptr) {
     rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context");
   }
index 845af2f07908bf68d21d5b255cfc2e3d713a1775..d102fe7624ee13435b4d31d253343f380f334ae9 100644 (file)
@@ -11,6 +11,7 @@ const whisperParamsMock = {
   language: "en",
   model: path.join(__dirname, "../../../models/ggml-base.en.bin"),
   fname_inp: path.join(__dirname, "../../../samples/jfk.wav"),
+  use_gpu: true,
 };
 
 describe("Run whisper.node", () => {
index 52e80ad8528c703ef6da2d979d3cb3b86bafd588..30acbc6afd8da8452a057e85cc30f8e85bee8ea3 100644 (file)
@@ -36,6 +36,7 @@ struct whisper_params {
     bool print_colors   = false;
     bool print_progress = false;
     bool no_timestamps  = false;
+    bool use_gpu        = true;
 
     std::string language = "en";
     std::string prompt;
@@ -153,7 +154,9 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
 
     // whisper init
 
-    struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
+    struct whisper_context_params cparams;
+    cparams.use_gpu = params.use_gpu;
+    struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
 
     if (ctx == nullptr) {
         fprintf(stderr, "error: failed to initialize whisper context\n");
@@ -315,10 +318,12 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
   std::string language = whisper_params.Get("language").As<Napi::String>();
   std::string model = whisper_params.Get("model").As<Napi::String>();
   std::string input = whisper_params.Get("fname_inp").As<Napi::String>();
+  bool use_gpu = whisper_params.Get("use_gpu").As<Napi::Boolean>();
 
   params.language = language;
   params.model = model;
   params.fname_inp.emplace_back(input);
+  params.use_gpu = use_gpu;
 
   Napi::Function callback = info[1].As<Napi::Function>();
   Worker* worker = new Worker(callback, params);
index d511cdc2b673ef3e3e191eba6c7df3340caad11a..3c6429375abe20a305c6122e5199e817f0d3d50e 100644 (file)
@@ -11,6 +11,7 @@ const whisperParams = {
   language: "en",
   model: path.join(__dirname, "../../models/ggml-base.en.bin"),
   fname_inp: "../../samples/jfk.wav",
+  use_gpu: true,
 };
 
 const arguments = process.argv.slice(2);
index 09e9d55d972d9e521afaab7c118521bfb3ad22b4..3624bbc48b1183a198538bf43d7eb1c0870788af 100644 (file)
@@ -57,7 +57,7 @@ EMSCRIPTEN_BINDINGS(bench) {
     emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
         for (size_t i = 0; i < g_contexts.size(); ++i) {
             if (g_contexts[i] == nullptr) {
-                g_contexts[i] = whisper_init_from_file(path_model.c_str());
+                g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
                 if (g_contexts[i] != nullptr) {
                     if (g_worker.joinable()) {
                         g_worker.join();
index dfb1d11bc29ca9c698c2294282b768141e242db9..9f50b3b622400f75ea34cc83dcbd82914a8bac13 100644 (file)
@@ -11,6 +11,8 @@ struct whisper_params {
     int32_t what = 0; // what to benchmark: 0 - whisper ecoder, 1 - memcpy, 2 - ggml_mul_mat
 
     std::string model = "models/ggml-base.en.bin";
+
+    bool use_gpu = true;
 };
 
 void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@@ -23,9 +25,10 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
             whisper_print_usage(argc, argv, params);
             exit(0);
         }
-        else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
-        else if (arg == "-m" || arg == "--model")   { params.model     = argv[++i]; }
-        else if (arg == "-w" || arg == "--what")    { params.what     = atoi(argv[++i]); }
+        else if (arg == "-t"  || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
+        else if (arg == "-m"  || arg == "--model")   { params.model     = argv[++i]; }
+        else if (arg == "-w"  || arg == "--what")    { params.what      = atoi(argv[++i]); }
+        else if (arg == "-ng" || arg == "--no-gpu")  { params.use_gpu   = false; }
         else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
             whisper_print_usage(argc, argv, params);
@@ -45,6 +48,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -t N,     --threads N   [%-7d] number of threads to use during computation\n", params.n_threads);
     fprintf(stderr, "  -m FNAME, --model FNAME [%-7s] model path\n",                                  params.model.c_str());
     fprintf(stderr, "  -w N,     --what N      [%-7d] what to benchmark:\n",                          params.what);
+    fprintf(stderr, "  -ng,      --no-gpu      [%-7s] disable GPU\n",                                 params.use_gpu ? "false" : "true");
     fprintf(stderr, "                           %-7s  0 - whisper\n",                                 "");
     fprintf(stderr, "                           %-7s  1 - memcpy\n",                                  "");
     fprintf(stderr, "                           %-7s  2 - ggml_mul_mat\n",                            "");
@@ -54,7 +58,10 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
 int whisper_bench_full(const whisper_params & params) {
     // whisper init
 
-    struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
+    struct whisper_context_params cparams;
+    cparams.use_gpu = params.use_gpu;
+
+    struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
 
     {
         fprintf(stderr, "\n");
index e739656dc6ddf84c4928108d88cdf58bc6f3cd12..528ff6ab5534aa92ddb4550814143959efa36cbd 100644 (file)
@@ -243,7 +243,7 @@ EMSCRIPTEN_BINDINGS(command) {
     emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
         for (size_t i = 0; i < g_contexts.size(); ++i) {
             if (g_contexts[i] == nullptr) {
-                g_contexts[i] = whisper_init_from_file(path_model.c_str());
+                g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
                 if (g_contexts[i] != nullptr) {
                     g_running = true;
                     if (g_worker.joinable()) {
index d39af7309a29cc4e25a0290a9f7913a42b1e7b77..7045f5ff81ee40888755e5dc6d7167ec6c0523d3 100644 (file)
@@ -38,6 +38,7 @@ struct whisper_params {
     bool print_special = false;
     bool print_energy  = false;
     bool no_timestamps = true;
+    bool use_gpu       = true;
 
     std::string language  = "en";
     std::string model     = "models/ggml-base.en.bin";
@@ -68,6 +69,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-tr"  || arg == "--translate")     { params.translate     = true; }
         else if (arg == "-ps"  || arg == "--print-special") { params.print_special = true; }
         else if (arg == "-pe"  || arg == "--print-energy")  { params.print_energy  = true; }
+        else if (arg == "-ng"  || arg == "--no-gpu")        { params.use_gpu       = false; }
         else if (arg == "-l"   || arg == "--language")      { params.language      = argv[++i]; }
         else if (arg == "-m"   || arg == "--model")         { params.model         = argv[++i]; }
         else if (arg == "-f"   || arg == "--file")          { params.fname_out     = argv[++i]; }
@@ -101,6 +103,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -tr,        --translate      [%-7s] translate from source language to english\n",   params.translate ? "true" : "false");
     fprintf(stderr, "  -ps,        --print-special  [%-7s] print special tokens\n",                        params.print_special ? "true" : "false");
     fprintf(stderr, "  -pe,        --print-energy   [%-7s] print sound energy (for debugging)\n",          params.print_energy ? "true" : "false");
+    fprintf(stderr, "  -ng,        --no-gpu         [%-7s] disable GPU\n",                                 params.use_gpu ? "false" : "true");
     fprintf(stderr, "  -l LANG,    --language LANG  [%-7s] spoken language\n",                             params.language.c_str());
     fprintf(stderr, "  -m FNAME,   --model FNAME    [%-7s] model path\n",                                  params.model.c_str());
     fprintf(stderr, "  -f FNAME,   --file FNAME     [%-7s] text output file name\n",                       params.fname_out.c_str());
@@ -610,7 +613,10 @@ int main(int argc, char ** argv) {
 
     // whisper init
 
-    struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
+    struct whisper_context_params cparams;
+    cparams.use_gpu = params.use_gpu;
+
+    struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
 
     // print some info about the processing
     {
index b8001b95702f1c1b141b22a5327806458477478d..8d8b6ffa238892cdadcddadc1b6af7cd9311ed31 100644 (file)
@@ -30,6 +30,7 @@ struct whisper_params {
     bool translate     = false;
     bool print_special = false;
     bool print_energy  = false;
+    bool use_gpu       = true;
 
     std::string language  = "en";
     std::string model     = "models/ggml-base.en.bin";
@@ -72,6 +73,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-tr"  || arg == "--translate")     { params.translate     = true; }
         else if (arg == "-ps"  || arg == "--print-special") { params.print_special = true; }
         else if (arg == "-pe"  || arg == "--print-energy")  { params.print_energy  = true; }
+        else if (arg == "-ng"  || arg == "--no-gpu")        { params.use_gpu       = false; }
         else if (arg == "-l"   || arg == "--language")      { params.language      = argv[++i]; }
         else if (arg == "-m"   || arg == "--model")         { params.model         = argv[++i]; }
         else {
@@ -102,6 +104,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -tr,        --translate      [%-7s] translate from source language to english\n",   params.translate ? "true" : "false");
     fprintf(stderr, "  -ps,        --print-special  [%-7s] print special tokens\n",                        params.print_special ? "true" : "false");
     fprintf(stderr, "  -pe,        --print-energy   [%-7s] print sound energy (for debugging)\n",          params.print_energy ? "true" : "false");
+    fprintf(stderr, "  -ng,        --no-gpu         [%-7s] disable GPU\n",                                 params.use_gpu ? "false" : "true");
     fprintf(stderr, "  -l LANG,    --language LANG  [%-7s] spoken language\n",                             params.language.c_str());
     fprintf(stderr, "  -m FNAME,   --model FNAME    [%-7s] model path\n",                                  params.model.c_str());
     fprintf(stderr, "\n");
@@ -432,7 +435,9 @@ int main(int argc, char ** argv) {
     }
 
     // whisper init
-    struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
+    struct whisper_context_params cparams;
+    cparams.use_gpu = params.use_gpu;
+    struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
     // init audio
 
     audio_async audio(30*1000);
index bed0789f9ad06ad4b4cd04d2b8bdff586abb1c98..e43dfe3f948f916836fb6022f4379355bd5d2bcf 100644 (file)
@@ -90,6 +90,7 @@ struct whisper_params {
     bool print_progress  = false;
     bool no_timestamps   = false;
     bool log_score       = false;
+    bool use_gpu         = true;
 
     std::string language  = "en";
     std::string prompt;
@@ -165,6 +166,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-f"    || arg == "--file")            { params.fname_inp.emplace_back(argv[++i]); }
         else if (arg == "-oved" || arg == "--ov-e-device")     { params.openvino_encode_device = argv[++i]; }
         else if (arg == "-ls"   || arg == "--log-score")       { params.log_score = true; }
+        else if (arg == "-ng"   || arg == "--no-gpu")          { params.use_gpu = false; }
         else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
             whisper_print_usage(argc, argv, params);
@@ -221,6 +223,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -f FNAME,  --file FNAME        [%-7s] input WAV file path\n",                            "");
     fprintf(stderr, "  -oved D,   --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n",  params.openvino_encode_device.c_str());
     fprintf(stderr, "  -ls,       --log-score         [%-7s] log best decoder scores of tokens\n",              params.log_score?"true":"false");
+    fprintf(stderr, "  -ng,       --no-gpu            [%-7s] disable GPU\n",                                    params.use_gpu ? "false" : "true");
     fprintf(stderr, "\n");
 }
 
@@ -877,7 +880,10 @@ int main(int argc, char ** argv) {
 
     // whisper init
 
-    struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
+    struct whisper_context_params cparams;
+    cparams.use_gpu = params.use_gpu;
+
+    struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
 
     if (ctx == nullptr) {
         fprintf(stderr, "error: failed to initialize whisper context\n");
index 144a14d268fee0402ee903c935aa59b149261cc8..71acffba296c59abecd14f1158741f1ff2838b4c 100644 (file)
@@ -132,7 +132,7 @@ EMSCRIPTEN_BINDINGS(stream) {
     emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
         for (size_t i = 0; i < g_contexts.size(); ++i) {
             if (g_contexts[i] == nullptr) {
-                g_contexts[i] = whisper_init_from_file(path_model.c_str());
+                g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
                 if (g_contexts[i] != nullptr) {
                     g_running = true;
                     if (g_worker.joinable()) {
index c8a452d1267374ead8bcae31872dadce6fc232e0..47f1780b4ea0f0d9ca26daabe9c0b53edf0ffae7 100644 (file)
@@ -48,11 +48,12 @@ struct whisper_params {
     bool no_context    = true;
     bool no_timestamps = false;
     bool tinydiarize   = false;
+    bool save_audio    = false; // save audio to wav file
+    bool use_gpu       = true;
 
     std::string language  = "en";
     std::string model     = "models/ggml-base.en.bin";
     std::string fname_out;
-    bool save_audio = false; // save audio to wav file
 };
 
 void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@@ -65,25 +66,26 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
             whisper_print_usage(argc, argv, params);
             exit(0);
         }
-        else if (arg == "-t"   || arg == "--threads")       { params.n_threads     = std::stoi(argv[++i]); }
-        else if (                 arg == "--step")          { params.step_ms       = std::stoi(argv[++i]); }
-        else if (                 arg == "--length")        { params.length_ms     = std::stoi(argv[++i]); }
-        else if (                 arg == "--keep")          { params.keep_ms       = std::stoi(argv[++i]); }
-        else if (arg == "-c"   || arg == "--capture")       { params.capture_id    = std::stoi(argv[++i]); }
-        else if (arg == "-mt"  || arg == "--max-tokens")    { params.max_tokens    = std::stoi(argv[++i]); }
-        else if (arg == "-ac"  || arg == "--audio-ctx")     { params.audio_ctx     = std::stoi(argv[++i]); }
-        else if (arg == "-vth" || arg == "--vad-thold")     { params.vad_thold     = std::stof(argv[++i]); }
-        else if (arg == "-fth" || arg == "--freq-thold")    { params.freq_thold    = std::stof(argv[++i]); }
-        else if (arg == "-su"  || arg == "--speed-up")      { params.speed_up      = true; }
-        else if (arg == "-tr"  || arg == "--translate")     { params.translate     = true; }
-        else if (arg == "-nf"  || arg == "--no-fallback")   { params.no_fallback   = true; }
-        else if (arg == "-ps"  || arg == "--print-special") { params.print_special = true; }
-        else if (arg == "-kc"  || arg == "--keep-context")  { params.no_context    = false; }
-        else if (arg == "-l"   || arg == "--language")      { params.language      = argv[++i]; }
-        else if (arg == "-m"   || arg == "--model")         { params.model         = argv[++i]; }
-        else if (arg == "-f"   || arg == "--file")          { params.fname_out     = argv[++i]; }
-        else if (arg == "-tdrz" || arg == "--tinydiarize")  { params.tinydiarize   = true; }
-        else if (arg == "-sa"  || arg == "--save-audio")    { params.save_audio    = true; }
+        else if (arg == "-t"    || arg == "--threads")       { params.n_threads     = std::stoi(argv[++i]); }
+        else if (                  arg == "--step")          { params.step_ms       = std::stoi(argv[++i]); }
+        else if (                  arg == "--length")        { params.length_ms     = std::stoi(argv[++i]); }
+        else if (                  arg == "--keep")          { params.keep_ms       = std::stoi(argv[++i]); }
+        else if (arg == "-c"    || arg == "--capture")       { params.capture_id    = std::stoi(argv[++i]); }
+        else if (arg == "-mt"   || arg == "--max-tokens")    { params.max_tokens    = std::stoi(argv[++i]); }
+        else if (arg == "-ac"   || arg == "--audio-ctx")     { params.audio_ctx     = std::stoi(argv[++i]); }
+        else if (arg == "-vth"  || arg == "--vad-thold")     { params.vad_thold     = std::stof(argv[++i]); }
+        else if (arg == "-fth"  || arg == "--freq-thold")    { params.freq_thold    = std::stof(argv[++i]); }
+        else if (arg == "-su"   || arg == "--speed-up")      { params.speed_up      = true; }
+        else if (arg == "-tr"   || arg == "--translate")     { params.translate     = true; }
+        else if (arg == "-nf"   || arg == "--no-fallback")   { params.no_fallback   = true; }
+        else if (arg == "-ps"   || arg == "--print-special") { params.print_special = true; }
+        else if (arg == "-kc"   || arg == "--keep-context")  { params.no_context    = false; }
+        else if (arg == "-l"    || arg == "--language")      { params.language      = argv[++i]; }
+        else if (arg == "-m"    || arg == "--model")         { params.model         = argv[++i]; }
+        else if (arg == "-f"    || arg == "--file")          { params.fname_out     = argv[++i]; }
+        else if (arg == "-tdrz" || arg == "--tinydiarize")   { params.tinydiarize   = true; }
+        else if (arg == "-sa"   || arg == "--save-audio")    { params.save_audio    = true; }
+        else if (arg == "-ng"   || arg == "--no-gpu")        { params.use_gpu       = false; }
 
         else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
@@ -118,8 +120,9 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -l LANG,  --language LANG [%-7s] spoken language\n",                                params.language.c_str());
     fprintf(stderr, "  -m FNAME, --model FNAME   [%-7s] model path\n",                                     params.model.c_str());
     fprintf(stderr, "  -f FNAME, --file FNAME    [%-7s] text output file name\n",                          params.fname_out.c_str());
-    fprintf(stderr, "  -tdrz,     --tinydiarize  [%-7s] enable tinydiarize (requires a tdrz model)\n",     params.tinydiarize ? "true" : "false");
+    fprintf(stderr, "  -tdrz,    --tinydiarize   [%-7s] enable tinydiarize (requires a tdrz model)\n",     params.tinydiarize ? "true" : "false");
     fprintf(stderr, "  -sa,      --save-audio    [%-7s] save the recorded audio to a file\n",              params.save_audio ? "true" : "false");
+    fprintf(stderr, "  -ng,      --no-gpu        [%-7s] disable GPU inference\n",                          params.use_gpu ? "false" : "true");
     fprintf(stderr, "\n");
 }
 
@@ -163,7 +166,10 @@ int main(int argc, char ** argv) {
         exit(0);
     }
 
-    struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
+    struct whisper_context_params cparams;
+    cparams.use_gpu = params.use_gpu;
+
+    struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
 
     std::vector<float> pcmf32    (n_samples_30s, 0.0f);
     std::vector<float> pcmf32_old;
@@ -424,4 +430,4 @@ int main(int argc, char ** argv) {
     whisper_free(ctx);
 
     return 0;
-}
\ No newline at end of file
+}
index e497690e4ade0e25c98aeb03d3d03742e57be44b..6cc30c1653ecb01305fe6366b952229e2890bb73 100644 (file)
@@ -63,6 +63,7 @@ struct whisper_params {
     bool print_energy   = false;
     bool no_timestamps  = true;
     bool verbose_prompt = false;
+    bool use_gpu        = true;
 
     std::string person      = "Georgi";
     std::string language    = "en";
@@ -84,25 +85,26 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
             whisper_print_usage(argc, argv, params);
             exit(0);
         }
-        else if (arg == "-t"   || arg == "--threads")       { params.n_threads     = std::stoi(argv[++i]); }
-        else if (arg == "-vms" || arg == "--voice-ms")      { params.voice_ms      = std::stoi(argv[++i]); }
-        else if (arg == "-c"   || arg == "--capture")       { params.capture_id    = std::stoi(argv[++i]); }
-        else if (arg == "-mt"  || arg == "--max-tokens")    { params.max_tokens    = std::stoi(argv[++i]); }
-        else if (arg == "-ac"  || arg == "--audio-ctx")     { params.audio_ctx     = std::stoi(argv[++i]); }
-        else if (arg == "-vth" || arg == "--vad-thold")     { params.vad_thold     = std::stof(argv[++i]); }
-        else if (arg == "-fth" || arg == "--freq-thold")    { params.freq_thold    = std::stof(argv[++i]); }
-        else if (arg == "-su"  || arg == "--speed-up")      { params.speed_up      = true; }
-        else if (arg == "-tr"  || arg == "--translate")     { params.translate     = true; }
-        else if (arg == "-ps"  || arg == "--print-special") { params.print_special = true; }
-        else if (arg == "-pe"  || arg == "--print-energy")  { params.print_energy  = true; }
-        else if (arg == "--verbose-prompt")                 { params.verbose_prompt = true; }
-        else if (arg == "-p"   || arg == "--person")        { params.person        = argv[++i]; }
-        else if (arg == "--session")                        { params.path_session  = argv[++i];}
-        else if (arg == "-l"   || arg == "--language")      { params.language      = argv[++i]; }
-        else if (arg == "-mw"  || arg == "--model-whisper") { params.model_wsp     = argv[++i]; }
-        else if (arg == "-ml"  || arg == "--model-llama")   { params.model_llama   = argv[++i]; }
-        else if (arg == "-s"   || arg == "--speak")         { params.speak         = argv[++i]; }
-        else if (arg == "--prompt-file")                    {
+        else if (arg == "-t"   || arg == "--threads")        { params.n_threads      = std::stoi(argv[++i]); }
+        else if (arg == "-vms" || arg == "--voice-ms")       { params.voice_ms       = std::stoi(argv[++i]); }
+        else if (arg == "-c"   || arg == "--capture")        { params.capture_id     = std::stoi(argv[++i]); }
+        else if (arg == "-mt"  || arg == "--max-tokens")     { params.max_tokens     = std::stoi(argv[++i]); }
+        else if (arg == "-ac"  || arg == "--audio-ctx")      { params.audio_ctx      = std::stoi(argv[++i]); }
+        else if (arg == "-vth" || arg == "--vad-thold")      { params.vad_thold      = std::stof(argv[++i]); }
+        else if (arg == "-fth" || arg == "--freq-thold")     { params.freq_thold     = std::stof(argv[++i]); }
+        else if (arg == "-su"  || arg == "--speed-up")       { params.speed_up       = true; }
+        else if (arg == "-tr"  || arg == "--translate")      { params.translate      = true; }
+        else if (arg == "-ps"  || arg == "--print-special")  { params.print_special  = true; }
+        else if (arg == "-pe"  || arg == "--print-energy")   { params.print_energy   = true; }
+        else if (arg == "-vp"  || arg == "--verbose-prompt") { params.verbose_prompt = true; }
+        else if (arg == "-ng"  || arg == "--no-gpu")         { params.use_gpu        = false; }
+        else if (arg == "-p"   || arg == "--person")         { params.person         = argv[++i]; }
+        else if (arg == "--session")                         { params.path_session   = argv[++i];}
+        else if (arg == "-l"   || arg == "--language")       { params.language       = argv[++i]; }
+        else if (arg == "-mw"  || arg == "--model-whisper")  { params.model_wsp      = argv[++i]; }
+        else if (arg == "-ml"  || arg == "--model-llama")    { params.model_llama    = argv[++i]; }
+        else if (arg == "-s"   || arg == "--speak")          { params.speak          = argv[++i]; }
+        else if (arg == "--prompt-file")                     {
             std::ifstream file(argv[++i]);
             std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
             if (params.prompt.back() == '\n') {
@@ -110,6 +112,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
             }
         }
         else if (arg == "-f"   || arg == "--file")          { params.fname_out     = argv[++i]; }
+        else if (arg == "-ng"  || arg == "--no-gpu")        { params.use_gpu       = false; }
         else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
             whisper_print_usage(argc, argv, params);
@@ -125,27 +128,28 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "usage: %s [options]\n", argv[0]);
     fprintf(stderr, "\n");
     fprintf(stderr, "options:\n");
-    fprintf(stderr, "  -h,       --help          [default] show this help message and exit\n");
-    fprintf(stderr, "  -t N,     --threads N     [%-7d] number of threads to use during computation\n", params.n_threads);
-    fprintf(stderr, "  -vms N,   --voice-ms N    [%-7d] voice duration in milliseconds\n",              params.voice_ms);
-    fprintf(stderr, "  -c ID,    --capture ID    [%-7d] capture device ID\n",                           params.capture_id);
-    fprintf(stderr, "  -mt N,    --max-tokens N  [%-7d] maximum number of tokens per audio chunk\n",    params.max_tokens);
-    fprintf(stderr, "  -ac N,    --audio-ctx N   [%-7d] audio context size (0 - all)\n",                params.audio_ctx);
-    fprintf(stderr, "  -vth N,   --vad-thold N   [%-7.2f] voice activity detection threshold\n",        params.vad_thold);
-    fprintf(stderr, "  -fth N,   --freq-thold N  [%-7.2f] high-pass frequency cutoff\n",                params.freq_thold);
-    fprintf(stderr, "  -su,      --speed-up      [%-7s] speed up audio by x2 (reduced accuracy)\n",     params.speed_up ? "true" : "false");
-    fprintf(stderr, "  -tr,      --translate     [%-7s] translate from source language to english\n",   params.translate ? "true" : "false");
-    fprintf(stderr, "  -ps,      --print-special [%-7s] print special tokens\n",                        params.print_special ? "true" : "false");
-    fprintf(stderr, "  -pe,      --print-energy  [%-7s] print sound energy (for debugging)\n",          params.print_energy ? "true" : "false");
-    fprintf(stderr, "  -p NAME,  --person NAME   [%-7s] person name (for prompt selection)\n",          params.person.c_str());
-    fprintf(stderr, "  -l LANG,  --language LANG [%-7s] spoken language\n",                             params.language.c_str());
-    fprintf(stderr, "  -mw FILE, --model-whisper [%-7s] whisper model file\n",                          params.model_wsp.c_str());
-    fprintf(stderr, "  -ml FILE, --model-llama   [%-7s] llama model file\n",                            params.model_llama.c_str());
-    fprintf(stderr, "  -s FILE,  --speak TEXT    [%-7s] command for TTS\n",                             params.speak.c_str());
-    fprintf(stderr, "  --prompt-file FNAME       [%-7s] file with custom prompt to start dialog\n",     "");
-    fprintf(stderr, "  --session FNAME       file to cache model state in (may be large!) (default: none)\n");
-    fprintf(stderr, "  --verbose-prompt          [%-7s] print prompt at start\n",                       params.verbose_prompt ? "true" : "false");
-    fprintf(stderr, "  -f FNAME, --file FNAME    [%-7s] text output file name\n",                       params.fname_out.c_str());
+    fprintf(stderr, "  -h,       --help           [default] show this help message and exit\n");
+    fprintf(stderr, "  -t N,     --threads N      [%-7d] number of threads to use during computation\n", params.n_threads);
+    fprintf(stderr, "  -vms N,   --voice-ms N     [%-7d] voice duration in milliseconds\n",              params.voice_ms);
+    fprintf(stderr, "  -c ID,    --capture ID     [%-7d] capture device ID\n",                           params.capture_id);
+    fprintf(stderr, "  -mt N,    --max-tokens N   [%-7d] maximum number of tokens per audio chunk\n",    params.max_tokens);
+    fprintf(stderr, "  -ac N,    --audio-ctx N    [%-7d] audio context size (0 - all)\n",                params.audio_ctx);
+    fprintf(stderr, "  -vth N,   --vad-thold N    [%-7.2f] voice activity detection threshold\n",        params.vad_thold);
+    fprintf(stderr, "  -fth N,   --freq-thold N   [%-7.2f] high-pass frequency cutoff\n",                params.freq_thold);
+    fprintf(stderr, "  -su,      --speed-up       [%-7s] speed up audio by x2 (reduced accuracy)\n",     params.speed_up ? "true" : "false");
+    fprintf(stderr, "  -tr,      --translate      [%-7s] translate from source language to english\n",   params.translate ? "true" : "false");
+    fprintf(stderr, "  -ps,      --print-special  [%-7s] print special tokens\n",                        params.print_special ? "true" : "false");
+    fprintf(stderr, "  -pe,      --print-energy   [%-7s] print sound energy (for debugging)\n",          params.print_energy ? "true" : "false");
+    fprintf(stderr, "  -vp,      --verbose-prompt [%-7s] print prompt at start\n",                       params.verbose_prompt ? "true" : "false");
+    fprintf(stderr, "  -ng,      --no-gpu         [%-7s] disable GPU\n",                                 params.use_gpu ? "false" : "true");
+    fprintf(stderr, "  -p NAME,  --person NAME    [%-7s] person name (for prompt selection)\n",          params.person.c_str());
+    fprintf(stderr, "  -l LANG,  --language LANG  [%-7s] spoken language\n",                             params.language.c_str());
+    fprintf(stderr, "  -mw FILE, --model-whisper  [%-7s] whisper model file\n",                          params.model_wsp.c_str());
+    fprintf(stderr, "  -ml FILE, --model-llama    [%-7s] llama model file\n",                            params.model_llama.c_str());
+    fprintf(stderr, "  -s FILE,  --speak TEXT     [%-7s] command for TTS\n",                             params.speak.c_str());
+    fprintf(stderr, "  --prompt-file FNAME        [%-7s] file with custom prompt to start dialog\n",     "");
+    fprintf(stderr, "  --session FNAME                   file to cache model state in (may be large!) (default: none)\n");
+    fprintf(stderr, "  -f FNAME, --file FNAME     [%-7s] text output file name\n",                       params.fname_out.c_str());
     fprintf(stderr, "\n");
 }
 
@@ -252,7 +256,10 @@ int main(int argc, char ** argv) {
 
     // whisper init
 
-    struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str());
+    struct whisper_context_params cparams;
+    cparams.use_gpu = params.use_gpu;
+
+    struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams);
 
     // llama init
 
@@ -269,6 +276,9 @@ int main(int argc, char ** argv) {
     lcparams.seed       = 1;
     lcparams.f16_kv     = true;
     lcparams.n_threads  = params.n_threads;
+    if (!params.use_gpu) {
+        lcparams.n_gpu_layers = 0;
+    }
 
     struct llama_context * ctx_llama = llama_new_context_with_model(model_llama, lcparams);
 
index 1ea970295ac066f3aa06b874e4b2ad808b6861eb..6d30b295832885cd311c6627aefad2310f9a258e 100644 (file)
@@ -271,7 +271,7 @@ EMSCRIPTEN_BINDINGS(talk) {
     emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
         for (size_t i = 0; i < g_contexts.size(); ++i) {
             if (g_contexts[i] == nullptr) {
-                g_contexts[i] = whisper_init_from_file(path_model.c_str());
+                g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
                 if (g_contexts[i] != nullptr) {
                     g_running = true;
                     if (g_worker.joinable()) {
index 346d9d483fedd4bbde85e0ff3bce9e9c5795437c..cdb1a230b7d6655fb649a4a6f6cc7909fb3e7a3a 100644 (file)
@@ -31,6 +31,7 @@ struct whisper_params {
     bool print_special = false;
     bool print_energy  = false;
     bool no_timestamps = true;
+    bool use_gpu       = true;
 
     std::string person    = "Santa";
     std::string language  = "en";
@@ -61,6 +62,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-tr"  || arg == "--translate")     { params.translate     = true; }
         else if (arg == "-ps"  || arg == "--print-special") { params.print_special = true; }
         else if (arg == "-pe"  || arg == "--print-energy")  { params.print_energy  = true; }
+        else if (arg == "-ng"  || arg == "--no-gpu")        { params.use_gpu       = false; }
         else if (arg == "-p"   || arg == "--person")        { params.person        = argv[++i]; }
         else if (arg == "-l"   || arg == "--language")      { params.language      = argv[++i]; }
         else if (arg == "-mw"  || arg == "--model-whisper") { params.model_wsp     = argv[++i]; }
@@ -94,6 +96,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -tr,      --translate     [%-7s] translate from source language to english\n",   params.translate ? "true" : "false");
     fprintf(stderr, "  -ps,      --print-special [%-7s] print special tokens\n",                        params.print_special ? "true" : "false");
     fprintf(stderr, "  -pe,      --print-energy  [%-7s] print sound energy (for debugging)\n",          params.print_energy ? "true" : "false");
+    fprintf(stderr, "  -ng,      --no-gpu        [%-7s] disable GPU\n",                                 params.use_gpu ? "false" : "true");
     fprintf(stderr, "  -p NAME,  --person NAME   [%-7s] person name (for prompt selection)\n",          params.person.c_str());
     fprintf(stderr, "  -l LANG,  --language LANG [%-7s] spoken language\n",                             params.language.c_str());
     fprintf(stderr, "  -mw FILE, --model-whisper [%-7s] whisper model file\n",                          params.model_wsp.c_str());
@@ -181,8 +184,10 @@ int main(int argc, char ** argv) {
     }
 
     // whisper init
+    struct whisper_context_params cparams;
+    cparams.use_gpu = params.use_gpu;
 
-    struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str());
+    struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams);
 
     // gpt init
 
index c437d0990f1a6fca75982e41788288a71e2d0619..a8b3ded4a32c3c68c40ed82a142095b70559f478 100644 (file)
@@ -127,7 +127,7 @@ static struct whisper_context *whisper_init_from_asset(
             .close = &asset_close
     };
 
-    return whisper_init(&loader);
+    return whisper_init_with_params(&loader, whisper_context_default_params());
 }
 
 JNIEXPORT jlong JNICALL
@@ -147,7 +147,7 @@ Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_initContext(
     UNUSED(thiz);
     struct whisper_context *context = NULL;
     const char *model_path_chars = (*env)->GetStringUTFChars(env, model_path_str, NULL);
-    context = whisper_init_from_file(model_path_chars);
+    context = whisper_init_from_file_with_params(model_path_chars, whisper_context_default_params());
     (*env)->ReleaseStringUTFChars(env, model_path_str, model_path_chars);
     return (jlong) context;
 }
index 06af23e60e317633ee2104a01659738157d18e5a..fd884cf379fdfd719ecff7fc8c8e0638ce9aeabc 100644 (file)
@@ -17,8 +17,8 @@
                18627C8629052BE000BD2A04 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 18627C8529052BE000BD2A04 /* Assets.xcassets */; };
                18627C8929052BE000BD2A04 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 18627C8729052BE000BD2A04 /* LaunchScreen.storyboard */; };
                18627C8C29052BE000BD2A04 /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C8B29052BE000BD2A04 /* main.m */; };
-               18627C9429052C4900BD2A04 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9329052C4900BD2A04 /* whisper.cpp */; settings = {COMPILER_FLAGS = "-DWHISPER_USE_COREML"; }; };
-               18627C9629052C5800BD2A04 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9529052C5800BD2A04 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE"; }; };
+               18627C9429052C4900BD2A04 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9329052C4900BD2A04 /* whisper.cpp */; settings = {COMPILER_FLAGS = "-DWHISPER_USE_COREML -DWHISPER_COREML_ALLOW_FALLBACK -DGGML_USE_METAL"; }; };
+               18627C9629052C5800BD2A04 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9529052C5800BD2A04 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE -DGGML_USE_METAL"; }; };
                18627C9B29052CFF00BD2A04 /* ggml-base.en.bin in Resources */ = {isa = PBXBuildFile; fileRef = 18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */; };
                18ABE15A2AF556340044A204 /* ggml-backend.c in Sources */ = {isa = PBXBuildFile; fileRef = 18ABE1572AF556340044A204 /* ggml-backend.c */; };
                18ABE15B2AF556340044A204 /* ggml-quants.c in Sources */ = {isa = PBXBuildFile; fileRef = 18ABE1592AF556340044A204 /* ggml-quants.c */; };
index 8a1e876c395c6451126851d40166682abba6f9ce..151b05d9c99f77365b3919515347813fca9d22a7 100644 (file)
@@ -61,7 +61,13 @@ void AudioInputCallback(void * inUserData,
         NSLog(@"Loading model from %@", modelPath);
 
         // create ggml context
-        stateInp.ctx = whisper_init_from_file([modelPath UTF8String]);
+
+        struct whisper_context_params cparams = whisper_context_default_params();
+#if TARGET_OS_SIMULATOR
+        cparams.use_gpu = false;
+        NSLog(@"Running on simulator, using CPU");
+#endif
+        stateInp.ctx = whisper_init_from_file_with_params([modelPath UTF8String], cparams);
 
         // check if the model was loaded successfully
         if (stateInp.ctx == NULL) {
index e9645b34f7466c3c0275e0e5f24a9ffdd22ad004..95e1aeefbc3863b9d15aadf2c005c0240420f770 100644 (file)
@@ -55,7 +55,12 @@ actor WhisperContext {
     }
     
     static func createContext(path: String) throws -> WhisperContext {
-        let context = whisper_init_from_file(path)
+        var params = whisper_context_default_params()
+#if targetEnvironment(simulator)
+        params.use_gpu = false
+        print("Running on the simulator, using CPU")
+#endif
+        let context = whisper_init_from_file_with_params(path, params)
         if let context {
             return WhisperContext(context: context)
         } else {
index 832a2a1bda5578f9f95f27c423089c39280a8434..605240da179ce68f5ed7625c171f4727c4a98101 100644 (file)
                0AAC5D9D29539CCF003032C3 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5D9C29539CCF003032C3 /* ContentView.swift */; };
                0AAC5D9F29539CD0003032C3 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 0AAC5D9E29539CD0003032C3 /* Assets.xcassets */; };
                0AAC5DA329539CD0003032C3 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 0AAC5DA229539CD0003032C3 /* Preview Assets.xcassets */; };
-               0AAC5DCB29539EB1003032C3 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DC729539EB0003032C3 /* whisper.cpp */; settings = {COMPILER_FLAGS = "-Wno-shorten-64-to-32"; }; };
-               0AAC5DCC29539EB1003032C3 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DC929539EB0003032C3 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE -Wno-shorten-64-to-32"; }; };
+               0AAC5DCB29539EB1003032C3 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DC729539EB0003032C3 /* whisper.cpp */; settings = {COMPILER_FLAGS = "-DGGML_USE_METAL -Wno-shorten-64-to-32"; }; };
+               0AAC5DCC29539EB1003032C3 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DC929539EB0003032C3 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE -DGGML_USE_METAL -Wno-shorten-64-to-32"; }; };
                0AAC5DCE2953A05C003032C3 /* WhisperState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DCD2953A05C003032C3 /* WhisperState.swift */; };
                0AAC5DD12953A394003032C3 /* LibWhisper.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DD02953A394003032C3 /* LibWhisper.swift */; };
                18ABE1522AF555FA0044A204 /* ggml-backend.c in Sources */ = {isa = PBXBuildFile; fileRef = 18ABE14C2AF555FA0044A204 /* ggml-backend.c */; };
                18ABE1532AF555FA0044A204 /* ggml-quants.c in Sources */ = {isa = PBXBuildFile; fileRef = 18ABE1512AF555FA0044A204 /* ggml-quants.c */; };
                18AED4812AB21F2B009D854F /* ggml-alloc.c in Sources */ = {isa = PBXBuildFile; fileRef = 18AED47F2AB21F2B009D854F /* ggml-alloc.c */; };
+               7FCB08262ACFA3A400AF3530 /* ggml-metal.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FCB08252ACFA3A400AF3530 /* ggml-metal.m */; settings = {COMPILER_FLAGS = "-framework Foundation -framework Metal -framework MetalKit -fno-objc-arc"; }; };
+               7FCB08282ACFA48500AF3530 /* ggml-metal.metal in Sources */ = {isa = PBXBuildFile; fileRef = 7FCB08272ACFA48500AF3530 /* ggml-metal.metal */; };
 /* End PBXBuildFile section */
 
 /* Begin PBXFileReference section */
@@ -52,6 +54,9 @@
                18ABE1512AF555FA0044A204 /* ggml-quants.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = "ggml-quants.c"; sourceTree = "<group>"; };
                18AED47F2AB21F2B009D854F /* ggml-alloc.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = "ggml-alloc.c"; sourceTree = "<group>"; };
                18AED4802AB21F2B009D854F /* ggml-alloc.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "ggml-alloc.h"; sourceTree = "<group>"; };
+               7FCB081E2ACFA04400AF3530 /* ggml-metal.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "ggml-metal.h"; sourceTree = "<group>"; };
+               7FCB08252ACFA3A400AF3530 /* ggml-metal.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = "ggml-metal.m"; sourceTree = "<group>"; };
+               7FCB08272ACFA48500AF3530 /* ggml-metal.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = "ggml-metal.metal"; sourceTree = "<group>"; };
 /* End PBXFileReference section */
 
 /* Begin PBXFrameworksBuildPhase section */
                0AAC5DC529539E89003032C3 /* whisper.cpp */ = {
                        isa = PBXGroup;
                        children = (
+                               7FCB08272ACFA48500AF3530 /* ggml-metal.metal */,
+                               7FCB081E2ACFA04400AF3530 /* ggml-metal.h */,
+                               7FCB08252ACFA3A400AF3530 /* ggml-metal.m */,
                                18ABE14E2AF555FA0044A204 /* ggml-backend-impl.h */,
                                18ABE14C2AF555FA0044A204 /* ggml-backend.c */,
                                18ABE14D2AF555FA0044A204 /* ggml-backend.h */,
                                0AAC5DCC29539EB1003032C3 /* ggml.c in Sources */,
                                18ABE1532AF555FA0044A204 /* ggml-quants.c in Sources */,
                                0AAC5DCE2953A05C003032C3 /* WhisperState.swift in Sources */,
+                               7FCB08282ACFA48500AF3530 /* ggml-metal.metal in Sources */,
                                0AAC5DD12953A394003032C3 /* LibWhisper.swift in Sources */,
                                0AA7514C2953B569001EE061 /* RiffWaveUtils.swift in Sources */,
                                0AAC5DCB29539EB1003032C3 /* whisper.cpp in Sources */,
                                0AA7514E2953D958001EE061 /* Recorder.swift in Sources */,
+                               7FCB08262ACFA3A400AF3530 /* ggml-metal.m in Sources */,
                                18AED4812AB21F2B009D854F /* ggml-alloc.c in Sources */,
                                18ABE1522AF555FA0044A204 /* ggml-backend.c in Sources */,
                        );
index db1ff789e5fcff94dc7050b26da79ecddd81be70..b84893dee734f87ae443cf5e1cb46d0da80ea67e 100644 (file)
@@ -24,7 +24,7 @@ EMSCRIPTEN_BINDINGS(whisper) {
 
         for (size_t i = 0; i < g_contexts.size(); ++i) {
             if (g_contexts[i] == nullptr) {
-                g_contexts[i] = whisper_init_from_file(path_model.c_str());
+                g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
                 if (g_contexts[i] != nullptr) {
                     return i + 1;
                 } else {
index 3e36d362054e4ae7ab3eefe9a50d3c5c380bb007..704460490e660002d244267674c4356336b4d397 100644 (file)
@@ -736,7 +736,7 @@ struct whisper_state {
 
     int lang_id = 0; // english by default
 
-    std::string path_model; // populated by whisper_init_from_file()
+    std::string path_model; // populated by whisper_init_from_file_with_params()
 #ifdef WHISPER_USE_COREML
     whisper_coreml_context * ctx_coreml = nullptr;
 #endif
@@ -770,7 +770,8 @@ struct whisper_context {
     whisper_vocab vocab;
     whisper_state * state = nullptr;
 
-    std::string path_model; // populated by whisper_init_from_file()
+    std::string path_model; // populated by whisper_init_from_file_with_params()
+    whisper_context_params params;
 };
 
 static void whisper_default_log(const char * text) {
@@ -2930,59 +2931,64 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
     }
 
 #ifdef GGML_USE_METAL
-    state->ctx_metal = ggml_metal_init(1);
-    if (!state->ctx_metal) {
-        log("%s: ggml_metal_init() failed\n", __func__);
-        delete state;
-        return nullptr;
+    if (ctx->params.use_gpu) {
+        state->ctx_metal = ggml_metal_init(1);
+        if (!state->ctx_metal) {
+            log("%s: ggml_metal_init() failed\n", __func__);
+            delete state;
+            return nullptr;
+        }
     }
 
-    log("%s: Metal context initialized\n", __func__);
+    if (state->ctx_metal) {
+        log("%s: Metal context initialized\n", __func__);
 
-    // this allocates all Metal resources and memory buffers
+        // this allocates all Metal resources and memory buffers
 
-    void * data_ptr  = NULL;
-    size_t data_size = 0;
+        void * data_ptr  = NULL;
+        size_t data_size = 0;
 
-    // TODO: add mmap support
-    //if (params.use_mmap) {
-    //    data_ptr  = ctx->model.mapping->addr;
-    //    data_size = ctx->model.mapping->size;
-    //} else {
-    //    data_ptr  = ggml_get_mem_buffer(ctx->model.ctx);
-    //    data_size = ggml_get_mem_size  (ctx->model.ctx);
-    //}
+        // TODO: add mmap support
+        //if (params.use_mmap) {
+        //    data_ptr  = ctx->model.mapping->addr;
+        //    data_size = ctx->model.mapping->size;
+        //} else {
+        //    data_ptr  = ggml_get_mem_buffer(ctx->model.ctx);
+        //    data_size = ggml_get_mem_size  (ctx->model.ctx);
+        //}
 
-    data_ptr  = ggml_get_mem_buffer(ctx->model.ctx);
-    data_size = ggml_get_mem_size  (ctx->model.ctx);
+        data_ptr  = ggml_get_mem_buffer(ctx->model.ctx);
+        data_size = ggml_get_mem_size  (ctx->model.ctx);
 
-    const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx);
+        const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx);
 
-    log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
+        log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
 
 #define WHISPER_METAL_CHECK_BUF(result)              \
-    if (!(result)) {                                 \
-        log("%s: failed to add metal buffer\n", __func__); \
-        delete state;                                \
-        return nullptr;                              \
-    }
+        if (!(result)) {                                 \
+            log("%s: failed to add metal buffer\n", __func__); \
+            delete state;                                \
+            return nullptr;                              \
+        }
 
-    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size));
+        WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size));
 
-    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_conv",   state->alloc_conv.meta.data(),   state->alloc_conv.meta.size(),   0));
-    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0));
-    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross",  state->alloc_cross.meta.data(),  state->alloc_cross.meta.size(),  0));
-    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0));
+        WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_conv",   state->alloc_conv.meta.data(),   state->alloc_conv.meta.size(),   0));
+        WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0));
+        WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross",  state->alloc_cross.meta.data(),  state->alloc_cross.meta.size(),  0));
+        WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0));
 
-    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_conv",   state->alloc_conv.data.data(),   state->alloc_conv.data.size(),   0));
-    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0));
-    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross",  state->alloc_cross.data.data(),  state->alloc_cross.data.size(),  0));
-    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0));
+        WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_conv",   state->alloc_conv.data.data(),   state->alloc_conv.data.size(),   0));
+        WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0));
+        WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross",  state->alloc_cross.data.data(),  state->alloc_cross.data.size(),  0));
+        WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0));
 
-    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_cross",  state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0));
+        WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_cross",  state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0));
 
-    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0));
+        WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0));
 #undef WHISPER_METAL_CHECK_BUF
+
+    }
 #endif
 
     state->rng = std::mt19937(0);
@@ -3039,7 +3045,14 @@ int whisper_ctx_init_openvino_encoder(
 #endif
 }
 
-struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
+struct whisper_context_params whisper_context_default_params() {
+    struct whisper_context_params result = {
+        /*.use_gpu    =*/ true,
+    };
+    return result;
+}
+
+struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) {
     log("%s: loading model from '%s'\n", __func__, path_model);
 
     auto fin = std::ifstream(path_model, std::ios::binary);
@@ -3068,7 +3081,7 @@ struct whisper_context * whisper_init_from_file_no_state(const char * path_model
         fin->close();
     };
 
-    auto ctx = whisper_init_no_state(&loader);
+    auto ctx = whisper_init_with_params_no_state(&loader, params);
 
     if (ctx) {
         ctx->path_model = path_model;
@@ -3077,7 +3090,7 @@ struct whisper_context * whisper_init_from_file_no_state(const char * path_model
     return ctx;
 }
 
-struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) {
+struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params) {
     struct buf_context {
         uint8_t* buffer;
         size_t size;
@@ -3111,13 +3124,14 @@ struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t
 
     loader.close = [](void * /*ctx*/) { };
 
-    return whisper_init_no_state(&loader);
+    return whisper_init_with_params_no_state(&loader, params);
 }
 
-struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader) {
+struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) {
     ggml_time_init();
 
     whisper_context * ctx = new whisper_context;
+    ctx->params = params;
 
     if (!whisper_model_load(loader, *ctx)) {
         loader->close(loader->context);
@@ -3131,8 +3145,8 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
     return ctx;
 }
 
-struct whisper_context * whisper_init_from_file(const char * path_model) {
-    whisper_context * ctx = whisper_init_from_file_no_state(path_model);
+struct whisper_context * whisper_init_from_file_with_params(const char * path_model, struct whisper_context_params params) {
+    whisper_context * ctx = whisper_init_from_file_with_params_no_state(path_model, params);
     if (!ctx) {
         return nullptr;
     }
@@ -3146,8 +3160,8 @@ struct whisper_context * whisper_init_from_file(const char * path_model) {
     return ctx;
 }
 
-struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
-    whisper_context * ctx = whisper_init_from_buffer_no_state(buffer, buffer_size);
+struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params) {
+    whisper_context * ctx = whisper_init_from_buffer_with_params_no_state(buffer, buffer_size, params);
     if (!ctx) {
         return nullptr;
     }
@@ -3161,8 +3175,8 @@ struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_s
     return ctx;
 }
 
-struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
-    whisper_context * ctx = whisper_init_no_state(loader);
+struct whisper_context * whisper_init_with_params(struct whisper_model_loader * loader, struct whisper_context_params params) {
+    whisper_context * ctx = whisper_init_with_params_no_state(loader, params);
     if (!ctx) {
         return nullptr;
     }
@@ -3176,6 +3190,30 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
     return ctx;
 }
 
+struct whisper_context * whisper_init_from_file(const char * path_model) {
+    return whisper_init_from_file_with_params(path_model, whisper_context_default_params());
+}
+
+struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
+    return whisper_init_from_buffer_with_params(buffer, buffer_size, whisper_context_default_params());
+}
+
+struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
+    return whisper_init_with_params(loader, whisper_context_default_params());
+}
+
+struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
+    return whisper_init_from_file_with_params_no_state(path_model, whisper_context_default_params());
+}
+
+struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) {
+    return whisper_init_from_buffer_with_params_no_state(buffer, buffer_size, whisper_context_default_params());
+}
+
+struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader) {
+    return whisper_init_with_params_no_state(loader, whisper_context_default_params());
+}
+
 void whisper_free_state(struct whisper_state * state)
 {
     if (state) {
@@ -3230,6 +3268,12 @@ void whisper_free(struct whisper_context * ctx) {
     }
 }
 
+void whisper_free_context_params(struct whisper_context_params * params) {
+    if (params) {
+        delete params;
+    }
+}
+
 void whisper_free_params(struct whisper_full_params * params) {
     if (params) {
         delete params;
@@ -3698,6 +3742,14 @@ const char * whisper_print_system_info(void) {
 
 ////////////////////////////////////////////////////////////////////////////
 
+struct whisper_context_params * whisper_context_default_params_by_ref() {
+    struct whisper_context_params params = whisper_context_default_params();
+
+    struct whisper_context_params* result = new whisper_context_params();
+    *result = params;
+    return result;
+}
+
 struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy) {
     struct whisper_full_params params = whisper_full_default_params(strategy);
 
@@ -4507,17 +4559,19 @@ int whisper_full_with_state(
 
             // TODO: not very clean - look for a better way and potentially merging with the init of decoder 0
 #ifdef GGML_USE_METAL
+            if (state->ctx_metal) {
 #define WHISPER_METAL_CHECK_BUF(result)              \
-            if (!(result)) {                                 \
-                log("%s: failed to add metal buffer\n", __func__); \
-                return 0;                              \
-            }
+                if (!(result)) {                                 \
+                    log("%s: failed to add metal buffer\n", __func__); \
+                    return 0;                              \
+                }
 
-            const std::string kv_name = "kv_self_" + std::to_string(j);
-            auto & kv_self = decoder.kv_self;
+                const std::string kv_name = "kv_self_" + std::to_string(j);
+                auto & kv_self = decoder.kv_self;
 
-            WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0));
+                WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0));
 #undef WHISPER_METAL_CHECK_BUF
+            }
 #endif
         }
     }
index c3118c9c99b7ef0d6b217f4f6ff8ceae1d0a2b8f..300fc4bac375e3eea94753d989eeed7c0e293a7e 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -5,6 +5,14 @@
 #include <stdint.h>
 #include <stdbool.h>
 
+#ifdef __GNUC__
+#    define WHISPER_DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
+#elif defined(_MSC_VER)
+#    define WHISPER_DEPRECATED(func, hint) __declspec(deprecated(hint)) func
+#else
+#    define WHISPER_DEPRECATED(func, hint) func
+#endif
+
 #ifdef WHISPER_SHARED
 #    ifdef _WIN32
 #        ifdef WHISPER_BUILD
@@ -71,6 +79,10 @@ extern "C" {
 
     typedef int whisper_token;
 
+    struct whisper_context_params {
+        bool  use_gpu;
+    };
+
     typedef struct whisper_token_data {
         whisper_token id;  // token id
         whisper_token tid; // forced timestamp token id
@@ -99,15 +111,40 @@ extern "C" {
     // Various functions for loading a ggml whisper model.
     // Allocate (almost) all memory needed for the model.
     // Return NULL on failure
-    WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model);
-    WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size);
-    WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader);
+    WHISPER_API struct whisper_context * whisper_init_from_file_with_params(const char * path_model, struct whisper_context_params params);
+    WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params);
+    WHISPER_API struct whisper_context * whisper_init_with_params(struct whisper_model_loader * loader, struct whisper_context_params params);
 
     // These are the same as the above, but the internal state of the context is not allocated automatically
     // It is the responsibility of the caller to allocate the state using whisper_init_state() (#523)
-    WHISPER_API struct whisper_context * whisper_init_from_file_no_state(const char * path_model);
-    WHISPER_API struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size);
-    WHISPER_API struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader);
+    WHISPER_API struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params);
+    WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params);
+    WHISPER_API struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params);
+
+    WHISPER_DEPRECATED(
+        WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model),
+        "use whisper_init_from_file_with_params instead"
+    );
+    WHISPER_DEPRECATED(
+        WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size),
+        "use whisper_init_from_buffer_with_params instead"
+    );
+    WHISPER_DEPRECATED(
+        WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader),
+        "use whisper_init_with_params instead"
+    );
+    WHISPER_DEPRECATED(
+        WHISPER_API struct whisper_context * whisper_init_from_file_no_state(const char * path_model),
+        "use whisper_init_from_file_with_params_no_state instead"
+    );
+    WHISPER_DEPRECATED(
+        WHISPER_API struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size),
+        "use whisper_init_from_buffer_with_params_no_state instead"
+    );
+    WHISPER_DEPRECATED(
+        WHISPER_API struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader),
+        "use whisper_init_with_params_no_state instead"
+    );
 
     WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx);
 
@@ -132,6 +169,7 @@ extern "C" {
     WHISPER_API void whisper_free      (struct whisper_context * ctx);
     WHISPER_API void whisper_free_state(struct whisper_state * state);
     WHISPER_API void whisper_free_params(struct whisper_full_params * params);
+    WHISPER_API void whisper_free_context_params(struct whisper_context_params * params);
 
     // Convert RAW PCM audio to log mel spectrogram.
     // The resulting spectrogram is stored inside the default state of the provided whisper context.
@@ -442,7 +480,9 @@ extern "C" {
         void * logits_filter_callback_user_data;
     };
 
-    // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_params()
+    // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params()
+    WHISPER_API struct whisper_context_params * whisper_context_default_params_by_ref();
+    WHISPER_API struct whisper_context_params whisper_context_default_params(void);
     WHISPER_API struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy);
     WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);