}\r
\r
tasks.register('copyLibwhisperDynlib', Copy) {\r
- from '../../build'\r
- include 'libwhisper.dynlib'\r
+ from '../../build/src'\r
+ include 'libwhisper.dylib'\r
into 'build/generated/resources/main/darwin'\r
}\r
\r
tasks.register('copyLibwhisperSo', Copy) {\r
- from '../../build'\r
+ from '../../build/src'\r
include 'libwhisper.so'\r
into 'build/generated/resources/main/linux-x86-64'\r
}\r
withJavadocJar()\r
}\r
\r
+sourcesJar() {\r
+ dependsOn copyLibs\r
+}\r
+\r
jar {\r
+ dependsOn copyLibs\r
exclude '**/whisper_java.exp', '**/whisper_java.lib'\r
}\r
\r
useJUnitPlatform()\r
}\r
\r
+test.dependsOn copyLibs\r
+processResources.dependsOn copyLibs\r
+\r
dependencies {\r
implementation "net.java.dev.jna:jna:5.13.0"\r
testImplementation "org.junit.jupiter:junit-jupiter:5.9.2"\r
--- /dev/null
+package io.github.ggerganov.whispercpp;
+
+/**
+ * Presets for alignment heads in DTW token timestamps
+ */
+public class WhisperConstants {
+ // Alignment heads presets
+ public static final int WHISPER_AHEADS_NONE = 0;
+ public static final int WHISPER_AHEADS_TINY_EN = 1;
+ public static final int WHISPER_AHEADS_TINY = 2;
+ public static final int WHISPER_AHEADS_BASE_EN = 3;
+ public static final int WHISPER_AHEADS_BASE = 4;
+ public static final int WHISPER_AHEADS_SMALL_EN = 5;
+ public static final int WHISPER_AHEADS_SMALL = 6;
+ public static final int WHISPER_AHEADS_MEDIUM_EN = 7;
+ public static final int WHISPER_AHEADS_MEDIUM = 8;
+ public static final int WHISPER_AHEADS_LARGE_V1 = 9;
+ public static final int WHISPER_AHEADS_LARGE_V2 = 10;
+ public static final int WHISPER_AHEADS_LARGE_V3 = 11;
+ public static final int WHISPER_AHEADS_LARGE_V3_TURBO = 12;
+ public static final int WHISPER_AHEADS_CUSTOM = 13;
+ public static final int WHISPER_AHEADS_N_TOP_MOST = 14;
+ public static final int WHISPER_AHEADS_COUNT = 15;
+}
package io.github.ggerganov.whispercpp;\r
\r
+import com.sun.jna.NativeLong;\r
import com.sun.jna.Structure;\r
import com.sun.jna.ptr.PointerByReference;\r
+import com.sun.jna.Pointer;\r
import io.github.ggerganov.whispercpp.ggml.GgmlType;\r
import io.github.ggerganov.whispercpp.WhisperModel;\r
import io.github.ggerganov.whispercpp.params.WhisperContextParams;\r
import java.util.List;\r
\r
public class WhisperContext extends Structure {\r
- int t_load_us = 0;\r
- int t_start_us = 0;\r
+ public NativeLong t_load_us;\r
+ public NativeLong t_start_us;\r
\r
/** weight type (FP32 / FP16 / QX) */\r
- GgmlType wtype = GgmlType.GGML_TYPE_F16;\r
+ public GgmlType wtype = GgmlType.GGML_TYPE_F16;\r
/** intermediate type (FP32 or FP16) */\r
- GgmlType itype = GgmlType.GGML_TYPE_F16;\r
+ public GgmlType itype = GgmlType.GGML_TYPE_F16;\r
\r
-// WhisperModel model;\r
- public PointerByReference model;\r
-// whisper_vocab vocab;\r
-// whisper_state * state = nullptr;\r
- public PointerByReference vocab;\r
- public PointerByReference state;\r
+ public WhisperContextParams.ByValue params;\r
+\r
+ public Pointer model;\r
+ public Pointer vocab;\r
+ public Pointer state;\r
\r
/** populated by whisper_init_from_file_with_params() */\r
- String path_model;\r
- WhisperContextParams params;\r
-\r
-// public static class ByReference extends WhisperContext implements Structure.ByReference {\r
-// }\r
-//\r
-// public static class ByValue extends WhisperContext implements Structure.ByValue {\r
-// }\r
-//\r
-// @Override\r
-// protected List<String> getFieldOrder() {\r
-// return List.of("t_load_us", "t_start_us", "wtype", "itype", "model", "vocab", "state", "path_model");\r
-// }\r
+ public Pointer path_model;\r
+\r
+ @Override\r
+ protected List<String> getFieldOrder() {\r
+ return List.of("t_load_us", "t_start_us", "wtype", "itype",\r
+ "params", "model", "vocab", "state", "path_model");\r
+ }\r
}\r
* @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")\r
* @param params - params to use when initialising the context\r
*/\r
- public void initContext(String modelPath, WhisperContextParams params) throws FileNotFoundException {\r
+ public void initContext(String modelPath, WhisperContextParams.ByValue params) throws FileNotFoundException {\r
initContextImpl(modelPath, params);\r
}\r
\r
- private void initContextImpl(String modelPath, WhisperContextParams params) throws FileNotFoundException {\r
+ private void initContextImpl(String modelPath, WhisperContextParams.ByValue params) throws FileNotFoundException {\r
if (ctx != null) {\r
lib.whisper_free(ctx);\r
}\r
\r
/**\r
* Provides default params which can be used with `whisper_init_from_file_with_params()` etc.\r
- * Because this function allocates memory for the params, the caller must call either:\r
- * - call `whisper_free_context_params()`\r
- * - `Native.free(Pointer.nativeValue(pointer));`\r
+ * Returns a ByValue instance to ensure proper parameter passing to native code.\r
*/\r
- public WhisperContextParams getContextDefaultParams() {\r
- paramsPointer = lib.whisper_context_default_params_by_ref();\r
- WhisperContextParams params = new WhisperContextParams(paramsPointer);\r
- params.read();\r
- return params;\r
+ public WhisperContextParams.ByValue getContextDefaultParams() {\r
+ WhisperContextParams.ByValue valueParams = new WhisperContextParams.ByValue(\r
+ lib.whisper_context_default_params_by_ref());\r
+ valueParams.read();\r
+ return valueParams;\r
}\r
\r
/**\r
*\r
* @param strategy - GREEDY\r
*/\r
- public WhisperFullParams getFullDefaultParams(WhisperSamplingStrategy strategy) {\r
+ public WhisperFullParams.ByValue getFullDefaultParams(WhisperSamplingStrategy strategy) {\r
Pointer pointer;\r
\r
// whisper_full_default_params_by_ref allocates memory which we need to delete, so only create max 1 pointer for each strategy.\r
pointer = beamParamsPointer;\r
}\r
\r
- WhisperFullParams params = new WhisperFullParams(pointer);\r
+ WhisperFullParams.ByValue params = new WhisperFullParams.ByValue(pointer);\r
params.read();\r
return params;\r
}\r
}\r
\r
/**\r
- * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.\r
+ * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.\r
* Not thread safe for same context\r
* Uses the specified decoding strategy to obtain the text.\r
*/\r
- public String fullTranscribe(WhisperFullParams whisperParams, float[] audioData) throws IOException {\r
+ public String fullTranscribe(WhisperFullParams.ByValue whisperParams, float[] audioData) throws IOException {\r
if (ctx == null) {\r
throw new IllegalStateException("Model not initialised");\r
}\r
\r
+ /*\r
+ WhisperFullParams.ByValue valueParams = new WhisperFullParams.ByValue(\r
+ lib.whisper_full_default_params_by_ref(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH.ordinal()));\r
+ valueParams.read();\r
+ */\r
+\r
if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) {\r
throw new IOException("Failed to process audio");\r
}\r
\r
return str.toString().trim();\r
}\r
+\r
public List<WhisperSegment> fullTranscribeWithTime(WhisperFullParams whisperParams, float[] audioData) throws IOException {\r
if (ctx == null) {\r
throw new IllegalStateException("Model not initialised");\r
}\r
\r
- if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) {\r
+ WhisperFullParams.ByValue valueParams = new WhisperFullParams.ByValue(\r
+ lib.whisper_full_default_params_by_ref(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH.ordinal()));\r
+ valueParams.read();\r
+\r
+ if (lib.whisper_full(ctx, valueParams, audioData, audioData.length) != 0) {\r
throw new IOException("Failed to process audio");\r
}\r
\r
* @param params Pointer to whisper_context_params\r
* @return Whisper context on success, null on failure\r
*/\r
- Pointer whisper_init_from_file_with_params(String path_model, WhisperContextParams params);\r
+ Pointer whisper_init_from_file_with_params(String path_model, WhisperContextParams.ByValue params);\r
\r
/**\r
* Allocate (almost) all memory needed for the model by loading from a buffer.\r
/**\r
* @return the id of the specified language, returns -1 if not found.\r
* Examples:\r
- * "de" -> 2\r
- * "german" -> 2\r
+ * "de" -> 2\r
+ * "german" -> 2\r
*/\r
int whisper_lang_id(String lang);\r
\r
- /** @return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found */\r
+ /** @return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found */\r
String whisper_lang_str(int id);\r
\r
/**\r
void whisper_free_params(Pointer params);\r
\r
/**\r
- * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text\r
+ * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text\r
* Not thread safe for same context\r
* Uses the specified decoding strategy to obtain the text.\r
*/\r
- int whisper_full(Pointer ctx, WhisperFullParams params, final float[] samples, int n_samples);\r
+ int whisper_full(Pointer ctx, WhisperFullParams.ByValue params, final float[] samples, int n_samples);\r
\r
- int whisper_full_with_state(Pointer ctx, Pointer state, WhisperFullParams params, final float[] samples, int n_samples);\r
+ public int whisper_full_with_state(Pointer ctx, Pointer state, WhisperFullParams.ByValue params, float[] samples, int n_samples);\r
+ //int whisper_full_with_state(Pointer ctx, Pointer state, WhisperFullParams params, final float[] samples, int n_samples);\r
\r
// Split the input audio in chunks and process each chunk separately using whisper_full_with_state()\r
// Result is stored in the default state of the context\r
// Not thread safe if executed in parallel on the same context.\r
// It seems this approach can offer some speedup in some cases.\r
// However, the transcription accuracy can be worse at the beginning and end of each chunk.\r
- int whisper_full_parallel(Pointer ctx, WhisperFullParams params, final float[] samples, int n_samples, int n_processors);\r
+ int whisper_full_parallel(Pointer ctx, WhisperFullParams.ByValue params, final float[] samples, int n_samples, int n_processors);\r
\r
/**\r
* Number of generated text segments.\r
--- /dev/null
+package io.github.ggerganov.whispercpp.callbacks;
+
+import com.sun.jna.Callback;
+
+/**
+ * Callback for aborting GGML computation
+ * Maps to the C typedef: bool (*ggml_abort_callback)(void * data)
+ */
+public interface GgmlAbortCallback extends Callback {
+ /**
+ * Return true to abort the computation, false to continue
+ *
+ * @param data User data passed to the callback
+ * @return true to abort, false to continue
+ */
+ boolean invoke(com.sun.jna.Pointer data);
+}
--- /dev/null
+package io.github.ggerganov.whispercpp.params;
+import com.sun.jna.*;
+import java.util.Arrays;
+import java.util.List;
+
+public class WhisperAhead extends Structure {
+
+ public int n_text_layer;
+
+ public int n_head;
+
+ public WhisperAhead() {
+ super();
+ }
+
+ public WhisperAhead(int textLayer, int head) {
+ super();
+ this.n_text_layer = textLayer;
+ this.n_head = head;
+ }
+
+ @Override
+ protected List<String> getFieldOrder() {
+ return Arrays.asList("n_text_layer", "n_head");
+ }
+
+ public static class ByReference extends WhisperAhead implements Structure.ByReference {}
+
+ public static class ByValue extends WhisperAhead implements Structure.ByValue {}
+}
--- /dev/null
+package io.github.ggerganov.whispercpp.params;
+import com.sun.jna.*;
+import java.util.Arrays;
+import java.util.List;
+
+public class WhisperAheads extends Structure {
+ public NativeLong n_heads;
+
+ public Pointer heads;
+
+ public WhisperAheads() {
+ super();
+ }
+
+ /**
+ * Create alignment heads from an array of WhisperAhead objects
+ */
+ public void setHeads(WhisperAhead[] aheadsArray) {
+ this.n_heads = new NativeLong(aheadsArray.length);
+
+ int structSize = aheadsArray[0].size();
+ Memory mem = new Memory(structSize * aheadsArray.length);
+
+ for (int i = 0; i < aheadsArray.length; i++) {
+ aheadsArray[i].write();
+ byte[] buffer = aheadsArray[i].getPointer().getByteArray(0, structSize);
+ mem.write(i * structSize, buffer, 0, buffer.length);
+ }
+
+ this.heads = mem;
+ }
+
+ @Override
+ protected List<String> getFieldOrder() {
+ return Arrays.asList("n_heads", "heads");
+ }
+
+ public static class ByReference extends WhisperAheads implements Structure.ByReference {}
+
+ public static class ByValue extends WhisperAheads implements Structure.ByValue {}
+}
package io.github.ggerganov.whispercpp.params;
-
import com.sun.jna.*;
-
import java.util.Arrays;
import java.util.List;
* whisper_context_default_params()
*/
public class WhisperContextParams extends Structure {
-
public WhisperContextParams(Pointer p) {
super(p);
}
- /** Use GPU for inference Number (default = true) */
+ public WhisperContextParams() {
+ super();
+ }
+
+ /** Use GPU for inference (default = true) */
public CBool use_gpu;
- /** Use GPU for inference Number (default = true) */
+ /** Use flash attention (default = false) */
+ public CBool flash_attn;
+
+ /** CUDA device to use (default = 0) */
+ public int gpu_device;
+
+ /** [EXPERIMENTAL] Enable token-level timestamps with DTW (default = false) */
+ public CBool dtw_token_timestamps;
+
+ /** [EXPERIMENTAL] Alignment heads preset for DTW */
+ public int dtw_aheads_preset;
+
+ /** Number of top layers to use for DTW when using WHISPER_AHEADS_N_TOP_MOST preset */
+ public int dtw_n_top;
+
+ public WhisperAheads.ByValue dtw_aheads;
+
+ /** DTW memory size (internal use) */
+ public NativeLong dtw_mem_size;
+
+ /** Use GPU for inference */
public void useGpu(boolean enable) {
use_gpu = enable ? CBool.TRUE : CBool.FALSE;
}
+ /** Use flash attention */
+ public void useFlashAttn(boolean enable) {
+ flash_attn = enable ? CBool.TRUE : CBool.FALSE;
+ }
+
+ /** Enable DTW token-level timestamps */
+ public void enableDtwTokenTimestamps(boolean enable) {
+ dtw_token_timestamps = enable ? CBool.TRUE : CBool.FALSE;
+ }
+
+ /** Set DTW alignment heads preset */
+ public void setDtwAheadsPreset(int preset) {
+ dtw_aheads_preset = preset;
+ }
+
@Override
protected List<String> getFieldOrder() {
- return Arrays.asList("use_gpu");
+ return Arrays.asList(
+ "use_gpu",
+ "flash_attn",
+ "gpu_device",
+ "dtw_token_timestamps",
+ "dtw_aheads_preset",
+ "dtw_n_top",
+ "dtw_aheads",
+ "dtw_mem_size"
+ );
+ }
+
+ public static class ByValue extends WhisperContextParams implements Structure.ByValue {
+ public ByValue() { super(); }
+ public ByValue(Pointer p) { super(p); }
}
}
import io.github.ggerganov.whispercpp.callbacks.WhisperLogitsFilterCallback;\r
import io.github.ggerganov.whispercpp.callbacks.WhisperNewSegmentCallback;\r
import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback;\r
+import io.github.ggerganov.whispercpp.callbacks.GgmlAbortCallback;\r
\r
import java.util.Arrays;\r
import java.util.List;\r
*/\r
public class WhisperFullParams extends Structure {\r
\r
+ public WhisperFullParams() {\r
+ super();\r
+ }\r
+\r
public WhisperFullParams(Pointer p) {\r
super(p);\r
-// super(p, ALIGN_MSVC);\r
-// super(p, ALIGN_GNUC);\r
}\r
\r
/** Sampling strategy for whisper_full() function. */\r
single_segment = single ? CBool.TRUE : CBool.FALSE;\r
}\r
\r
- /** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). (default = false) */\r
+ /** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). (default = false) */\r
public CBool print_special;\r
\r
- /** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). (default = false) */\r
+ /** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). (default = false) */\r
public void printSpecial(boolean enable) {\r
print_special = enable ? CBool.TRUE : CBool.FALSE;\r
}\r
/** Maximum tokens per segment (0, default = no limit) */\r
public int max_tokens;\r
\r
+ /** [EXPERIMENTAL] Enable debug mode for extra info */\r
+ public CBool debug_mode;\r
+\r
+ /** Enable debug mode */\r
+ public void enableDebugMode(boolean enable) {\r
+ debug_mode = enable ? CBool.TRUE : CBool.FALSE;\r
+ }\r
+\r
/** Overwrite the audio context size (0 = use default). */\r
public int audio_ctx;\r
\r
*/\r
public Pointer encoder_begin_callback_user_data;\r
\r
+ /** Callback used to abort GGML computation */\r
+ public Pointer abort_callback;\r
+\r
+ /** User data for the abort_callback */\r
+ public Pointer abort_callback_user_data;\r
+\r
+ public void setAbortCallback(GgmlAbortCallback callback) {\r
+ abort_callback = CallbackReference.getFunctionPointer(callback);\r
+ }\r
+\r
/**\r
* Callback by each decoder to filter obtained logits.\r
* WhisperLogitsFilterCallback\r
\r
@Override\r
protected List<String> getFieldOrder() {\r
- return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate",\r
- "no_context", "single_segment", "no_timestamps",\r
- "print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",\r
- "thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "audio_ctx",\r
- "tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",\r
- "suppress_blank", "suppress_nst", "temperature", "max_initial_ts", "length_penalty",\r
- "temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",\r
- "new_segment_callback", "new_segment_callback_user_data",\r
+ return Arrays.asList("strategy", "n_threads", "n_max_text_ctx",\r
+ "offset_ms", "duration_ms", "translate", "no_context",\r
+ "no_timestamps", "single_segment", "print_special",\r
+ "print_progress", "print_realtime", "print_timestamps",\r
+ "token_timestamps", "thold_pt", "thold_ptsum", "max_len",\r
+ "split_on_word", "max_tokens", "debug_mode", "audio_ctx", \r
+ "tdrz_enable", "suppress_regex", "initial_prompt",\r
+ "prompt_tokens", "prompt_n_tokens", "language", "detect_language",\r
+ "suppress_blank", "suppress_nst", "temperature",\r
+ "max_initial_ts", "length_penalty", "temperature_inc",\r
+ "entropy_thold", "logprob_thold", "no_speech_thold", "greedy",\r
+ "beam_search", "new_segment_callback", "new_segment_callback_user_data",\r
"progress_callback", "progress_callback_user_data",\r
"encoder_begin_callback", "encoder_begin_callback_user_data",\r
+ "abort_callback", "abort_callback_user_data",\r
"logits_filter_callback", "logits_filter_callback_user_data",\r
"grammar_rules", "n_grammar_rules", "i_start_rule", "grammar_penalty");\r
}\r
+\r
+ public static class ByValue extends WhisperFullParams implements Structure.ByValue {\r
+ public ByValue() { super(); }\r
+ public ByValue(Pointer p) { super(p); }\r
+ }\r
+\r
}\r
float[] floats = new float[b.length / 2];\r
\r
//WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);\r
- WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);\r
+ WhisperFullParams.ByValue params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);\r
params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));\r
params.print_progress = CBool.FALSE;\r
//params.initial_prompt = "and so my fellow Americans um, like";\r