* Java needs to call `whisper_full_default_params_by_ref()`, returning struct by val does not seem to work.
* added convenience methods to WhisperFullParams
* Remove unused WhisperJavaParams
include:
- arch: Win32
s2arc: x86
+ jnaPath: win32-x86
- arch: x64
s2arc: x64
+ jnaPath: win32-x86-64
- sdl2: ON
s2ver: 2.26.0
if: matrix.sdl2 == 'ON'
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
+ - name: Upload dll
+ uses: actions/upload-artifact@v3
+ with:
+ name: ${{ matrix.jnaPath }}_whisper.dll
+ path: build/bin/${{ matrix.build }}/whisper.dll
+
- name: Upload binaries
if: matrix.sdl2 == 'ON'
uses: actions/upload-artifact@v1
run: |
cd examples/whisper.android
./gradlew assembleRelease --no-daemon
+
+ java:
+ needs: [ 'windows' ]
+ runs-on: windows-latest
+ steps:
+ - uses: actions/checkout@v1
+
+ - name: Install Java
+ uses: actions/setup-java@v1
+ with:
+ java-version: 17
+
+ - name: Download Windows lib
+ uses: actions/download-artifact@v3
+ with:
+ name: win32-x86-64_whisper.dll
+ path: bindings/java/build/generated/resources/main/win32-x86-64
+
+ - name: Build
+ run: |
+ models\download-ggml-model.cmd tiny.en
+ cd bindings/java
+ chmod +x ./gradlew
+ ./gradlew build
+
+ - name: Upload jar
+ uses: actions/upload-artifact@v3
+ with:
+ name: whispercpp.jar
+ path: bindings/java/build/libs/whispercpp-*.jar
+
+# - name: Publish package
+# if: ${{ github.ref == 'refs/heads/master' }}
+# uses: gradle/gradle-build-action@v2
+# with:
+# arguments: publish
+# env:
+# MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }}
+# MAVEN_PASSWORD: ${{ secrets.OSSRH_TOKEN }}
+++ /dev/null
-cmake_minimum_required(VERSION 3.10)\r
-\r
-project(whisper_java VERSION 1.4.2)\r
-\r
-# Set the target name and source file/s\r
-set(TARGET_NAME whisper_java)\r
-set(SOURCES src/main/cpp/whisper_java.cpp)\r
-\r
-# include <whisper.h>\r
-include_directories(../../)\r
-\r
-# Set the output directory for the DLL/shared library based on the platform as required by JNA\r
-if(WIN32)\r
- set(OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated/resources/main/win32-x86-64)\r
-elseif(UNIX AND NOT APPLE)\r
- set(OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated/resources/main/linux-x86-64)\r
-elseif(APPLE)\r
- set(OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated/resources/main/macos-x86-64)\r
-endif()\r
-\r
-set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${OUTPUT_DIR})\r
-set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${OUTPUT_DIR})\r
-\r
-# Create the whisper_java library\r
-add_library(${TARGET_NAME} SHARED ${SOURCES})\r
-\r
-# Link against ../../build/Release/whisper.dll (or so/dynlib)\r
-target_link_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/../../../build/${CMAKE_BUILD_TYPE})\r
-target_link_libraries(${TARGET_NAME} PRIVATE whisper)\r
-\r
-# Set the appropriate compiler flags for Windows, Linux, and macOS\r
-if(WIN32)\r
- target_compile_options(${TARGET_NAME} PRIVATE /W4 /D_CRT_SECURE_NO_WARNINGS)\r
-elseif(UNIX AND NOT APPLE)\r
- target_compile_options(${TARGET_NAME} PRIVATE -Wall -Wextra)\r
-elseif(APPLE)\r
- target_compile_options(${TARGET_NAME} PRIVATE -Wall -Wextra)\r
-endif()\r
-\r
-target_compile_definitions(${TARGET_NAME} PRIVATE WHISPER_SHARED)\r
-# add_definitions(-DWHISPER_SHARED)\r
-\r
-# Force CMake to save the libs to build/generated/resources/main/${os}-${arch} as required by JNA\r
-foreach(OUTPUTCONFIG ${CMAKE_CONFIGURATION_TYPES})\r
- string(TOUPPER ${OUTPUTCONFIG} OUTPUTCONFIG)\r
- set_target_properties(${TARGET_NAME} PROPERTIES\r
- RUNTIME_OUTPUT_DIRECTORY_${OUTPUTCONFIG} ${OUTPUT_DIR}\r
- LIBRARY_OUTPUT_DIRECTORY_${OUTPUTCONFIG} ${OUTPUT_DIR}\r
- ARCHIVE_OUTPUT_DIRECTORY_${OUTPUTCONFIG} ${OUTPUT_DIR})\r
-endforeach(OUTPUTCONFIG CMAKE_CONFIGURATION_TYPES)\r
* Ubuntu on x86_64
* Windows on x86_64
-The "low level" bindings are in `WhisperCppJnaLibrary` and `WhisperJavaJnaLibrary` which caches `whisper_full_params` and `whisper_context` in `whisper_java.cpp`.
-
-There are a lot of classes in the `callbacks`, `ggml`, `model` and `params` directories but most of them have not been tested.
-
-The most simple usage is as follows:
+The "low level" bindings are in `WhisperCppJnaLibrary`. The most simple usage is as follows:
```java
import io.github.ggerganov.whispercpp.WhisperCpp;
git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp/bindings/java
-mkdir build
-pushd build
-cmake ..
-cmake --build .
-popd
-
./gradlew build
```
}\r
}\r
\r
+tasks.register('copyLibwhisperDynlib', Copy) {\r
+ from '../../build'\r
+ include 'libwhisper.dynlib'\r
+ into 'build/generated/resources/main/darwin'\r
+}\r
+\r
tasks.register('copyLibwhisperSo', Copy) {\r
from '../../build'\r
include 'libwhisper.so'\r
into 'build/generated/resources/main/windows-x86-64'\r
}\r
\r
-tasks.build.dependsOn copyLibwhisperSo, copyWhisperDll\r
+tasks.register('copyLibs') {\r
+ dependsOn copyLibwhisperDynlib, copyLibwhisperSo, copyWhisperDll\r
+}\r
\r
test {\r
systemProperty 'jna.library.path', project.file('build/generated/resources/main').absolutePath\r
+++ /dev/null
-#include <stdio.h>\r
-#include "whisper_java.h"\r
-\r
-struct whisper_full_params default_params;\r
-struct whisper_context * whisper_ctx = nullptr;\r
-\r
-struct void whisper_java_default_params(enum whisper_sampling_strategy strategy) {\r
- default_params = whisper_full_default_params(strategy);\r
-\r
-// struct whisper_java_params result = {};\r
-// return result;\r
- return;\r
-}\r
-\r
-void whisper_java_init_from_file(const char * path_model) {\r
- whisper_ctx = whisper_init_from_file(path_model);\r
- if (0 == default_params.n_threads) {\r
- whisper_java_default_params(WHISPER_SAMPLING_GREEDY);\r
- }\r
-}\r
-\r
-/** Delegates to whisper_full, but without having to pass `whisper_full_params` */\r
-int whisper_java_full(\r
- struct whisper_context * ctx,\r
-// struct whisper_java_params params,\r
- const float * samples,\r
- int n_samples) {\r
- return whisper_full(ctx, default_params, samples, n_samples);\r
-}\r
-\r
-void whisper_java_free() {\r
-// free(default_params);\r
-}\r
+++ /dev/null
-#define WHISPER_BUILD\r
-#include <whisper.h>\r
-\r
-#ifdef __cplusplus\r
-extern "C" {\r
-#endif\r
-\r
-struct whisper_java_params {\r
-};\r
-\r
-WHISPER_API void whisper_java_default_params(enum whisper_sampling_strategy strategy);\r
-\r
-WHISPER_API void whisper_java_init_from_file(const char * path_model);\r
-\r
-WHISPER_API int whisper_java_full(\r
- struct whisper_context * ctx,\r
-// struct whisper_java_params params,\r
- const float * samples,\r
- int n_samples);\r
-\r
-\r
-#ifdef __cplusplus\r
-}\r
-#endif\r
package io.github.ggerganov.whispercpp;\r
\r
+import com.sun.jna.Native;\r
import com.sun.jna.Pointer;\r
-import io.github.ggerganov.whispercpp.params.WhisperJavaParams;\r
+import io.github.ggerganov.whispercpp.params.WhisperFullParams;\r
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;\r
\r
import java.io.File;\r
*/\r
public class WhisperCpp implements AutoCloseable {\r
private WhisperCppJnaLibrary lib = WhisperCppJnaLibrary.instance;\r
- private WhisperJavaJnaLibrary javaLib = WhisperJavaJnaLibrary.instance;\r
private Pointer ctx = null;\r
+ private Pointer greedyPointer = null;\r
+ private Pointer beamPointer = null;\r
\r
public File modelDir() {\r
String modelDirPath = System.getenv("XDG_CACHE_HOME");\r
\r
/**\r
* @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")\r
- * @return a Pointer to the WhisperContext\r
*/\r
- void initContext(String modelPath) throws FileNotFoundException {\r
+ public void initContext(String modelPath) throws FileNotFoundException {\r
if (ctx != null) {\r
lib.whisper_free(ctx);\r
}\r
modelPath = new File(modelDir(), modelPath).getAbsolutePath();\r
}\r
\r
- javaLib.whisper_java_init_from_file(modelPath);\r
ctx = lib.whisper_init_from_file(modelPath);\r
\r
if (ctx == null) {\r
}\r
\r
/**\r
- * Initialises `whisper_full_params` internally in whisper_java.cpp so JNA doesn't have to map everything.\r
- * `whisper_java_init_from_file()` calls `whisper_java_default_params(WHISPER_SAMPLING_GREEDY)` for convenience.\r
+ * Provides default params which can be used with `whisper_full()` etc.\r
+ * Because this function allocates memory for the params, the caller must call either:\r
+ * - call `whisper_free_params()`\r
+ * - `Native.free(Pointer.nativeValue(pointer));`\r
+ *\r
+ * @param strategy - GREEDY\r
*/\r
- public void getDefaultJavaParams(WhisperSamplingStrategy strategy) {\r
- javaLib.whisper_java_default_params(strategy.ordinal());\r
-// return lib.whisper_full_default_params(strategy.value)\r
- }\r
+ public WhisperFullParams getFullDefaultParams(WhisperSamplingStrategy strategy) {\r
+ Pointer pointer;\r
\r
-// whisper_full_default_params was too hard to integrate with, so for now we use javaLib.whisper_java_default_params\r
-// fun getDefaultParams(strategy: WhisperSamplingStrategy): WhisperFullParams {\r
-// return lib.whisper_full_default_params(strategy.value)\r
-// }\r
+ // whisper_full_default_params_by_ref allocates memory which we need to delete, so only create max 1 pointer for each strategy.\r
+ if (strategy == WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY) {\r
+ if (greedyPointer == null) {\r
+ greedyPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());\r
+ }\r
+ pointer = greedyPointer;\r
+ } else {\r
+ if (beamPointer == null) {\r
+ beamPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());\r
+ }\r
+ pointer = beamPointer;\r
+ }\r
+\r
+ WhisperFullParams params = new WhisperFullParams(pointer);\r
+ params.read();\r
+ return params;\r
+ }\r
\r
@Override\r
public void close() {\r
freeContext();\r
+ freeParams();\r
System.out.println("Whisper closed");\r
}\r
\r
}\r
}\r
\r
+ private void freeParams() {\r
+ if (greedyPointer != null) {\r
+ Native.free(Pointer.nativeValue(greedyPointer));\r
+ greedyPointer = null;\r
+ }\r
+ if (beamPointer != null) {\r
+ Native.free(Pointer.nativeValue(beamPointer));\r
+ beamPointer = null;\r
+ }\r
+ }\r
+\r
/**\r
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.\r
* Not thread safe for same context\r
* Uses the specified decoding strategy to obtain the text.\r
*/\r
- public String fullTranscribe(/*WhisperJavaParams whisperParams,*/ float[] audioData) throws IOException {\r
+ public String fullTranscribe(WhisperFullParams whisperParams, float[] audioData) throws IOException {\r
if (ctx == null) {\r
throw new IllegalStateException("Model not initialised");\r
}\r
\r
- if (javaLib.whisper_java_full(ctx, /*whisperParams,*/ audioData, audioData.length) != 0) {\r
+ if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) {\r
throw new IOException("Failed to process audio");\r
}\r
\r
void whisper_print_timings(Pointer ctx);\r
void whisper_reset_timings(Pointer ctx);\r
\r
+ // Note: Even if `whisper_full_params is stripped back to just 4 ints, JNA throws "Invalid memory access"\r
+ // when `whisper_full_default_params()` tries to return a struct.\r
+ // WhisperFullParams whisper_full_default_params(int strategy);\r
+\r
/**\r
+ * Provides default params which can be used with `whisper_full()` etc.\r
+ * Because this function allocates memory for the params, the caller must call either:\r
+ * - call `whisper_free_params()`\r
+ * - `Native.free(Pointer.nativeValue(pointer));`\r
+ *\r
* @param strategy - WhisperSamplingStrategy.value\r
*/\r
- WhisperFullParams whisper_full_default_params(int strategy);\r
+ Pointer whisper_full_default_params_by_ref(int strategy);\r
+\r
+ void whisper_free_params(Pointer params);\r
\r
/**\r
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text\r
+++ /dev/null
-package io.github.ggerganov.whispercpp;\r
-\r
-import com.sun.jna.Library;\r
-import com.sun.jna.Native;\r
-import com.sun.jna.Pointer;\r
-import io.github.ggerganov.whispercpp.params.WhisperJavaParams;\r
-\r
-interface WhisperJavaJnaLibrary extends Library {\r
- WhisperJavaJnaLibrary instance = Native.load("whisper_java", WhisperJavaJnaLibrary.class);\r
-\r
- void whisper_java_default_params(int strategy);\r
-\r
- void whisper_java_free();\r
-\r
- void whisper_java_init_from_file(String modelPath);\r
-\r
- /**\r
- * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.\r
- * Not thread safe for same context\r
- * Uses the specified decoding strategy to obtain the text.\r
- */\r
- int whisper_java_full(Pointer ctx, /*WhisperJavaParams params, */float[] samples, int nSamples);\r
-}\r
* @param user_data User data.\r
* @return True if the computation should proceed, false otherwise.\r
*/\r
- boolean callback(WhisperContext ctx, WhisperState state, Pointer user_data);\r
+ boolean callback(Pointer ctx, Pointer state, Pointer user_data);\r
}\r
package io.github.ggerganov.whispercpp.callbacks;\r
\r
+import com.sun.jna.Callback;\r
import com.sun.jna.Pointer;\r
-import io.github.ggerganov.whispercpp.WhisperContext;\r
-import io.github.ggerganov.whispercpp.model.WhisperState;\r
import io.github.ggerganov.whispercpp.model.WhisperTokenData;\r
\r
-import javax.security.auth.callback.Callback;\r
-\r
/**\r
* Callback to filter logits.\r
* Can be used to modify the logits before sampling.\r
* @param logits The array of logits.\r
* @param user_data User data.\r
*/\r
- void callback(WhisperContext ctx, WhisperState state, WhisperTokenData[] tokens, int n_tokens, float[] logits, Pointer user_data);\r
+ void callback(Pointer ctx, Pointer state, WhisperTokenData[] tokens, int n_tokens, float[] logits, Pointer user_data);\r
}\r
* @param n_new The number of newly generated text segments.\r
* @param user_data User data.\r
*/\r
- void callback(WhisperContext ctx, WhisperState state, int n_new, Pointer user_data);\r
+ void callback(Pointer ctx, Pointer state, int n_new, Pointer user_data);\r
}\r
package io.github.ggerganov.whispercpp.callbacks;\r
\r
+import com.sun.jna.Callback;\r
import com.sun.jna.Pointer;\r
import io.github.ggerganov.whispercpp.WhisperContext;\r
import io.github.ggerganov.whispercpp.model.WhisperState;\r
\r
-import javax.security.auth.callback.Callback;\r
-\r
/**\r
* Callback for progress updates.\r
*/\r
* @param progress The progress value.\r
* @param user_data User data.\r
*/\r
- void callback(WhisperContext ctx, WhisperState state, int progress, Pointer user_data);\r
+ void callback(Pointer ctx, Pointer state, int progress, Pointer user_data);\r
}\r
--- /dev/null
+package io.github.ggerganov.whispercpp.params;\r
+\r
+import com.sun.jna.Structure;\r
+\r
+import java.util.Arrays;\r
+import java.util.List;\r
+\r
+public class BeamSearchParams extends Structure {\r
+ /** ref: <a href="https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265">...</a> */\r
+ public int beam_size;\r
+\r
+ /** ref: <a href="https://arxiv.org/pdf/2204.05424.pdf">...</a> */\r
+ public float patience;\r
+\r
+ @Override\r
+ protected List<String> getFieldOrder() {\r
+ return Arrays.asList("beam_size", "patience");\r
+ }\r
+}\r
--- /dev/null
+package io.github.ggerganov.whispercpp.params;\r
+\r
+import com.sun.jna.IntegerType;\r
+\r
+import java.util.function.BooleanSupplier;\r
+\r
+public class CBool extends IntegerType implements BooleanSupplier {\r
+ public static final int SIZE = 1;\r
+ public static final CBool FALSE = new CBool(0);\r
+ public static final CBool TRUE = new CBool(1);\r
+\r
+\r
+ public CBool() {\r
+ this(0);\r
+ }\r
+\r
+ public CBool(long value) {\r
+ super(SIZE, value, true);\r
+ }\r
+\r
+ @Override\r
+ public boolean getAsBoolean() {\r
+ return intValue() == 1;\r
+ }\r
+\r
+ @Override\r
+ public String toString() {\r
+ return intValue() == 1 ? "true" : "false";\r
+ }\r
+}\r
--- /dev/null
+package io.github.ggerganov.whispercpp.params;\r
+\r
+import com.sun.jna.Structure;\r
+\r
+import java.util.Collections;\r
+import java.util.List;\r
+\r
+public class GreedyParams extends Structure {\r
+ /** <a href="https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264">...</a> */\r
+ public int best_of;\r
+\r
+ @Override\r
+ protected List<String> getFieldOrder() {\r
+ return Collections.singletonList("best_of");\r
+ }\r
+}\r
package io.github.ggerganov.whispercpp.params;\r
\r
-import com.sun.jna.Callback;\r
-import com.sun.jna.Pointer;\r
-import com.sun.jna.Structure;\r
+import com.sun.jna.*;\r
import io.github.ggerganov.whispercpp.callbacks.WhisperEncoderBeginCallback;\r
import io.github.ggerganov.whispercpp.callbacks.WhisperLogitsFilterCallback;\r
import io.github.ggerganov.whispercpp.callbacks.WhisperNewSegmentCallback;\r
import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback;\r
\r
+import java.util.Arrays;\r
+import java.util.List;\r
+\r
/**\r
* Parameters for the whisper_full() function.\r
* If you change the order or add new parameters, make sure to update the default values in whisper.cpp:\r
*/\r
public class WhisperFullParams extends Structure {\r
\r
+ public WhisperFullParams(Pointer p) {\r
+ super(p);\r
+// super(p, ALIGN_MSVC);\r
+// super(p, ALIGN_GNUC);\r
+ }\r
+\r
/** Sampling strategy for whisper_full() function. */\r
public int strategy;\r
\r
- /** Number of threads. */\r
+ /** Number of threads. (default = 4) */\r
public int n_threads;\r
\r
- /** Maximum tokens to use from past text as a prompt for the decoder. */\r
+ /** Maximum tokens to use from past text as a prompt for the decoder. (default = 16384) */\r
public int n_max_text_ctx;\r
\r
- /** Start offset in milliseconds. */\r
+ /** Start offset in milliseconds. (default = 0) */\r
public int offset_ms;\r
\r
- /** Audio duration to process in milliseconds. */\r
+ /** Audio duration to process in milliseconds. (default = 0) */\r
public int duration_ms;\r
\r
- /** Translate flag. */\r
- public boolean translate;\r
+ /** Translate flag. (default = false) */\r
+ public CBool translate;\r
+\r
+ /** The compliment of translateMode() */\r
+ public void transcribeMode() {\r
+ translate = CBool.FALSE;\r
+ }\r
\r
- /** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. */\r
- public boolean no_context;\r
+ /** The compliment of transcribeMode() */\r
+ public void translateMode() {\r
+ translate = CBool.TRUE;\r
+ }\r
\r
- /** Flag to force single segment output (useful for streaming). */\r
- public boolean single_segment;\r
+ /** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. (default = true) */\r
+ public CBool no_context;\r
\r
- /** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). */\r
- public boolean print_special;\r
+ /** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. (default = true) */\r
+ public void enableContext(boolean enable) {\r
+ no_context = enable ? CBool.FALSE : CBool.TRUE;\r
+ }\r
\r
- /** Flag to print progress information. */\r
- public boolean print_progress;\r
+ /** Flag to force single segment output (useful for streaming). (default = false) */\r
+ public CBool single_segment;\r
\r
- /** Flag to print results from within whisper.cpp (avoid it, use callback instead). */\r
- public boolean print_realtime;\r
+ /** Flag to force single segment output (useful for streaming). (default = false) */\r
+ public void singleSegment(boolean single) {\r
+ single_segment = single ? CBool.TRUE : CBool.FALSE;\r
+ }\r
\r
- /** Flag to print timestamps for each text segment when printing realtime. */\r
- public boolean print_timestamps;\r
+ /** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). (default = false) */\r
+ public CBool print_special;\r
\r
- /** [EXPERIMENTAL] Flag to enable token-level timestamps. */\r
- public boolean token_timestamps;\r
+ /** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). (default = false) */\r
+ public void printSpecial(boolean enable) {\r
+ print_special = enable ? CBool.TRUE : CBool.FALSE;\r
+ }\r
+\r
+ /** Flag to print progress information. (default = true) */\r
+ public CBool print_progress;\r
+\r
+ /** Flag to print progress information. (default = true) */\r
+ public void printProgress(boolean enable) {\r
+ print_progress = enable ? CBool.TRUE : CBool.FALSE;\r
+ }\r
+\r
+ /** Flag to print results from within whisper.cpp (avoid it, use callback instead). (default = true) */\r
+ public CBool print_realtime;\r
+\r
+ /** Flag to print results from within whisper.cpp (avoid it, use callback instead). (default = true) */\r
+ public void printRealtime(boolean enable) {\r
+ print_realtime = enable ? CBool.TRUE : CBool.FALSE;\r
+ }\r
\r
- /** [EXPERIMENTAL] Timestamp token probability threshold (~0.01). */\r
+ /** Flag to print timestamps for each text segment when printing realtime. (default = true) */\r
+ public CBool print_timestamps;\r
+\r
+ /** Flag to print timestamps for each text segment when printing realtime. (default = true) */\r
+ public void printTimestamps(boolean enable) {\r
+ print_timestamps = enable ? CBool.TRUE : CBool.FALSE;\r
+ }\r
+\r
+ /** [EXPERIMENTAL] Flag to enable token-level timestamps. (default = false) */\r
+ public CBool token_timestamps;\r
+\r
+ /** [EXPERIMENTAL] Flag to enable token-level timestamps. (default = false) */\r
+ public void tokenTimestamps(boolean enable) {\r
+ token_timestamps = enable ? CBool.TRUE : CBool.FALSE;\r
+ }\r
+\r
+ /** [EXPERIMENTAL] Timestamp token probability threshold (~0.01). (default = 0.01) */\r
public float thold_pt;\r
\r
/** [EXPERIMENTAL] Timestamp token sum probability threshold (~0.01). */\r
public float thold_ptsum;\r
\r
- /** Maximum segment length in characters. */\r
+ /** Maximum segment length in characters. (default = 0) */\r
public int max_len;\r
\r
- /** Flag to split on word rather than on token (when used with max_len). */\r
- public boolean split_on_word;\r
+ /** Flag to split on word rather than on token (when used with max_len). (default = false) */\r
+ public CBool split_on_word;\r
+\r
+ /** Flag to split on word rather than on token (when used with max_len). (default = false) */\r
+ public void splitOnWord(boolean enable) {\r
+ split_on_word = enable ? CBool.TRUE : CBool.FALSE;\r
+ }\r
\r
- /** Maximum tokens per segment (0 = no limit). */\r
+ /** Maximum tokens per segment (0, default = no limit) */\r
public int max_tokens;\r
\r
- /** Flag to speed up the audio by 2x using Phase Vocoder. */\r
- public boolean speed_up;\r
+ /** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */\r
+ public CBool speed_up;\r
+\r
+ /** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */\r
+ public void speedUp(boolean enable) {\r
+ speed_up = enable ? CBool.TRUE : CBool.FALSE;\r
+ }\r
\r
/** Overwrite the audio context size (0 = use default). */\r
public int audio_ctx;\r
* These are prepended to any existing text context from a previous call. */\r
public String initial_prompt;\r
\r
- /** Prompt tokens. */\r
+ /** Prompt tokens. (int*) */\r
public Pointer prompt_tokens;\r
\r
+ public void setPromptTokens(int[] tokens) {\r
+ Memory mem = new Memory(tokens.length * 4L);\r
+ mem.write(0, tokens, 0, tokens.length);\r
+ prompt_tokens = mem;\r
+ }\r
+\r
/** Number of prompt tokens. */\r
public int prompt_n_tokens;\r
\r
public String language;\r
\r
/** Flag to indicate whether to detect language automatically. */\r
- public boolean detect_language;\r
+ public CBool detect_language;\r
+\r
+ /** Flag to indicate whether to detect language automatically. */\r
+ public void detectLanguage(boolean enable) {\r
+ detect_language = enable ? CBool.TRUE : CBool.FALSE;\r
+ }\r
\r
- /** Common decoding parameters. */\r
+ // Common decoding parameters.\r
\r
/** Flag to suppress blank tokens. */\r
- public boolean suppress_blank;\r
+ public CBool suppress_blank;\r
+\r
+ public void suppressBlanks(boolean enable) {\r
+ suppress_blank = enable ? CBool.TRUE : CBool.FALSE;\r
+ }\r
+\r
+ /** Flag to suppress non-speech tokens. */\r
+ public CBool suppress_non_speech_tokens;\r
\r
/** Flag to suppress non-speech tokens. */\r
- public boolean suppress_non_speech_tokens;\r
+ public void suppressNonSpeechTokens(boolean enable) {\r
+ suppress_non_speech_tokens = enable ? CBool.TRUE : CBool.FALSE;\r
+ }\r
\r
/** Initial decoding temperature. */\r
public float temperature;\r
/** Length penalty. */\r
public float length_penalty;\r
\r
- /** Fallback parameters. */\r
+ // Fallback parameters.\r
\r
/** Temperature increment. */\r
public float temperature_inc;\r
/** No speech threshold. */\r
public float no_speech_thold;\r
\r
- class GreedyParams extends Structure {\r
- /** https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264 */\r
- public int best_of;\r
- }\r
-\r
/** Greedy decoding parameters. */\r
public GreedyParams greedy;\r
\r
- class BeamSearchParams extends Structure {\r
- /** ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265 */\r
- int beam_size;\r
-\r
- /** ref: https://arxiv.org/pdf/2204.05424.pdf */\r
- float patience;\r
- }\r
-\r
/**\r
* Beam search decoding parameters.\r
*/\r
public BeamSearchParams beam_search;\r
\r
+ public void setBestOf(int bestOf) {\r
+ if (greedy == null) {\r
+ greedy = new GreedyParams();\r
+ }\r
+ greedy.best_of = bestOf;\r
+ }\r
+\r
+ public void setBeamSize(int beamSize) {\r
+ if (beam_search == null) {\r
+ beam_search = new BeamSearchParams();\r
+ }\r
+ beam_search.beam_size = beamSize;\r
+ }\r
+\r
+ public void setBeamSizeAndPatience(int beamSize, float patience) {\r
+ if (beam_search == null) {\r
+ beam_search = new BeamSearchParams();\r
+ }\r
+ beam_search.beam_size = beamSize;\r
+ beam_search.patience = patience;\r
+ }\r
+\r
/**\r
* Callback for every newly generated text segment.\r
+ * WhisperNewSegmentCallback\r
*/\r
- public WhisperNewSegmentCallback new_segment_callback;\r
+ public Pointer new_segment_callback;\r
\r
/**\r
* User data for the new_segment_callback.\r
\r
/**\r
* Callback on each progress update.\r
+ * WhisperProgressCallback\r
*/\r
- public WhisperProgressCallback progress_callback;\r
+ public Pointer progress_callback;\r
\r
/**\r
* User data for the progress_callback.\r
\r
/**\r
* Callback each time before the encoder starts.\r
+ * WhisperEncoderBeginCallback\r
*/\r
- public WhisperEncoderBeginCallback encoder_begin_callback;\r
+ public Pointer encoder_begin_callback;\r
\r
/**\r
* User data for the encoder_begin_callback.\r
\r
/**\r
* Callback by each decoder to filter obtained logits.\r
+ * WhisperLogitsFilterCallback\r
*/\r
- public WhisperLogitsFilterCallback logits_filter_callback;\r
+ public Pointer logits_filter_callback;\r
\r
/**\r
* User data for the logits_filter_callback.\r
*/\r
public Pointer logits_filter_callback_user_data;\r
-}\r
\r
+\r
+ public void setNewSegmentCallback(WhisperNewSegmentCallback callback) {\r
+ new_segment_callback = CallbackReference.getFunctionPointer(callback);\r
+ }\r
+\r
+ public void setProgressCallback(WhisperProgressCallback callback) {\r
+ progress_callback = CallbackReference.getFunctionPointer(callback);\r
+ }\r
+\r
+ public void setEncoderBeginCallbackeginCallbackCallback(WhisperEncoderBeginCallback callback) {\r
+ encoder_begin_callback = CallbackReference.getFunctionPointer(callback);\r
+ }\r
+\r
+ public void setLogitsFilterCallback(WhisperLogitsFilterCallback callback) {\r
+ logits_filter_callback = CallbackReference.getFunctionPointer(callback);\r
+ }\r
+\r
+ @Override\r
+ protected List<String> getFieldOrder() {\r
+ return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate",\r
+ "no_context", "single_segment",\r
+ "print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",\r
+ "thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx",\r
+ "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",\r
+ "suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",\r
+ "temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",\r
+ "new_segment_callback", "new_segment_callback_user_data",\r
+ "progress_callback", "progress_callback_user_data",\r
+ "encoder_begin_callback", "encoder_begin_callback_user_data",\r
+ "logits_filter_callback", "logits_filter_callback_user_data");\r
+ }\r
+}\r
+++ /dev/null
-package io.github.ggerganov.whispercpp.params;\r
-\r
-import com.sun.jna.Structure;\r
-\r
-public class WhisperJavaParams extends Structure {\r
-\r
-}\r
\r
import static org.junit.jupiter.api.Assertions.*;\r
\r
-import io.github.ggerganov.whispercpp.params.WhisperJavaParams;\r
+import io.github.ggerganov.whispercpp.params.CBool;\r
+import io.github.ggerganov.whispercpp.params.WhisperFullParams;\r
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;\r
import org.junit.jupiter.api.BeforeAll;\r
import org.junit.jupiter.api.Test;\r
static void init() throws FileNotFoundException {\r
// By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"\r
// or you can provide the absolute path to the model file.\r
- String modelName = "base.en";\r
+ String modelName = "../../models/ggml-tiny.en.bin";\r
try {\r
whisper.initContext(modelName);\r
- whisper.getDefaultJavaParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);\r
-// whisper.getDefaultJavaParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);\r
+// whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);\r
+// whisper.getJavaDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);\r
modelInitialised = true;\r
} catch (FileNotFoundException ex) {\r
System.out.println("Model " + modelName + " not found");\r
}\r
\r
@Test\r
- void testGetDefaultJavaParams() {\r
+ void testGetDefaultFullParams_BeamSearch() {\r
// When\r
- whisper.getDefaultJavaParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);\r
+ WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);\r
\r
- // Then if it doesn't throw we've connected to whisper.cpp\r
+ // Then\r
+ assertEquals(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH.ordinal(), params.strategy);\r
+ assertNotEquals(0, params.n_threads);\r
+ assertEquals(16384, params.n_max_text_ctx);\r
+ assertFalse(params.translate);\r
+ assertEquals(0.01f, params.thold_pt);\r
+ assertEquals(2, params.beam_search.beam_size);\r
+ assertEquals(-1.0f, params.beam_search.patience);\r
+ }\r
+\r
+ @Test\r
+ void testGetDefaultFullParams_Greedy() {\r
+ // When\r
+ WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);\r
+\r
+ // Then\r
+ assertEquals(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY.ordinal(), params.strategy);\r
+ assertNotEquals(0, params.n_threads);\r
+ assertEquals(16384, params.n_max_text_ctx);\r
+ assertEquals(2, params.greedy.best_of);\r
}\r
\r
@Test\r
byte[] b = new byte[audioInputStream.available()];\r
float[] floats = new float[b.length / 2];\r
\r
+// WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);\r
+ WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);\r
+ params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));\r
+ params.print_progress = CBool.FALSE;\r
+// params.initial_prompt = "and so my fellow Americans um, like";\r
+\r
+\r
try {\r
audioInputStream.read(b);\r
\r
}\r
\r
// When\r
- String result = whisper.fullTranscribe(/*params,*/ floats);\r
+ String result = whisper.fullTranscribe(params, floats);\r
\r
// Then\r
- System.out.println(result);\r
- assertEquals("And so my fellow Americans, ask not what your country can do for you, " +\r
+ System.err.println(result);\r
+ assertEquals("And so my fellow Americans ask not what your country can do for you " +\r
"ask what you can do for your country.",\r
- result);\r
+ result.replace(",", ""));\r
} finally {\r
audioInputStream.close();\r
}\r
}
}
+void whisper_free_params(struct whisper_full_params * params) {
+ if (params) {
+ delete params;
+ }
+}
+
int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
////////////////////////////////////////////////////////////////////////////
+struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy) {
+ struct whisper_full_params params = whisper_full_default_params(strategy);
+
+ struct whisper_full_params* result = new whisper_full_params();
+ *result = params;
+ return result;
+}
+
struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
struct whisper_full_params result = {
/*.strategy =*/ strategy,
// Frees all allocated memory
WHISPER_API void whisper_free (struct whisper_context * ctx);
WHISPER_API void whisper_free_state(struct whisper_state * state);
+ WHISPER_API void whisper_free_params(struct whisper_full_params * params);
// Convert RAW PCM audio to log mel spectrogram.
// The resulting spectrogram is stored inside the default state of the provided whisper context.
void * logits_filter_callback_user_data;
};
+ // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_params()
+ WHISPER_API struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy);
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text