]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
Feature/java bindings2 (#944)
authorNicholas Albion <redacted>
Sun, 28 May 2023 23:38:58 +0000 (09:38 +1000)
committerGitHub <redacted>
Sun, 28 May 2023 23:38:58 +0000 (09:38 +1000)
* Java needs to call `whisper_full_default_params_by_ref()`, returning struct by val does not seem to work.
* added convenience methods to WhisperFullParams
* Remove unused WhisperJavaParams

21 files changed:
.github/workflows/build.yml
bindings/java/CMakeLists.txt [deleted file]
bindings/java/README.md
bindings/java/build.gradle
bindings/java/src/main/cpp/whisper_java.cpp [deleted file]
bindings/java/src/main/cpp/whisper_java.h [deleted file]
bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCpp.java
bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java
bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperJavaJnaLibrary.java [deleted file]
bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperEncoderBeginCallback.java
bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperLogitsFilterCallback.java
bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperNewSegmentCallback.java
bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperProgressCallback.java
bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/BeamSearchParams.java [new file with mode: 0644]
bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/CBool.java [new file with mode: 0644]
bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/GreedyParams.java [new file with mode: 0644]
bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java
bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperJavaParams.java [deleted file]
bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperCppTest.java
whisper.cpp
whisper.h

index 657570077302bfdd0a5909b57f83c567d7a9c861..2e25ef62eb2d8a99387da7bae3c397d9ce6a50d6 100644 (file)
@@ -125,8 +125,10 @@ jobs:
         include:
           - arch: Win32
             s2arc: x86
+            jnaPath: win32-x86
           - arch: x64
             s2arc: x64
+            jnaPath: win32-x86-64
           - sdl2: ON
             s2ver: 2.26.0
 
@@ -159,6 +161,12 @@ jobs:
         if: matrix.sdl2 == 'ON'
         run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
 
+      - name: Upload dll
+        uses: actions/upload-artifact@v3
+        with:
+          name: ${{ matrix.jnaPath }}_whisper.dll
+          path: build/bin/${{ matrix.build }}/whisper.dll
+
       - name: Upload binaries
         if: matrix.sdl2 == 'ON'
         uses: actions/upload-artifact@v1
@@ -363,3 +371,42 @@ jobs:
         run: |
           cd examples/whisper.android
           ./gradlew assembleRelease --no-daemon
+
+  java:
+    needs: [ 'windows' ]
+    runs-on: windows-latest
+    steps:
+      - uses: actions/checkout@v1
+
+      - name: Install Java
+        uses: actions/setup-java@v1
+        with:
+          java-version: 17
+
+      - name: Download Windows lib
+        uses: actions/download-artifact@v3
+        with:
+          name: win32-x86-64_whisper.dll
+          path: bindings/java/build/generated/resources/main/win32-x86-64
+
+      - name: Build
+        run: |
+          models\download-ggml-model.cmd tiny.en
+          cd bindings/java
+          chmod +x ./gradlew
+          ./gradlew build
+
+      - name: Upload jar
+        uses: actions/upload-artifact@v3
+        with:
+          name: whispercpp.jar
+          path: bindings/java/build/libs/whispercpp-*.jar
+
+#      - name: Publish package
+#        if: ${{ github.ref == 'refs/heads/master' }}
+#        uses: gradle/gradle-build-action@v2
+#        with:
+#          arguments: publish
+#        env:
+#          MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }}
+#          MAVEN_PASSWORD: ${{ secrets.OSSRH_TOKEN }}
diff --git a/bindings/java/CMakeLists.txt b/bindings/java/CMakeLists.txt
deleted file mode 100644 (file)
index 7e47bb3..0000000
+++ /dev/null
@@ -1,50 +0,0 @@
-cmake_minimum_required(VERSION 3.10)\r
-\r
-project(whisper_java VERSION 1.4.2)\r
-\r
-# Set the target name and source file/s\r
-set(TARGET_NAME whisper_java)\r
-set(SOURCES src/main/cpp/whisper_java.cpp)\r
-\r
-# include <whisper.h>\r
-include_directories(../../)\r
-\r
-# Set the output directory for the DLL/shared library based on the platform as required by JNA\r
-if(WIN32)\r
-    set(OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated/resources/main/win32-x86-64)\r
-elseif(UNIX AND NOT APPLE)\r
-    set(OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated/resources/main/linux-x86-64)\r
-elseif(APPLE)\r
-    set(OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated/resources/main/macos-x86-64)\r
-endif()\r
-\r
-set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${OUTPUT_DIR})\r
-set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${OUTPUT_DIR})\r
-\r
-# Create the whisper_java library\r
-add_library(${TARGET_NAME} SHARED ${SOURCES})\r
-\r
-# Link against ../../build/Release/whisper.dll (or so/dynlib)\r
-target_link_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/../../../build/${CMAKE_BUILD_TYPE})\r
-target_link_libraries(${TARGET_NAME} PRIVATE whisper)\r
-\r
-# Set the appropriate compiler flags for Windows, Linux, and macOS\r
-if(WIN32)\r
-    target_compile_options(${TARGET_NAME} PRIVATE /W4 /D_CRT_SECURE_NO_WARNINGS)\r
-elseif(UNIX AND NOT APPLE)\r
-    target_compile_options(${TARGET_NAME} PRIVATE -Wall -Wextra)\r
-elseif(APPLE)\r
-    target_compile_options(${TARGET_NAME} PRIVATE -Wall -Wextra)\r
-endif()\r
-\r
-target_compile_definitions(${TARGET_NAME} PRIVATE WHISPER_SHARED)\r
-# add_definitions(-DWHISPER_SHARED)\r
-\r
-# Force CMake to save the libs to build/generated/resources/main/${os}-${arch} as required by JNA\r
-foreach(OUTPUTCONFIG ${CMAKE_CONFIGURATION_TYPES})\r
-    string(TOUPPER ${OUTPUTCONFIG} OUTPUTCONFIG)\r
-    set_target_properties(${TARGET_NAME} PROPERTIES\r
-                          RUNTIME_OUTPUT_DIRECTORY_${OUTPUTCONFIG} ${OUTPUT_DIR}\r
-                          LIBRARY_OUTPUT_DIRECTORY_${OUTPUTCONFIG} ${OUTPUT_DIR}\r
-                          ARCHIVE_OUTPUT_DIRECTORY_${OUTPUTCONFIG} ${OUTPUT_DIR})\r
-endforeach(OUTPUTCONFIG CMAKE_CONFIGURATION_TYPES)\r
index 429287c0e37494946ea7ee8cc53f847ce451f58e..24c461ea63852e5c44569347310fb0224f3a7023 100644 (file)
@@ -6,11 +6,7 @@ This package provides Java JNI bindings for whisper.cpp. They have been tested o
   * Ubuntu on x86_64
   * Windows on x86_64
 
-The "low level" bindings are in `WhisperCppJnaLibrary` and `WhisperJavaJnaLibrary` which caches `whisper_full_params` and `whisper_context` in `whisper_java.cpp`. 
-
-There are a lot of classes in the `callbacks`, `ggml`, `model` and `params` directories but most of them have not been tested. 
-
-The most simple usage is as follows:
+The "low level" bindings are in `WhisperCppJnaLibrary`. The most simple usage is as follows:
 
 ```java
 import io.github.ggerganov.whispercpp.WhisperCpp;
@@ -48,12 +44,6 @@ In order to build, you need to have the JDK 8 or higher installed. Run the tests
 git clone https://github.com/ggerganov/whisper.cpp.git
 cd whisper.cpp/bindings/java
 
-mkdir build
-pushd build
-cmake ..
-cmake --build .
-popd
-
 ./gradlew build
 ```
 
index 4a9b02f15858aeb1497d38ee02eae1f8e77787a8..3028f6f6a586e39be51f8b4e48dab3550008e10b 100644 (file)
@@ -22,6 +22,12 @@ sourceSets {
     }\r
 }\r
 \r
+tasks.register('copyLibwhisperDynlib', Copy) {\r
+    from '../../build'\r
+    include 'libwhisper.dynlib'\r
+    into 'build/generated/resources/main/darwin'\r
+}\r
+\r
 tasks.register('copyLibwhisperSo', Copy) {\r
     from '../../build'\r
     include 'libwhisper.so'\r
@@ -34,7 +40,9 @@ tasks.register('copyWhisperDll', Copy) {
     into 'build/generated/resources/main/windows-x86-64'\r
 }\r
 \r
-tasks.build.dependsOn copyLibwhisperSo, copyWhisperDll\r
+tasks.register('copyLibs') {\r
+    dependsOn copyLibwhisperDynlib, copyLibwhisperSo, copyWhisperDll\r
+}\r
 \r
 test {\r
     systemProperty 'jna.library.path', project.file('build/generated/resources/main').absolutePath\r
diff --git a/bindings/java/src/main/cpp/whisper_java.cpp b/bindings/java/src/main/cpp/whisper_java.cpp
deleted file mode 100644 (file)
index 9e06aa0..0000000
+++ /dev/null
@@ -1,33 +0,0 @@
-#include <stdio.h>\r
-#include "whisper_java.h"\r
-\r
-struct whisper_full_params default_params;\r
-struct whisper_context * whisper_ctx = nullptr;\r
-\r
-struct void whisper_java_default_params(enum whisper_sampling_strategy strategy) {\r
-    default_params = whisper_full_default_params(strategy);\r
-\r
-//    struct whisper_java_params result = {};\r
-//    return result;\r
-    return;\r
-}\r
-\r
-void whisper_java_init_from_file(const char * path_model) {\r
-    whisper_ctx = whisper_init_from_file(path_model);\r
-    if (0 == default_params.n_threads) {\r
-        whisper_java_default_params(WHISPER_SAMPLING_GREEDY);\r
-    }\r
-}\r
-\r
-/** Delegates to whisper_full, but without having to pass `whisper_full_params` */\r
-int whisper_java_full(\r
-          struct whisper_context * ctx,\r
-//      struct whisper_java_params   params,\r
-                     const float * samples,\r
-                             int   n_samples) {\r
-    return whisper_full(ctx, default_params, samples, n_samples);\r
-}\r
-\r
-void whisper_java_free() {\r
-//    free(default_params);\r
-}\r
diff --git a/bindings/java/src/main/cpp/whisper_java.h b/bindings/java/src/main/cpp/whisper_java.h
deleted file mode 100644 (file)
index d64866b..0000000
+++ /dev/null
@@ -1,24 +0,0 @@
-#define WHISPER_BUILD\r
-#include <whisper.h>\r
-\r
-#ifdef __cplusplus\r
-extern "C" {\r
-#endif\r
-\r
-struct whisper_java_params {\r
-};\r
-\r
-WHISPER_API void whisper_java_default_params(enum whisper_sampling_strategy strategy);\r
-\r
-WHISPER_API void whisper_java_init_from_file(const char * path_model);\r
-\r
-WHISPER_API int whisper_java_full(\r
-          struct whisper_context * ctx,\r
-//      struct whisper_java_params   params,\r
-                     const float * samples,\r
-                             int   n_samples);\r
-\r
-\r
-#ifdef __cplusplus\r
-}\r
-#endif\r
index f014407086be99c806533ab23f49cbc8e0edd326..9bc1a8601a923eb100b1fcd0f36cbc60a932f668 100644 (file)
@@ -1,7 +1,8 @@
 package io.github.ggerganov.whispercpp;\r
 \r
+import com.sun.jna.Native;\r
 import com.sun.jna.Pointer;\r
-import io.github.ggerganov.whispercpp.params.WhisperJavaParams;\r
+import io.github.ggerganov.whispercpp.params.WhisperFullParams;\r
 import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;\r
 \r
 import java.io.File;\r
@@ -13,8 +14,9 @@ import java.io.IOException;
  */\r
 public class WhisperCpp implements AutoCloseable {\r
     private WhisperCppJnaLibrary lib = WhisperCppJnaLibrary.instance;\r
-    private WhisperJavaJnaLibrary javaLib = WhisperJavaJnaLibrary.instance;\r
     private Pointer ctx = null;\r
+    private Pointer greedyPointer = null;\r
+    private Pointer beamPointer = null;\r
 \r
     public File modelDir() {\r
         String modelDirPath = System.getenv("XDG_CACHE_HOME");\r
@@ -27,9 +29,8 @@ public class WhisperCpp implements AutoCloseable {
 \r
     /**\r
      * @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")\r
-     * @return a Pointer to the WhisperContext\r
      */\r
-    void initContext(String modelPath) throws FileNotFoundException {\r
+    public void initContext(String modelPath) throws FileNotFoundException {\r
         if (ctx != null) {\r
             lib.whisper_free(ctx);\r
         }\r
@@ -42,7 +43,6 @@ public class WhisperCpp implements AutoCloseable {
             modelPath = new File(modelDir(), modelPath).getAbsolutePath();\r
         }\r
 \r
-        javaLib.whisper_java_init_from_file(modelPath);\r
         ctx = lib.whisper_init_from_file(modelPath);\r
 \r
         if (ctx == null) {\r
@@ -51,22 +51,38 @@ public class WhisperCpp implements AutoCloseable {
     }\r
 \r
     /**\r
-     * Initialises `whisper_full_params` internally in whisper_java.cpp so JNA doesn't have to map everything.\r
-     * `whisper_java_init_from_file()` calls `whisper_java_default_params(WHISPER_SAMPLING_GREEDY)` for convenience.\r
+     * Provides default params which can be used with `whisper_full()` etc.\r
+     * Because this function allocates memory for the params, the caller must call either:\r
+     * - call `whisper_free_params()`\r
+     * - `Native.free(Pointer.nativeValue(pointer));`\r
+     *\r
+     * @param strategy - GREEDY\r
      */\r
-    public void getDefaultJavaParams(WhisperSamplingStrategy strategy) {\r
-        javaLib.whisper_java_default_params(strategy.ordinal());\r
-//        return lib.whisper_full_default_params(strategy.value)\r
-    }\r
+    public WhisperFullParams getFullDefaultParams(WhisperSamplingStrategy strategy) {\r
+        Pointer pointer;\r
 \r
-// whisper_full_default_params was too hard to integrate with, so for now we use javaLib.whisper_java_default_params\r
-//    fun getDefaultParams(strategy: WhisperSamplingStrategy): WhisperFullParams {\r
-//        return lib.whisper_full_default_params(strategy.value)\r
-//    }\r
+        // whisper_full_default_params_by_ref allocates memory which we need to delete, so only create max 1 pointer for each strategy.\r
+        if (strategy == WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY) {\r
+            if (greedyPointer == null) {\r
+                greedyPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());\r
+            }\r
+            pointer = greedyPointer;\r
+        } else {\r
+            if (beamPointer == null) {\r
+                beamPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());\r
+            }\r
+            pointer = beamPointer;\r
+        }\r
+\r
+        WhisperFullParams params = new WhisperFullParams(pointer);\r
+        params.read();\r
+        return params;\r
+    }\r
 \r
     @Override\r
     public void close() {\r
         freeContext();\r
+        freeParams();\r
         System.out.println("Whisper closed");\r
     }\r
 \r
@@ -76,17 +92,28 @@ public class WhisperCpp implements AutoCloseable {
         }\r
     }\r
 \r
+    private void freeParams() {\r
+        if (greedyPointer != null) {\r
+            Native.free(Pointer.nativeValue(greedyPointer));\r
+            greedyPointer = null;\r
+        }\r
+        if (beamPointer != null) {\r
+            Native.free(Pointer.nativeValue(beamPointer));\r
+            beamPointer = null;\r
+        }\r
+    }\r
+\r
     /**\r
      * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.\r
      * Not thread safe for same context\r
      * Uses the specified decoding strategy to obtain the text.\r
      */\r
-    public String fullTranscribe(/*WhisperJavaParams whisperParams,*/ float[] audioData) throws IOException {\r
+    public String fullTranscribe(WhisperFullParams whisperParams, float[] audioData) throws IOException {\r
         if (ctx == null) {\r
             throw new IllegalStateException("Model not initialised");\r
         }\r
 \r
-        if (javaLib.whisper_java_full(ctx, /*whisperParams,*/ audioData, audioData.length) != 0) {\r
+        if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) {\r
             throw new IOException("Failed to process audio");\r
         }\r
 \r
index 66025657c6ad1648f3b18d56d38375ace51a2561..c1fb4f8e3b0f4e57588723c21762e9931d633f43 100644 (file)
@@ -231,10 +231,21 @@ public interface WhisperCppJnaLibrary extends Library {
     void whisper_print_timings(Pointer ctx);\r
     void whisper_reset_timings(Pointer ctx);\r
 \r
+    // Note: Even if `whisper_full_params is stripped back to just 4 ints, JNA throws "Invalid memory access"\r
+    //       when `whisper_full_default_params()` tries to return a struct.\r
+    // WhisperFullParams whisper_full_default_params(int strategy);\r
+\r
     /**\r
+     * Provides default params which can be used with `whisper_full()` etc.\r
+     * Because this function allocates memory for the params, the caller must call either:\r
+     * - call `whisper_free_params()`\r
+     * - `Native.free(Pointer.nativeValue(pointer));`\r
+     *\r
      * @param strategy - WhisperSamplingStrategy.value\r
      */\r
-    WhisperFullParams whisper_full_default_params(int strategy);\r
+    Pointer whisper_full_default_params_by_ref(int strategy);\r
+\r
+    void whisper_free_params(Pointer params);\r
 \r
     /**\r
      * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text\r
diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperJavaJnaLibrary.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperJavaJnaLibrary.java
deleted file mode 100644 (file)
index 74f8459..0000000
+++ /dev/null
@@ -1,23 +0,0 @@
-package io.github.ggerganov.whispercpp;\r
-\r
-import com.sun.jna.Library;\r
-import com.sun.jna.Native;\r
-import com.sun.jna.Pointer;\r
-import io.github.ggerganov.whispercpp.params.WhisperJavaParams;\r
-\r
-interface WhisperJavaJnaLibrary extends Library {\r
-    WhisperJavaJnaLibrary instance = Native.load("whisper_java", WhisperJavaJnaLibrary.class);\r
-\r
-    void whisper_java_default_params(int strategy);\r
-\r
-    void whisper_java_free();\r
-\r
-    void whisper_java_init_from_file(String modelPath);\r
-\r
-    /**\r
-     * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.\r
-     * Not thread safe for same context\r
-     * Uses the specified decoding strategy to obtain the text.\r
-     */\r
-    int whisper_java_full(Pointer ctx, /*WhisperJavaParams params, */float[] samples, int nSamples);\r
-}\r
index b5e9797a9230f48b0f1a791b08c9c25bdc436d2a..3d228cbeb39d010d7acd32e9550ac5d9e8745946 100644 (file)
@@ -20,5 +20,5 @@ public interface WhisperEncoderBeginCallback extends Callback {
      * @param user_data  User data.\r
      * @return True if the computation should proceed, false otherwise.\r
      */\r
-    boolean callback(WhisperContext ctx, WhisperState state, Pointer user_data);\r
+    boolean callback(Pointer ctx, Pointer state, Pointer user_data);\r
 }\r
index 5377b4ebb94a8b31a13df14b2af2f8dbd14f191a..9777c76353e39c16683f4a21d3462d94437e689a 100644 (file)
@@ -1,12 +1,9 @@
 package io.github.ggerganov.whispercpp.callbacks;\r
 \r
+import com.sun.jna.Callback;\r
 import com.sun.jna.Pointer;\r
-import io.github.ggerganov.whispercpp.WhisperContext;\r
-import io.github.ggerganov.whispercpp.model.WhisperState;\r
 import io.github.ggerganov.whispercpp.model.WhisperTokenData;\r
 \r
-import javax.security.auth.callback.Callback;\r
-\r
 /**\r
  * Callback to filter logits.\r
  * Can be used to modify the logits before sampling.\r
@@ -24,5 +21,5 @@ public interface WhisperLogitsFilterCallback extends Callback {
      * @param logits     The array of logits.\r
      * @param user_data  User data.\r
      */\r
-    void callback(WhisperContext ctx, WhisperState state, WhisperTokenData[] tokens, int n_tokens, float[] logits, Pointer user_data);\r
+    void callback(Pointer ctx, Pointer state, WhisperTokenData[] tokens, int n_tokens, float[] logits, Pointer user_data);\r
 }\r
index 95ca346721cf1f5b026755b4ef962cd39e4f43fa..27b1c6152fa97fc66ff161d53bca56d773516a98 100644 (file)
@@ -20,5 +20,5 @@ public interface WhisperNewSegmentCallback extends Callback {
      * @param n_new      The number of newly generated text segments.\r
      * @param user_data  User data.\r
      */\r
-    void callback(WhisperContext ctx, WhisperState state, int n_new, Pointer user_data);\r
+    void callback(Pointer ctx, Pointer state, int n_new, Pointer user_data);\r
 }\r
index 88662152263af9c9427163a7301ce4d7772e70bc..c64f0ab932e27d8e5a2300504558fa78fc53056b 100644 (file)
@@ -1,11 +1,10 @@
 package io.github.ggerganov.whispercpp.callbacks;\r
 \r
+import com.sun.jna.Callback;\r
 import com.sun.jna.Pointer;\r
 import io.github.ggerganov.whispercpp.WhisperContext;\r
 import io.github.ggerganov.whispercpp.model.WhisperState;\r
 \r
-import javax.security.auth.callback.Callback;\r
-\r
 /**\r
  * Callback for progress updates.\r
  */\r
@@ -19,5 +18,5 @@ public interface WhisperProgressCallback extends Callback {
      * @param progress   The progress value.\r
      * @param user_data  User data.\r
      */\r
-    void callback(WhisperContext ctx, WhisperState state, int progress, Pointer user_data);\r
+    void callback(Pointer ctx, Pointer state, int progress, Pointer user_data);\r
 }\r
diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/BeamSearchParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/BeamSearchParams.java
new file mode 100644 (file)
index 0000000..fd621dd
--- /dev/null
@@ -0,0 +1,19 @@
+package io.github.ggerganov.whispercpp.params;\r
+\r
+import com.sun.jna.Structure;\r
+\r
+import java.util.Arrays;\r
+import java.util.List;\r
+\r
+public class BeamSearchParams extends Structure {\r
+    /** ref: <a href="https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265">...</a> */\r
+    public int beam_size;\r
+\r
+    /** ref: <a href="https://arxiv.org/pdf/2204.05424.pdf">...</a> */\r
+    public float patience;\r
+\r
+    @Override\r
+    protected List<String> getFieldOrder() {\r
+        return Arrays.asList("beam_size", "patience");\r
+    }\r
+}\r
diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/CBool.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/CBool.java
new file mode 100644 (file)
index 0000000..1f6814b
--- /dev/null
@@ -0,0 +1,30 @@
+package io.github.ggerganov.whispercpp.params;\r
+\r
+import com.sun.jna.IntegerType;\r
+\r
+import java.util.function.BooleanSupplier;\r
+\r
+public class CBool extends IntegerType implements BooleanSupplier {\r
+    public static final int SIZE = 1;\r
+    public static final CBool FALSE = new CBool(0);\r
+    public static final CBool TRUE = new CBool(1);\r
+\r
+\r
+    public CBool() {\r
+        this(0);\r
+    }\r
+\r
+    public CBool(long value) {\r
+        super(SIZE, value, true);\r
+    }\r
+\r
+    @Override\r
+    public boolean getAsBoolean() {\r
+        return intValue() == 1;\r
+    }\r
+\r
+    @Override\r
+    public String toString() {\r
+        return intValue() == 1 ? "true" : "false";\r
+    }\r
+}\r
diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/GreedyParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/GreedyParams.java
new file mode 100644 (file)
index 0000000..e3b0138
--- /dev/null
@@ -0,0 +1,16 @@
+package io.github.ggerganov.whispercpp.params;\r
+\r
+import com.sun.jna.Structure;\r
+\r
+import java.util.Collections;\r
+import java.util.List;\r
+\r
+public class GreedyParams extends Structure {\r
+    /** <a href="https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264">...</a> */\r
+    public int best_of;\r
+\r
+    @Override\r
+    protected List<String> getFieldOrder() {\r
+        return Collections.singletonList("best_of");\r
+    }\r
+}\r
index ea0bccf728c1a00b05bcac5da11693ec7729bbca..07e68948ef82f6edc4924782a6275bbb2b065890 100644 (file)
@@ -1,13 +1,14 @@
 package io.github.ggerganov.whispercpp.params;\r
 \r
-import com.sun.jna.Callback;\r
-import com.sun.jna.Pointer;\r
-import com.sun.jna.Structure;\r
+import com.sun.jna.*;\r
 import io.github.ggerganov.whispercpp.callbacks.WhisperEncoderBeginCallback;\r
 import io.github.ggerganov.whispercpp.callbacks.WhisperLogitsFilterCallback;\r
 import io.github.ggerganov.whispercpp.callbacks.WhisperNewSegmentCallback;\r
 import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback;\r
 \r
+import java.util.Arrays;\r
+import java.util.List;\r
+\r
 /**\r
  * Parameters for the whisper_full() function.\r
  * If you change the order or add new parameters, make sure to update the default values in whisper.cpp:\r
@@ -15,62 +16,123 @@ import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback;
  */\r
 public class WhisperFullParams extends Structure {\r
 \r
+    public WhisperFullParams(Pointer p) {\r
+        super(p);\r
+//        super(p, ALIGN_MSVC);\r
+//        super(p, ALIGN_GNUC);\r
+    }\r
+\r
     /** Sampling strategy for whisper_full() function. */\r
     public int strategy;\r
 \r
-    /** Number of threads. */\r
+    /** Number of threads. (default = 4) */\r
     public int n_threads;\r
 \r
-    /** Maximum tokens to use from past text as a prompt for the decoder. */\r
+    /** Maximum tokens to use from past text as a prompt for the decoder. (default = 16384) */\r
     public int n_max_text_ctx;\r
 \r
-    /** Start offset in milliseconds. */\r
+    /** Start offset in milliseconds. (default = 0) */\r
     public int offset_ms;\r
 \r
-    /** Audio duration to process in milliseconds. */\r
+    /** Audio duration to process in milliseconds. (default = 0) */\r
     public int duration_ms;\r
 \r
-    /** Translate flag. */\r
-    public boolean translate;\r
+    /** Translate flag. (default = false) */\r
+    public CBool translate;\r
+\r
+    /** The compliment of translateMode() */\r
+    public void transcribeMode() {\r
+        translate = CBool.FALSE;\r
+    }\r
 \r
-    /** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. */\r
-    public boolean no_context;\r
+    /** The compliment of transcribeMode() */\r
+    public void translateMode() {\r
+        translate = CBool.TRUE;\r
+    }\r
 \r
-    /** Flag to force single segment output (useful for streaming). */\r
-    public boolean single_segment;\r
+    /** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. (default = true) */\r
+    public CBool no_context;\r
 \r
-    /** Flag to print special tokens (e.g., &lt;SOT>, &lt;EOT>, &lt;BEG>, etc.). */\r
-    public boolean print_special;\r
+    /** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. (default = true) */\r
+    public void enableContext(boolean enable) {\r
+        no_context = enable ? CBool.FALSE : CBool.TRUE;\r
+    }\r
 \r
-    /** Flag to print progress information. */\r
-    public boolean print_progress;\r
+    /** Flag to force single segment output (useful for streaming). (default = false) */\r
+    public CBool single_segment;\r
 \r
-    /** Flag to print results from within whisper.cpp (avoid it, use callback instead). */\r
-    public boolean print_realtime;\r
+    /** Flag to force single segment output (useful for streaming). (default = false) */\r
+    public void singleSegment(boolean single) {\r
+        single_segment = single ? CBool.TRUE : CBool.FALSE;\r
+    }\r
 \r
-    /** Flag to print timestamps for each text segment when printing realtime. */\r
-    public boolean print_timestamps;\r
+    /** Flag to print special tokens (e.g., &lt;SOT>, &lt;EOT>, &lt;BEG>, etc.). (default = false) */\r
+    public CBool print_special;\r
 \r
-    /** [EXPERIMENTAL] Flag to enable token-level timestamps. */\r
-    public boolean token_timestamps;\r
+    /** Flag to print special tokens (e.g., &lt;SOT>, &lt;EOT>, &lt;BEG>, etc.). (default = false) */\r
+    public void printSpecial(boolean enable) {\r
+        print_special = enable ? CBool.TRUE : CBool.FALSE;\r
+    }\r
+\r
+    /** Flag to print progress information. (default = true) */\r
+    public CBool print_progress;\r
+\r
+    /** Flag to print progress information. (default = true) */\r
+    public void printProgress(boolean enable) {\r
+        print_progress = enable ? CBool.TRUE : CBool.FALSE;\r
+    }\r
+\r
+    /** Flag to print results from within whisper.cpp (avoid it, use callback instead). (default = true) */\r
+    public CBool print_realtime;\r
+\r
+    /** Flag to print results from within whisper.cpp (avoid it, use callback instead). (default = true) */\r
+    public void printRealtime(boolean enable) {\r
+        print_realtime = enable ? CBool.TRUE : CBool.FALSE;\r
+    }\r
 \r
-    /** [EXPERIMENTAL] Timestamp token probability threshold (~0.01). */\r
+    /** Flag to print timestamps for each text segment when printing realtime. (default = true) */\r
+    public CBool print_timestamps;\r
+\r
+    /** Flag to print timestamps for each text segment when printing realtime. (default = true) */\r
+    public void printTimestamps(boolean enable) {\r
+        print_timestamps = enable ? CBool.TRUE : CBool.FALSE;\r
+    }\r
+\r
+    /** [EXPERIMENTAL] Flag to enable token-level timestamps. (default = false) */\r
+    public CBool token_timestamps;\r
+\r
+    /** [EXPERIMENTAL] Flag to enable token-level timestamps. (default = false) */\r
+    public void tokenTimestamps(boolean enable) {\r
+        token_timestamps = enable ? CBool.TRUE : CBool.FALSE;\r
+    }\r
+\r
+    /** [EXPERIMENTAL] Timestamp token probability threshold (~0.01). (default = 0.01) */\r
     public float thold_pt;\r
 \r
     /** [EXPERIMENTAL] Timestamp token sum probability threshold (~0.01). */\r
     public float thold_ptsum;\r
 \r
-    /** Maximum segment length in characters. */\r
+    /** Maximum segment length in characters. (default = 0) */\r
     public int max_len;\r
 \r
-    /** Flag to split on word rather than on token (when used with max_len). */\r
-    public boolean split_on_word;\r
+    /** Flag to split on word rather than on token (when used with max_len). (default = false) */\r
+    public CBool split_on_word;\r
+\r
+    /** Flag to split on word rather than on token (when used with max_len). (default = false) */\r
+    public void splitOnWord(boolean enable) {\r
+        split_on_word = enable ? CBool.TRUE : CBool.FALSE;\r
+    }\r
 \r
-    /** Maximum tokens per segment (0 = no limit). */\r
+    /** Maximum tokens per segment (0, default = no limit) */\r
     public int max_tokens;\r
 \r
-    /** Flag to speed up the audio by 2x using Phase Vocoder. */\r
-    public boolean speed_up;\r
+    /** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */\r
+    public CBool speed_up;\r
+\r
+    /** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */\r
+    public void speedUp(boolean enable) {\r
+        speed_up = enable ? CBool.TRUE : CBool.FALSE;\r
+    }\r
 \r
     /** Overwrite the audio context size (0 = use default). */\r
     public int audio_ctx;\r
@@ -79,9 +141,15 @@ public class WhisperFullParams extends Structure {
      * These are prepended to any existing text context from a previous call. */\r
     public String initial_prompt;\r
 \r
-    /** Prompt tokens. */\r
+    /** Prompt tokens. (int*) */\r
     public Pointer prompt_tokens;\r
 \r
+    public void setPromptTokens(int[] tokens) {\r
+        Memory mem = new Memory(tokens.length * 4L);\r
+        mem.write(0, tokens, 0, tokens.length);\r
+        prompt_tokens = mem;\r
+    }\r
+\r
     /** Number of prompt tokens. */\r
     public int prompt_n_tokens;\r
 \r
@@ -90,15 +158,29 @@ public class WhisperFullParams extends Structure {
     public String language;\r
 \r
     /** Flag to indicate whether to detect language automatically. */\r
-    public boolean detect_language;\r
+    public CBool detect_language;\r
+\r
+    /** Flag to indicate whether to detect language automatically. */\r
+    public void detectLanguage(boolean enable) {\r
+        detect_language = enable ? CBool.TRUE : CBool.FALSE;\r
+    }\r
 \r
-    /** Common decoding parameters. */\r
+    // Common decoding parameters.\r
 \r
     /** Flag to suppress blank tokens. */\r
-    public boolean suppress_blank;\r
+    public CBool suppress_blank;\r
+\r
+    public void suppressBlanks(boolean enable) {\r
+        suppress_blank = enable ? CBool.TRUE : CBool.FALSE;\r
+    }\r
+\r
+    /** Flag to suppress non-speech tokens. */\r
+    public CBool suppress_non_speech_tokens;\r
 \r
     /** Flag to suppress non-speech tokens. */\r
-    public boolean suppress_non_speech_tokens;\r
+    public void suppressNonSpeechTokens(boolean enable) {\r
+        suppress_non_speech_tokens = enable ? CBool.TRUE : CBool.FALSE;\r
+    }\r
 \r
     /** Initial decoding temperature. */\r
     public float temperature;\r
@@ -109,7 +191,7 @@ public class WhisperFullParams extends Structure {
     /** Length penalty. */\r
     public float length_penalty;\r
 \r
-    /** Fallback parameters. */\r
+    // Fallback parameters.\r
 \r
     /** Temperature increment. */\r
     public float temperature_inc;\r
@@ -123,31 +205,41 @@ public class WhisperFullParams extends Structure {
     /** No speech threshold. */\r
     public float no_speech_thold;\r
 \r
-    class GreedyParams extends Structure {\r
-        /** https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264 */\r
-        public int best_of;\r
-    }\r
-\r
     /** Greedy decoding parameters. */\r
     public GreedyParams greedy;\r
 \r
-    class BeamSearchParams extends Structure {\r
-        /** ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265 */\r
-        int beam_size;\r
-\r
-        /** ref: https://arxiv.org/pdf/2204.05424.pdf */\r
-        float patience;\r
-    }\r
-\r
     /**\r
      * Beam search decoding parameters.\r
      */\r
     public BeamSearchParams beam_search;\r
 \r
+    public void setBestOf(int bestOf) {\r
+        if (greedy == null) {\r
+            greedy = new GreedyParams();\r
+        }\r
+        greedy.best_of = bestOf;\r
+    }\r
+\r
+    public void setBeamSize(int beamSize) {\r
+        if (beam_search == null) {\r
+            beam_search = new BeamSearchParams();\r
+        }\r
+        beam_search.beam_size = beamSize;\r
+    }\r
+\r
+    public void setBeamSizeAndPatience(int beamSize, float patience) {\r
+        if (beam_search == null) {\r
+            beam_search = new BeamSearchParams();\r
+        }\r
+        beam_search.beam_size = beamSize;\r
+        beam_search.patience = patience;\r
+    }\r
+\r
     /**\r
      * Callback for every newly generated text segment.\r
+     * WhisperNewSegmentCallback\r
      */\r
-    public WhisperNewSegmentCallback new_segment_callback;\r
+    public Pointer new_segment_callback;\r
 \r
     /**\r
      * User data for the new_segment_callback.\r
@@ -156,8 +248,9 @@ public class WhisperFullParams extends Structure {
 \r
     /**\r
      * Callback on each progress update.\r
+     * WhisperProgressCallback\r
      */\r
-    public WhisperProgressCallback progress_callback;\r
+    public Pointer progress_callback;\r
 \r
     /**\r
      * User data for the progress_callback.\r
@@ -166,8 +259,9 @@ public class WhisperFullParams extends Structure {
 \r
     /**\r
      * Callback each time before the encoder starts.\r
+     * WhisperEncoderBeginCallback\r
      */\r
-    public WhisperEncoderBeginCallback encoder_begin_callback;\r
+    public Pointer encoder_begin_callback;\r
 \r
     /**\r
      * User data for the encoder_begin_callback.\r
@@ -176,12 +270,44 @@ public class WhisperFullParams extends Structure {
 \r
     /**\r
      * Callback by each decoder to filter obtained logits.\r
+     * WhisperLogitsFilterCallback\r
      */\r
-    public WhisperLogitsFilterCallback logits_filter_callback;\r
+    public Pointer logits_filter_callback;\r
 \r
     /**\r
      * User data for the logits_filter_callback.\r
      */\r
     public Pointer logits_filter_callback_user_data;\r
-}\r
 \r
+\r
+    public void setNewSegmentCallback(WhisperNewSegmentCallback callback) {\r
+        new_segment_callback = CallbackReference.getFunctionPointer(callback);\r
+    }\r
+\r
+    public void setProgressCallback(WhisperProgressCallback callback) {\r
+        progress_callback = CallbackReference.getFunctionPointer(callback);\r
+    }\r
+\r
+    public void setEncoderBeginCallbackeginCallbackCallback(WhisperEncoderBeginCallback callback) {\r
+        encoder_begin_callback = CallbackReference.getFunctionPointer(callback);\r
+    }\r
+\r
+    public void setLogitsFilterCallback(WhisperLogitsFilterCallback callback) {\r
+        logits_filter_callback = CallbackReference.getFunctionPointer(callback);\r
+    }\r
+\r
+    @Override\r
+    protected List<String> getFieldOrder() {\r
+        return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate",\r
+                "no_context", "single_segment",\r
+                "print_special", "print_progress", "print_realtime", "print_timestamps",  "token_timestamps",\r
+                "thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx",\r
+                "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",\r
+                "suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",\r
+                "temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",\r
+                "new_segment_callback", "new_segment_callback_user_data",\r
+                "progress_callback", "progress_callback_user_data",\r
+                "encoder_begin_callback", "encoder_begin_callback_user_data",\r
+                "logits_filter_callback", "logits_filter_callback_user_data");\r
+    }\r
+}\r
diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperJavaParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperJavaParams.java
deleted file mode 100644 (file)
index 728485c..0000000
+++ /dev/null
@@ -1,7 +0,0 @@
-package io.github.ggerganov.whispercpp.params;\r
-\r
-import com.sun.jna.Structure;\r
-\r
-public class WhisperJavaParams extends Structure {\r
-\r
-}\r
index 98390aa9a317d3e0988930464e84f9f7bfcb5c56..66e18f9a93647a36543fe37cebdaa2a07a47680c 100644 (file)
@@ -2,7 +2,8 @@ package io.github.ggerganov.whispercpp;
 \r
 import static org.junit.jupiter.api.Assertions.*;\r
 \r
-import io.github.ggerganov.whispercpp.params.WhisperJavaParams;\r
+import io.github.ggerganov.whispercpp.params.CBool;\r
+import io.github.ggerganov.whispercpp.params.WhisperFullParams;\r
 import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;\r
 import org.junit.jupiter.api.BeforeAll;\r
 import org.junit.jupiter.api.Test;\r
@@ -19,11 +20,11 @@ class WhisperCppTest {
     static void init() throws FileNotFoundException {\r
         // By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"\r
         // or you can provide the absolute path to the model file.\r
-        String modelName = "base.en";\r
+        String modelName = "../../models/ggml-tiny.en.bin";\r
         try {\r
             whisper.initContext(modelName);\r
-            whisper.getDefaultJavaParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);\r
-//            whisper.getDefaultJavaParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);\r
+//            whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);\r
+//            whisper.getJavaDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);\r
             modelInitialised = true;\r
         } catch (FileNotFoundException ex) {\r
             System.out.println("Model " + modelName + " not found");\r
@@ -31,11 +32,30 @@ class WhisperCppTest {
     }\r
 \r
     @Test\r
-    void testGetDefaultJavaParams() {\r
+    void testGetDefaultFullParams_BeamSearch() {\r
         // When\r
-        whisper.getDefaultJavaParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);\r
+        WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);\r
 \r
-        // Then if it doesn't throw we've connected to whisper.cpp\r
+        // Then\r
+        assertEquals(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH.ordinal(), params.strategy);\r
+        assertNotEquals(0, params.n_threads);\r
+        assertEquals(16384, params.n_max_text_ctx);\r
+        assertFalse(params.translate);\r
+        assertEquals(0.01f, params.thold_pt);\r
+        assertEquals(2, params.beam_search.beam_size);\r
+        assertEquals(-1.0f, params.beam_search.patience);\r
+    }\r
+\r
+    @Test\r
+    void testGetDefaultFullParams_Greedy() {\r
+        // When\r
+        WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);\r
+\r
+        // Then\r
+        assertEquals(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY.ordinal(), params.strategy);\r
+        assertNotEquals(0, params.n_threads);\r
+        assertEquals(16384, params.n_max_text_ctx);\r
+        assertEquals(2, params.greedy.best_of);\r
     }\r
 \r
     @Test\r
@@ -52,6 +72,13 @@ class WhisperCppTest {
         byte[] b = new byte[audioInputStream.available()];\r
         float[] floats = new float[b.length / 2];\r
 \r
+//        WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);\r
+        WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);\r
+        params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));\r
+        params.print_progress = CBool.FALSE;\r
+//        params.initial_prompt = "and so my fellow Americans um, like";\r
+\r
+\r
         try {\r
             audioInputStream.read(b);\r
 \r
@@ -61,13 +88,13 @@ class WhisperCppTest {
             }\r
 \r
             // When\r
-            String result = whisper.fullTranscribe(/*params,*/ floats);\r
+            String result = whisper.fullTranscribe(params, floats);\r
 \r
             // Then\r
-            System.out.println(result);\r
-            assertEquals("And so my fellow Americans, ask not what your country can do for you, " +\r
+            System.err.println(result);\r
+            assertEquals("And so my fellow Americans ask not what your country can do for you " +\r
                     "ask what you can do for your country.",\r
-                    result);\r
+                    result.replace(",", ""));\r
         } finally {\r
             audioInputStream.close();\r
         }\r
index 6faa3f2f543cdcc3069038b8752e7d4f4069dad0..0cdd4a1d49fe0e9cc0cdd504bd7a739602ea648b 100644 (file)
@@ -2852,6 +2852,12 @@ void whisper_free(struct whisper_context * ctx) {
     }
 }
 
+void whisper_free_params(struct whisper_full_params * params) {
+    if (params) {
+        delete params;
+    }
+}
+
 int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
     if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
         fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
@@ -3285,6 +3291,14 @@ const char * whisper_print_system_info(void) {
 
 ////////////////////////////////////////////////////////////////////////////
 
+struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy) {
+    struct whisper_full_params params = whisper_full_default_params(strategy);
+
+    struct whisper_full_params* result = new whisper_full_params();
+    *result = params;
+    return result;
+}
+
 struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
     struct whisper_full_params result = {
         /*.strategy         =*/ strategy,
index 2d5b3eb98579811e86e531d8bbab9b2d75403d9e..e983c7d4fa323f65ac9912e1e53982367cf2a8ea 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -113,6 +113,7 @@ extern "C" {
     // Frees all allocated memory
     WHISPER_API void whisper_free      (struct whisper_context * ctx);
     WHISPER_API void whisper_free_state(struct whisper_state * state);
+    WHISPER_API void whisper_free_params(struct whisper_full_params * params);
 
     // Convert RAW PCM audio to log mel spectrogram.
     // The resulting spectrogram is stored inside the default state of the provided whisper context.
@@ -409,6 +410,8 @@ extern "C" {
         void * logits_filter_callback_user_data;
     };
 
+    // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_params()
+    WHISPER_API struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy);
     WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
 
     // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text