]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vad : add initial Voice Activity Detection (VAD) support (#3065)
authorDaniel Bevenius <redacted>
Mon, 12 May 2025 14:10:11 +0000 (16:10 +0200)
committerGitHub <redacted>
Mon, 12 May 2025 14:10:11 +0000 (16:10 +0200)
* vad : add initial Voice Activity Detection (VAD) support

This commit add support for Voice Activity Detection (VAD). When enabled
this feature will process the audio input and detect speech segments.
This information is then used to reduce the number of samples that need
to be processed by whisper_full.

Resolves: https://github.com/ggml-org/whisper.cpp/issues/3003

---------

Co-authored-by: Georgi Gerganov <redacted>
.github/workflows/build.yml
README.md
examples/cli/cli.cpp
include/whisper.h
models/convert-silero-vad-to-ggml.py [new file with mode: 0644]
models/for-tests-silero-v5.1.2-ggml.bin [new file with mode: 0644]
src/whisper-arch.h
src/whisper.cpp
tests/CMakeLists.txt
tests/test-vad-full.cpp [new file with mode: 0644]
tests/test-vad.cpp [new file with mode: 0644]

index 7e8d461f0770c9b455b009f8c997c831f67fafbe..ada1a312636f1d31647f6c4a666d48d779f97718 100644 (file)
@@ -1253,3 +1253,23 @@ jobs:
           source venv/bin/activate
           pip install ane_transformers openai-whisper coremltools
           ./models/generate-coreml-model.sh ${{ env.MODEL_NAME }}
+
+  vad:
+    if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
+            github.event.inputs.run_type == 'full-ci' }}
+    runs-on: ubuntu-latest
+
+    steps:
+      - name: Checkout
+        uses: actions/checkout@v4
+
+      - name: Build
+        shell: bash
+        run: |
+          cmake -B build
+          cmake --build build --config Release
+
+      - name: Test
+        shell: bash
+        run: |
+          ctest -R ^test-vad$ --test-dir build --output-on-failure -VV
index 860aa6083cb942acb11d04d58d3e5644d344ad43..d0ead52eac429fefced8bba683ee076516c89051 100644 (file)
--- a/README.md
+++ b/README.md
@@ -25,6 +25,7 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp
 - [Ascend NPU Support](#ascend-npu-support)
 - [Moore Threads GPU Support](#moore-threads-gpu-support)
 - [C-style API](https://github.com/ggml-org/whisper.cpp/blob/master/include/whisper.h)
+- [Voice Activity Detection (VAD)](#voice-activity-detection-vad)
 
 Supported platforms:
 
@@ -732,6 +733,64 @@ let package = Package(
 )
 ```
 
+### Voice Activity Detection (VAD)
+Support for Voice Activity Detection (VAD) can be enabled using the `--vad`
+argument to `whisper-cli`. In addition to this option a VAD model is also
+required.
+
+The way this works is that first the audio samples are passed through
+the VAD model which will detect speech segments. Using this information the
+only the speech segments that are detected are extracted from the original audio
+input and passed to whisper for processing. This reduces the amount of audio
+data that needs to be processed by whisper and can significantly speed up the
+transcription process.
+
+The following VAD models are currently supported:
+
+#### Silero-VAD
+[Silero-vad](https://github.com/snakers4/silero-vad) is a lightweight VAD model
+written in Python that is fast and accurate.
+
+This model can be converted to ggml using the following command:
+```console
+$ python3 -m venv venv && source venv/bin/activate
+$ (venv) pip install silero-vad
+$ (venv) $ python models/convert-silero-vad-to-ggml.py --output models/silero.bin
+Saving GGML Silero-VAD model to models/silero-v5.1.2-ggml.bin
+```
+And it can then be used with whisper as follows:
+```console
+$ ./build/bin/whisper-cli \
+   --file ./samples/jfk.wav \
+   --model ./models/ggml-base.en.bin \
+   --vad \
+   --vad-model ./models/silero-v5.1.2-ggml.bin
+```
+
+#### VAD Options
+
+* --vad-threshold: Threshold probability for speech detection. A probability
+for a speech segment/frame above this threshold will be considered as speech.
+
+* --vad-min-speech-duration-ms: Minimum speech duration in milliseconds. Speech
+segments shorter than this value will be discarded to filter out brief noise or
+false positives.
+
+* --vad-min-silence-duration-ms: Minimum silence duration in milliseconds. Silence
+periods must be at least this long to end a speech segment. Shorter silence
+periods will be ignored and included as part of the speech.
+
+* --vad-max-speech-duration-s: Maximum speech duration in seconds. Speech segments
+longer than this will be automatically split into multiple segments at silence
+points exceeding 98ms to prevent excessively long segments.
+
+* --vad-speech-pad-ms: Speech padding in milliseconds. Adds this amount of padding
+before and after each detected speech segment to avoid cutting off speech edges.
+
+* --vad-samples-overlap: Amount of audio to extend from each speech segment into
+the next one, in seconds (e.g., 0.10 = 100ms overlap). This ensures speech isn't
+cut off abruptly between segments when they're concatenated together.
+
 ## Examples
 
 There are various examples of using the library for different projects in the [examples](examples) folder.
index 0cc3d38f3f4d4a48469a8c662f5714d201d59211..28dba7067a4632726474601c37bb991c5bd2400e 100644 (file)
@@ -11,6 +11,7 @@
 #include <thread>
 #include <vector>
 #include <cstring>
+#include <cfloat>
 
 #if defined(_WIN32)
 #ifndef NOMINMAX
@@ -97,6 +98,16 @@ struct whisper_params {
     std::vector<std::string> fname_out = {};
 
     grammar_parser::parse_state grammar_parsed;
+
+    // Voice Activity Detection (VAD) parameters
+    bool        vad           = false;
+    std::string vad_model     = "";
+    float       vad_threshold = 0.5f;
+    int         vad_min_speech_duration_ms = 250;
+    int         vad_min_silence_duration_ms = 100;
+    float       vad_max_speech_duration_s = FLT_MAX;
+    int         vad_speech_pad_ms = 30;
+    float       vad_samples_overlap = 0.1f;
 };
 
 static void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@@ -185,6 +196,15 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
         else if (                  arg == "--grammar")         { params.grammar         = ARGV_NEXT; }
         else if (                  arg == "--grammar-rule")    { params.grammar_rule    = ARGV_NEXT; }
         else if (                  arg == "--grammar-penalty") { params.grammar_penalty = std::stof(ARGV_NEXT); }
+        // Voice Activity Detection (VAD)
+        else if (arg == "-v"    || arg == "--vad")                         { params.vad                         = true; }
+        else if (arg == "-vm"   || arg == "--vad-model")                   { params.vad_model                   = ARGV_NEXT; }
+        else if (arg == "-vt"   || arg == "--vad-threshold")               { params.vad_threshold               = std::stof(ARGV_NEXT); }
+        else if (arg == "-vsd"  || arg == "--vad-min-speech-duration-ms")  { params.vad_min_speech_duration_ms  = std::stoi(ARGV_NEXT); }
+        else if (arg == "-vsd"  || arg == "--vad-min-silence-duration-ms") { params.vad_min_speech_duration_ms  = std::stoi(ARGV_NEXT); }
+        else if (arg == "-vmsd" || arg == "--vad-max-speech-duration-s")   { params.vad_max_speech_duration_s   = std::stof(ARGV_NEXT); }
+        else if (arg == "-vp"   || arg == "--vad-speech-pad-ms")           { params.vad_speech_pad_ms           = std::stoi(ARGV_NEXT); }
+        else if (arg == "-vo"   || arg == "--vad-samples-overlap")         { params.vad_samples_overlap         = std::stof(ARGV_NEXT); }
         else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
             whisper_print_usage(argc, argv, params);
@@ -254,6 +274,18 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params
     fprintf(stderr, "  --grammar GRAMMAR              [%-7s] GBNF grammar to guide decoding\n",                 params.grammar.c_str());
     fprintf(stderr, "  --grammar-rule RULE            [%-7s] top-level GBNF grammar rule name\n",               params.grammar_rule.c_str());
     fprintf(stderr, "  --grammar-penalty N            [%-7.1f] scales down logits of nongrammar tokens\n",      params.grammar_penalty);
+    // Voice Activity Detection (VAD) parameters
+    fprintf(stderr, "\nVoice Activity Detection (VAD) options:\n");
+    fprintf(stderr, "  -v,        --vad                           [%-7s] enable Voice Activity Detection (VAD)\n",            params.vad ? "true" : "false");
+    fprintf(stderr, "  -vm FNAME, --vad-model FNAME               [%-7s] VAD model path\n",                                   params.vad_model.c_str());
+    fprintf(stderr, "  -vt N,     --vad-threshold N               [%-7.2f] VAD threshold for speech recognition\n",           params.vad_threshold);
+    fprintf(stderr, "  -vspd N,   --vad-min-speech-duration-ms  N [%-7d] VAD min speech duration (0.0-1.0)\n",                params.vad_min_speech_duration_ms);
+    fprintf(stderr, "  -vsd N,    --vad-min-silence-duration-ms N [%-7d] VAD min silence duration (to split segments)\n",      params.vad_min_silence_duration_ms);
+    fprintf(stderr, "  -vmsd N,   --vad-max-speech-duration-s   N [%-7s] VAD max speech duration (auto-split longer)\n",      params.vad_max_speech_duration_s == FLT_MAX ?
+                                                                                                                                  std::string("FLT_MAX").c_str() :
+                                                                                                                                  std::to_string(params.vad_max_speech_duration_s).c_str());
+    fprintf(stderr, "  -vp N,     --vad-speech-pad-ms           N [%-7d] VAD speech padding (extend segments)\n",             params.vad_speech_pad_ms);
+    fprintf(stderr, "  -vo N,     --vad-samples-overlap         N [%-7.2f] VAD samples overlap (seconds between segments)\n", params.vad_samples_overlap);
     fprintf(stderr, "\n");
 }
 
@@ -1134,6 +1166,16 @@ int main(int argc, char ** argv) {
 
             wparams.suppress_nst     = params.suppress_nst;
 
+            wparams.vad            = params.vad;
+            wparams.vad_model_path = params.vad_model.c_str();
+
+            wparams.vad_params.threshold               = params.vad_threshold;
+            wparams.vad_params.min_speech_duration_ms  = params.vad_min_speech_duration_ms;
+            wparams.vad_params.min_silence_duration_ms = params.vad_min_silence_duration_ms;
+            wparams.vad_params.max_speech_duration_s   = params.vad_max_speech_duration_s;
+            wparams.vad_params.speech_pad_ms           = params.vad_speech_pad_ms;
+            wparams.vad_params.samples_overlap         = params.vad_samples_overlap;
+
             whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
 
             const auto & grammar_parsed = params.grammar_parsed;
index 1e1375033adbd3306ba87687084d661e48abd524..4aeda98f334e6de8c03232a3762938d5ec7406b4 100644 (file)
@@ -189,6 +189,15 @@ extern "C" {
         uint32_t             value; // Unicode code point or rule ID
     } whisper_grammar_element;
 
+    typedef struct whisper_vad_params {
+        float threshold;               // Probability threshold to consider as speech.
+        int   min_speech_duration_ms;  // Min duration for a valid speech segment.
+        int   min_silence_duration_ms; // Min silence duration to consider speech as ended.
+        float max_speech_duration_s;   // Max duration of a speech segment before forcing a new segment.
+        int   speech_pad_ms;           // Padding added before and after speech segments.
+        float samples_overlap;         // Overlap in seconds when copying audio samples from speech segment.
+    } whisper_vad_params;
+
     // Various functions for loading a ggml whisper model.
     // Allocate (almost) all memory needed for the model.
     // Return NULL on failure
@@ -570,11 +579,18 @@ extern "C" {
         size_t                           n_grammar_rules;
         size_t                           i_start_rule;
         float                            grammar_penalty;
+
+        // Voice Activity Detection (VAD) params
+        bool         vad;                         // Enable VAD
+        const char * vad_model_path;              // Path to VAD model
+
+        whisper_vad_params vad_params;
     };
 
     // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params()
     WHISPER_API struct whisper_context_params * whisper_context_default_params_by_ref(void);
     WHISPER_API struct whisper_context_params   whisper_context_default_params       (void);
+
     WHISPER_API struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy);
     WHISPER_API struct whisper_full_params   whisper_full_default_params       (enum whisper_sampling_strategy strategy);
 
@@ -652,6 +668,53 @@ extern "C" {
     WHISPER_API float whisper_full_get_token_p           (struct whisper_context * ctx, int i_segment, int i_token);
     WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token);
 
+    //
+    // Voice Activity Detection (VAD)
+    //
+
+    struct whisper_vad_context;
+
+    WHISPER_API struct whisper_vad_params whisper_vad_default_params(void);
+
+    struct whisper_vad_context_params {
+        int   n_threads;  // The number of threads to use for processing.
+        bool  use_gpu;
+        int   gpu_device; // CUDA device
+    };
+
+    WHISPER_API struct whisper_vad_context_params whisper_vad_default_context_params(void);
+
+    WHISPER_API struct whisper_vad_context * whisper_vad_init_from_file_with_params(const char * path_model,              struct whisper_vad_context_params params);
+    WHISPER_API struct whisper_vad_context * whisper_vad_init_with_params          (struct whisper_model_loader * loader, struct whisper_vad_context_params params);
+
+    WHISPER_API bool whisper_vad_detect_speech(
+            struct whisper_vad_context * vctx,
+                           const float * samples,
+                                   int   n_samples);
+
+    WHISPER_API int     whisper_vad_n_probs(struct whisper_vad_context * vctx);
+    WHISPER_API float * whisper_vad_probs  (struct whisper_vad_context * vctx);
+
+    struct whisper_vad_segments;
+
+    WHISPER_API struct whisper_vad_segments * whisper_vad_segments_from_probs(
+            struct whisper_vad_context * vctx,
+            struct whisper_vad_params    params);
+
+    WHISPER_API struct whisper_vad_segments * whisper_vad_segments_from_samples(
+            struct whisper_vad_context * vctx,
+            struct whisper_vad_params    params,
+                           const float * samples,
+                                   int   n_samples);
+
+    WHISPER_API int whisper_vad_segments_n_segments(struct whisper_vad_segments * segments);
+
+    WHISPER_API float whisper_vad_segments_get_segment_t0(struct whisper_vad_segments * segments, int i_segment);
+    WHISPER_API float whisper_vad_segments_get_segment_t1(struct whisper_vad_segments * segments, int i_segment);
+
+    WHISPER_API void whisper_vad_free_segments(struct whisper_vad_segments * segments);
+    WHISPER_API void whisper_vad_free         (struct whisper_vad_context  * ctx);
+
     ////////////////////////////////////////////////////////////////////////////
 
     // Temporary helpers needed for exposing ggml interface
diff --git a/models/convert-silero-vad-to-ggml.py b/models/convert-silero-vad-to-ggml.py
new file mode 100644 (file)
index 0000000..078131c
--- /dev/null
@@ -0,0 +1,196 @@
+import os
+import struct
+import argparse
+import torch
+import numpy as np
+from silero_vad import load_silero_vad, __version__ as silero_version
+
+def convert_silero_vad(output_path, print_tensors=True):
+    model = load_silero_vad()
+    state_dict = model.state_dict()
+
+    # Clean up state dict keys - filter out 8k model
+    cleaned_dict = {}
+    for key, value in state_dict.items():
+        # Skip 8k model
+        if "_8k" not in key:
+            clean_key = key
+            if not key.startswith("_model."):
+                clean_key = "_model." + key
+            cleaned_dict[clean_key] = value
+
+    base, ext = os.path.splitext(output_path)
+    output_file = f"{base}-v{silero_version}-ggml{ext}"
+    print(f"Saving GGML Silero-VAD model to {output_file}")
+
+    print("\nTensor info for debugging:")
+    for key, tensor in cleaned_dict.items():
+        print(f"  - {key}: {tensor.shape} ({tensor.dtype})")
+    print()
+
+    with open(output_file, "wb") as fout:
+        # Write magic and version
+        fout.write(struct.pack("i", 0x67676d6c))
+
+        model_type = "silero-16k"
+        str_len = len(model_type)
+        fout.write(struct.pack("i", str_len))
+        fout.write(model_type.encode('utf-8'))
+
+        version_parts = silero_version.split('.')
+        major, minor, patch = map(int, version_parts)
+        print(f"Version: {major}.{minor}.{patch}")
+        fout.write(struct.pack("i", major))
+        fout.write(struct.pack("i", minor))
+        fout.write(struct.pack("i", patch))
+
+        # Write model architecture parameters
+        window_size = 512
+        fout.write(struct.pack("i", window_size))
+        context_size = 64
+        fout.write(struct.pack("i", context_size))
+
+        n_encoder_layers = 4
+        fout.write(struct.pack("i", n_encoder_layers))
+
+        # Write encoder dimensions
+        input_channels = 129
+        encoder_in_channels = [input_channels, 128, 64, 64]
+        encoder_out_channels = [128, 64, 64, 128]
+        kernel_size = 3
+
+        for i in range(n_encoder_layers):
+            fout.write(struct.pack("i", encoder_in_channels[i]))
+            fout.write(struct.pack("i", encoder_out_channels[i]))
+            fout.write(struct.pack("i", kernel_size))
+
+        # Write LSTM dimensions
+        lstm_input_size = 128
+        lstm_hidden_size = 128
+        fout.write(struct.pack("i", lstm_input_size))
+        fout.write(struct.pack("i", lstm_hidden_size))
+
+        # Write final conv dimensions
+        final_conv_in = 128
+        final_conv_out = 1
+        fout.write(struct.pack("i", final_conv_in))
+        fout.write(struct.pack("i", final_conv_out))
+
+        # Define tensor keys to write
+        tensor_keys = []
+
+        # Encoder weights
+        for i in range(n_encoder_layers):
+            weight_key = f"_model.encoder.{i}.reparam_conv.weight"
+            bias_key = f"_model.encoder.{i}.reparam_conv.bias"
+            if weight_key in cleaned_dict and bias_key in cleaned_dict:
+                tensor_keys.append(weight_key)
+                tensor_keys.append(bias_key)
+
+        # LSTM weights
+        lstm_keys = [
+            "_model.decoder.rnn.weight_ih",
+            "_model.decoder.rnn.weight_hh",
+            "_model.decoder.rnn.bias_ih",
+            "_model.decoder.rnn.bias_hh"
+        ]
+        tensor_keys.extend([k for k in lstm_keys if k in cleaned_dict])
+
+        # Final conv weights
+        final_keys = [
+            "_model.decoder.decoder.2.weight",
+            "_model.decoder.decoder.2.bias"
+        ]
+        tensor_keys.extend([k for k in final_keys if k in cleaned_dict])
+
+        # STFT basis - add this last
+        stft_tensor = "_model.stft.forward_basis_buffer"
+        tensor_keys.append(stft_tensor)
+
+        print(f"Writing {len(tensor_keys)} tensors:")
+        for key in tensor_keys:
+            if key in cleaned_dict:
+                print(f"  - {key}: {cleaned_dict[key].shape}")
+            else:
+                print(f"  - {key}: MISSING")
+
+        # Process each tensor
+        for key in tensor_keys:
+            if key not in cleaned_dict:
+                print(f"Warning: Missing tensor {key}, skipping")
+                continue
+
+            tensor = cleaned_dict[key]
+
+            # Special handling for STFT tensor
+            if key == "_model.stft.forward_basis_buffer":
+                # Get the original numpy array without squeezing
+                data = tensor.detach().cpu().numpy()
+                # Ensure it has the expected shape
+                print(f"STFT tensor original shape: {data.shape}")
+                n_dims = 3
+                tensor_shape = [data.shape[2], data.shape[1], data.shape[0]]
+                is_conv_weight = True
+            else:
+                # For other tensors, we can use standard processing
+                data = tensor.detach().cpu().squeeze().numpy()
+                tensor_shape = list(data.shape)
+
+                # Ensure we have at most 4 dimensions for GGML
+                n_dims = min(len(tensor_shape), 4)
+
+                # Reverse dimensions for GGML
+                tensor_shape = tensor_shape[:n_dims]
+                tensor_shape.reverse()
+
+                # Check if this is a convolution weight tensor
+                is_conv_weight = "weight" in key and ("encoder" in key or "_model.decoder.decoder.2" in key)
+
+            # Convert to float16 for convolution weights
+            if is_conv_weight:
+                data = data.astype(np.float16)
+                ftype = 1  # float16
+            else:
+                ftype = 0  # float32
+
+            # Debug printing of tensor info
+            print(f"\nWriting tensor: {key}")
+            print(f"  Original shape: {tensor.shape}")
+            print(f"  Processed shape: {data.shape}")
+            print(f"  GGML dimensions: {n_dims}")
+            print(f"  GGML shape: {tensor_shape}")
+            print(f"  Type: {'float16' if ftype == 1 else 'float32'}")
+
+            # Convert tensor name to bytes
+            name_bytes = key.encode('utf-8')
+            name_length = len(name_bytes)
+
+            # Write tensor header
+            fout.write(struct.pack("i", n_dims))
+            fout.write(struct.pack("i", name_length))
+            fout.write(struct.pack("i", ftype))
+
+            # Write tensor dimensions
+            for i in range(n_dims):
+                size = tensor_shape[i] if i < len(tensor_shape) else 1
+                fout.write(struct.pack("i", size))
+                print(f"  Writing dimension {i}: {size}")
+
+            # Write tensor name
+            fout.write(name_bytes)
+
+            # Write tensor data
+            data.tofile(fout)
+
+            print(f"  Wrote {data.size * (2 if ftype==1 else 4)} bytes")
+
+    print(f"\nDone! Model has been converted to GGML format: {output_file}")
+    print(f"File size: {os.path.getsize(output_file)} bytes")
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="Convert Silero-VAD PyTorch model to GGML format")
+    parser.add_argument("--output", type=str, required=True, help="Path to output GGML model file")
+    parser.add_argument("--print-tensors", action="store_true", help="Print tensor values", default=True)
+    args = parser.parse_args()
+
+    convert_silero_vad(args.output, args.print_tensors)
diff --git a/models/for-tests-silero-v5.1.2-ggml.bin b/models/for-tests-silero-v5.1.2-ggml.bin
new file mode 100644 (file)
index 0000000..c5ddfb5
Binary files /dev/null and b/models/for-tests-silero-v5.1.2-ggml.bin differ
index ea2cfd6013c5aceb3de7288546d825134761f542..3a65ff35aa070e254d5a1bf174f8914f20eb50df 100644 (file)
@@ -139,3 +139,59 @@ static const std::map<asr_tensor, ggml_op> ASR_TENSOR_INFO = {
     {ASR_TENSOR_ATTN_OUT_WEIGHT,       GGML_OP_MUL_MAT},
     {ASR_TENSOR_ATTN_OUT_BIAS,         GGML_OP_ADD},
 };
+
+enum vad_tensor {
+    VAD_TENSOR_STFT_BASIS,
+    VAD_TENSOR_ENC_0_WEIGHT,
+    VAD_TENSOR_ENC_0_BIAS,
+    VAD_TENSOR_ENC_1_WEIGHT,
+    VAD_TENSOR_ENC_1_BIAS,
+    VAD_TENSOR_ENC_2_WEIGHT,
+    VAD_TENSOR_ENC_2_BIAS,
+    VAD_TENSOR_ENC_3_WEIGHT,
+    VAD_TENSOR_ENC_3_BIAS,
+    VAD_TENSOR_LSTM_WEIGHT_IH,
+    VAD_TENSOR_LSTM_WEIGHT_HH,
+    VAD_TENSOR_LSTM_BIAS_IH,
+    VAD_TENSOR_LSTM_BIAS_HH,
+    VAD_TENSOR_FINAL_CONV_WEIGHT,
+    VAD_TENSOR_FINAL_CONV_BIAS,
+};
+
+static const std::map<vad_tensor, ggml_op> VAD_TENSOR_OPS = {
+    {VAD_TENSOR_STFT_BASIS,          GGML_OP_IM2COL},
+    {VAD_TENSOR_ENC_0_WEIGHT,        GGML_OP_IM2COL},
+    {VAD_TENSOR_ENC_0_BIAS,          GGML_OP_ADD},
+    {VAD_TENSOR_ENC_1_WEIGHT,        GGML_OP_IM2COL},
+    {VAD_TENSOR_ENC_1_BIAS,          GGML_OP_ADD},
+    {VAD_TENSOR_ENC_2_WEIGHT,        GGML_OP_IM2COL},
+    {VAD_TENSOR_ENC_2_BIAS,          GGML_OP_ADD},
+    {VAD_TENSOR_ENC_3_WEIGHT,        GGML_OP_IM2COL},
+    {VAD_TENSOR_ENC_3_BIAS,          GGML_OP_ADD},
+
+    {VAD_TENSOR_LSTM_WEIGHT_IH,      GGML_OP_MUL_MAT},
+    {VAD_TENSOR_LSTM_WEIGHT_HH,      GGML_OP_MUL_MAT},
+    {VAD_TENSOR_LSTM_BIAS_IH,        GGML_OP_ADD},
+    {VAD_TENSOR_LSTM_BIAS_HH,        GGML_OP_ADD},
+
+    {VAD_TENSOR_FINAL_CONV_WEIGHT,   GGML_OP_IM2COL},
+    {VAD_TENSOR_FINAL_CONV_BIAS,     GGML_OP_ADD}
+};
+
+static const std::map<vad_tensor, const char *> VAD_TENSOR_NAMES = {
+    {VAD_TENSOR_STFT_BASIS,          "_model.stft.forward_basis_buffer"},
+    {VAD_TENSOR_ENC_0_WEIGHT,        "_model.encoder.0.reparam_conv.weight"},
+    {VAD_TENSOR_ENC_0_BIAS,          "_model.encoder.0.reparam_conv.bias"},
+    {VAD_TENSOR_ENC_1_WEIGHT,        "_model.encoder.1.reparam_conv.weight"},
+    {VAD_TENSOR_ENC_1_BIAS,          "_model.encoder.1.reparam_conv.bias"},
+    {VAD_TENSOR_ENC_2_WEIGHT,        "_model.encoder.2.reparam_conv.weight"},
+    {VAD_TENSOR_ENC_2_BIAS,          "_model.encoder.2.reparam_conv.bias"},
+    {VAD_TENSOR_ENC_3_WEIGHT,        "_model.encoder.3.reparam_conv.weight"},
+    {VAD_TENSOR_ENC_3_BIAS,          "_model.encoder.3.reparam_conv.bias"},
+    {VAD_TENSOR_LSTM_WEIGHT_IH,      "_model.decoder.rnn.weight_ih"},
+    {VAD_TENSOR_LSTM_WEIGHT_HH,      "_model.decoder.rnn.weight_hh"},
+    {VAD_TENSOR_LSTM_BIAS_IH,        "_model.decoder.rnn.bias_ih"},
+    {VAD_TENSOR_LSTM_BIAS_HH,        "_model.decoder.rnn.bias_hh"},
+    {VAD_TENSOR_FINAL_CONV_WEIGHT,   "_model.decoder.decoder.2.weight"},
+    {VAD_TENSOR_FINAL_CONV_BIAS,     "_model.decoder.decoder.2.bias"}
+};
index e103e29b8beacdefeaab14647995a205e6c967f4..bba91f200bf1f8b3589298d2ac975b2eb2f4e7d1 100644 (file)
@@ -17,6 +17,7 @@
 #include <atomic>
 #include <algorithm>
 #include <cassert>
+#include <cfloat>
 #define _USE_MATH_DEFINES
 #include <cmath>
 #include <climits>
@@ -163,7 +164,6 @@ static bool ggml_graph_compute_helper(
                          int   n_threads,
          ggml_abort_callback   abort_callback,
                         void * abort_callback_data) {
-
     ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
 
     auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get()));
@@ -184,8 +184,8 @@ static bool ggml_graph_compute_helper(
 static bool ggml_graph_compute_helper(
       ggml_backend_sched_t   sched,
         struct ggml_cgraph * graph,
-                       int   n_threads) {
-
+                       int   n_threads,
+                      bool   sched_reset = true) {
     for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) {
         ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i);
         ggml_backend_dev_t dev = ggml_backend_get_device(backend);
@@ -197,8 +197,12 @@ static bool ggml_graph_compute_helper(
         }
     }
 
-    bool t = ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS;
-    ggml_backend_sched_reset(sched);
+    const bool t = (ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS);
+
+    if (!t || sched_reset) {
+        ggml_backend_sched_reset(sched);
+    }
+
     return t;
 }
 
@@ -949,6 +953,15 @@ struct whisper_state {
 
     // [EXPERIMENTAL] speed-up techniques
     int32_t exp_n_audio_ctx = 0; // 0 - use default
+
+    struct vad_segment_info {
+        float orig_start;
+        float orig_end;
+        float vad_start;
+        float vad_end;
+    };
+    std::vector<vad_segment_info> vad_segments;
+    bool has_vad_segments = false;
 };
 
 struct whisper_context {
@@ -4341,225 +4354,1337 @@ const char * whisper_print_system_info(void) {
 }
 
 //////////////////////////////////
-// Grammar - ported from llama.cpp
+// Voice Activity Detection (VAD)
 //////////////////////////////////
 
-// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
-// pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
-static std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
-        const char         * src,
-        whisper_partial_utf8   partial_start) {
-    static const int      lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
-    const char          * pos      = src;
-    std::vector<uint32_t> code_points;
-    uint32_t              value    = partial_start.value;
-    int                   n_remain = partial_start.n_remain;
+struct whisper_vad_hparams {
+    int32_t   n_encoder_layers;
+    int32_t * encoder_in_channels;
+    int32_t * encoder_out_channels;
+    int32_t * kernel_sizes;
+    int32_t   lstm_input_size;
+    int32_t   lstm_hidden_size;
+    int32_t   final_conv_in;
+    int32_t   final_conv_out;
+};
 
-    // continue previous decode, if applicable
-    while (*pos != 0 && n_remain > 0) {
-        uint8_t next_byte = static_cast<uint8_t>(*pos);
-        if ((next_byte >> 6) != 2) {
-            // invalid sequence, abort
-            code_points.push_back(0);
-            return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
-        }
-        value = (value << 6) + (next_byte & 0x3F);
-        ++pos;
-        --n_remain;
-    }
+struct whisper_vad_model {
+    std::string type;
+    std::string version;
+    whisper_vad_hparams hparams;
 
-    if (partial_start.n_remain > 0 && n_remain == 0) {
-        code_points.push_back(value);
-    }
+    struct ggml_tensor * stft_forward_basis; // [256, 1, 258]
 
-    // decode any subsequent utf-8 sequences, which may end in an incomplete one
-    while (*pos != 0) {
-        uint8_t  first_byte = static_cast<uint8_t>(*pos);
-        uint8_t  highbits   = first_byte >> 4;
-                 n_remain   = lookup[highbits] - 1;
+    // Encoder tensors - 4 convolutional layers
+    struct ggml_tensor * encoder_0_weight;  // [3, 129, 128]
+    struct ggml_tensor * encoder_0_bias;    // [128]
 
-        if (n_remain < 0) {
-            // invalid sequence, abort
-            code_points.clear();
-            code_points.push_back(0);
-            return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain });
-        }
+    // Second encoder layer
+    struct ggml_tensor * encoder_1_weight;  // [3, 128, 64]
+    struct ggml_tensor * encoder_1_bias;    // [64]
 
-        uint8_t  mask       = (1 << (7 - n_remain)) - 1;
-                 value      = first_byte & mask;
-        ++pos;
-        while (*pos != 0 && n_remain > 0) {
-            value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
-            ++pos;
-            --n_remain;
-        }
-        if (n_remain == 0) {
-            code_points.push_back(value);
-        }
-    }
-    code_points.push_back(0);
+    // Third encoder layer
+    struct ggml_tensor * encoder_2_weight;  // [3, 64, 64]
+    struct ggml_tensor * encoder_2_bias;    // [64]
 
-    return std::make_pair(std::move(code_points), whisper_partial_utf8{ value, n_remain });
-}
+    // Fourth encoder layer
+    struct ggml_tensor * encoder_3_weight;  // [3, 64, 128]
+    struct ggml_tensor * encoder_3_bias;    // [128]
 
-// returns true iff pos points to the end of one of the definitions of a rule
-static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element * pos) {
-    switch (pos->type) {
-        case WHISPER_GRETYPE_END: return true;  // NOLINT
-        case WHISPER_GRETYPE_ALT: return true;  // NOLINT
-        default:                return false;
-    }
-}
+    // LSTM decoder tensors
+    struct ggml_tensor * lstm_ih_weight;    // [128, 512] input-to-hidden
+    struct ggml_tensor * lstm_ih_bias;      // [512]
+    struct ggml_tensor * lstm_hh_weight;    // [128, 512] hidden-to-hidden
+    struct ggml_tensor * lstm_hh_bias;      // [512]
 
-// returns true iff chr satisfies the char range at pos (regular or inverse range)
-// asserts that pos is pointing to a char range element
-static std::pair<bool, const whisper_grammar_element *> whisper_grammar_match_char(
-        const whisper_grammar_element * pos,
-        const uint32_t                chr) {
+    // Final conv layer
+    struct ggml_tensor * final_conv_weight; // [128]
+    struct ggml_tensor * final_conv_bias;   // [1]
 
-    bool found            = false;
-    bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
+    // ggml contexts
+    std::vector<ggml_context *> ctxs;
 
-    WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT
+    // buffer for the model tensors
+    std::vector<ggml_backend_buffer_t> buffers;
 
-    do {
-        if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
-            // inclusive range, e.g. [a-z]
-            found = found || (pos->value <= chr && chr <= pos[1].value);
-            pos += 2;
-        } else {
-            // exact char match, e.g. [a] or "a"
-            found = found || pos->value == chr;
-            pos += 1;
-        }
-    } while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
+    // tensors
+    int n_loaded;
+    std::map<std::string, struct ggml_tensor *> tensors;
+};
 
-    return std::make_pair(found == is_positive_char, pos);
+struct whisper_vad_segment {
+    float start; // Start time in seconds
+    float end;   // End time in seconds
+};
+
+struct whisper_vad_segments {
+    std::vector<whisper_vad_segment> data;
+};
+
+struct whisper_vad_context {
+    int64_t t_vad_us = 0;
+
+    int     n_window;
+    int     n_context;
+    int     n_threads;
+
+    std::vector<ggml_backend_t> backends;
+    ggml_backend_buffer_t       buffer = nullptr;
+    whisper_context_params      params;
+    std::vector<uint8_t>        ctx_buf;
+    whisper_sched               sched;
+
+    whisper_vad_model    model;
+    std::string          path_model;
+    struct ggml_tensor * h_state;
+    struct ggml_tensor * c_state;
+    std::vector<float>   probs;
+};
+
+struct whisper_vad_context_params whisper_vad_default_context_params(void) {
+    whisper_vad_context_params result = {
+        /*.n_thread                = */ 4,
+        /*.use_gpu                 = */ false,
+        /*.gpu_device              = */ 0,
+    };
+    return result;
 }
 
-// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
-// range at pos (regular or inverse range)
-// asserts that pos is pointing to a char range element
-static bool whisper_grammar_match_partial_char(
-        const whisper_grammar_element * pos,
-        const whisper_partial_utf8      partial_utf8) {
+struct whisper_vad_params whisper_vad_default_params(void) {
+    whisper_vad_params result = {
+        /* threshold               = */ 0.5f,
+        /* min_speech_duration_ms  = */ 250,
+        /* min_silence_duration_ms = */ 100,
+        /* max_speech_duration_s   = */ FLT_MAX,
+        /* speech_pad_ms           = */ 30,
+        /* samples_overlap         = */ 0.1,
+    };
+    return result;
+}
 
-    bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
-    WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT);
+static bool weight_buft_supported(const whisper_vad_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) {
+    bool op_supported = true;
 
-    uint32_t partial_value = partial_utf8.value;
-    int      n_remain      = partial_utf8.n_remain;
+    if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU ||
+        (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) {
+        // GPU and default CPU backend support all operators
+        op_supported = true;
+    } else {
+        switch (op) {
+            // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT
+            case GGML_OP_MUL_MAT: {
+                ggml_init_params params = {
+                    /*.mem_size   =*/ 2 * ggml_tensor_overhead(),
+                    /*.mem_buffer =*/ nullptr,
+                    /*.no_alloc   =*/ true,
+                };
 
-    // invalid sequence or 7-bit char split across 2 bytes (overlong)
-    if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
-        return false;
-    }
+                ggml_context_ptr ctx_ptr { ggml_init(params) };
+                if (!ctx_ptr) {
+                    throw std::runtime_error("failed to create ggml context");
+                }
+                ggml_context * ctx = ctx_ptr.get();
 
-    // range of possible code points this partial UTF-8 sequence could complete to
-    uint32_t low  = partial_value << (n_remain * 6);
-    uint32_t high = low | ((1 << (n_remain * 6)) - 1);
+                ggml_tensor * op_tensor = nullptr;
 
-    if (low == 0) {
-        if (n_remain == 2) {
-            low = 1 << 11;
-        } else if (n_remain == 3) {
-            low = 1 << 16;
-        }
-    }
+                int64_t n_ctx = hparams.lstm_hidden_size;
+                ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
+                op_tensor = ggml_mul_mat(ctx, w, b);
 
-    do {
-        if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
-            // inclusive range, e.g. [a-z]
-            if (pos->value <= high && low <= pos[1].value) {
-                return is_positive_char;
+                // create a temporary dummy buffer for the weight so that supports_op can check the buffer type
+                GGML_ASSERT(w->buffer == nullptr);
+                w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
+                op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
+                ggml_backend_buffer_free(w->buffer);
+                w->buffer = nullptr;
+                break;
             }
-            pos += 2;
-        } else {
-            // exact char match, e.g. [a] or "a"
-            if (low <= pos->value && pos->value <= high) {
-                return is_positive_char;
+            default: {
+                op_supported = false;
+                break;
             }
-            pos += 1;
+        };
+    }
+    return op_supported;
+}
+
+static ggml_backend_buffer_type_t select_weight_buft(const whisper_vad_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) {
+    GGML_ASSERT(!buft_list.empty());
+    for (const auto & p : buft_list) {
+        ggml_backend_dev_t dev = p.first;
+        ggml_backend_buffer_type_t buft = p.second;
+        if (weight_buft_supported(hparams, w, op, buft, dev)) {
+            return buft;
         }
-    } while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
+    }
 
-    return !is_positive_char;
+    return nullptr;
 }
 
+static ggml_tensor * whisper_vad_build_stft_layer(ggml_context * ctx0,
+        const whisper_vad_model & model, ggml_tensor * cur) {
+    // Apply reflective padding to the input tensor
+    ggml_tensor * padded = ggml_pad_reflect_1d(ctx0, cur, 64, 64);
 
-// transforms a grammar pushdown stack into N possible stacks, all ending
-// at a character range (terminal element)
-static void whisper_grammar_advance_stack(
-        const std::vector<std::vector<whisper_grammar_element>>   & rules,
-        const std::vector<const whisper_grammar_element *>        & stack,
-        std::vector<std::vector<const whisper_grammar_element *>> & new_stacks) {
+    struct ggml_tensor * stft = ggml_conv_1d(ctx0, model.stft_forward_basis, padded, model.hparams.lstm_input_size, 0, 1);
 
-    if (stack.empty()) {
-        new_stacks.emplace_back();
-        return;
-    }
+    // Calculate cutoff for real/imaginary parts
+    int cutoff = model.stft_forward_basis->ne[2] / 2;
 
-    const whisper_grammar_element * pos = stack.back();
+    // Extract real part (first half of the STFT output).
+    struct ggml_tensor * real_part = ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], 0);
+    // Extract imaginary part (second half of the STFT output).
+    struct ggml_tensor * img_part = ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], cutoff * stft->nb[1]);
 
-    switch (pos->type) {
-        case WHISPER_GRETYPE_RULE_REF: {
-            const size_t                  rule_id = static_cast<size_t>(pos->value);
-            const whisper_grammar_element * subpos  = rules[rule_id].data();
-            do {
-                // init new stack without the top (pos)
-                std::vector<const whisper_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
-                if (!whisper_grammar_is_end_of_sequence(pos + 1)) {
-                    // if this rule ref is followed by another element, add that to stack
-                    new_stack.push_back(pos + 1);
-                }
-                if (!whisper_grammar_is_end_of_sequence(subpos)) {
-                    // if alternate is nonempty, add to stack
-                    new_stack.push_back(subpos);
-                }
-                whisper_grammar_advance_stack(rules, new_stack, new_stacks);
-                while (!whisper_grammar_is_end_of_sequence(subpos)) {
-                    // scan to end of alternate def
-                    subpos++;
-                }
-                if (subpos->type == WHISPER_GRETYPE_ALT) {
-                    // there's another alternate def of this rule to process
-                    subpos++;
-                } else {
-                    break;
-                }
-            } while (true);
-            break;
-        }
-        case WHISPER_GRETYPE_CHAR:
-        case WHISPER_GRETYPE_CHAR_NOT:
-            new_stacks.push_back(stack);
-            break;
-        default:
-            // end of alternate (WHISPER_GRETYPE_END, WHISPER_GRETYPE_ALT) or middle of char range
-            // (WHISPER_GRETYPE_CHAR_ALT, WHISPER_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
-            // those
-            WHISPER_ASSERT(false);
-    }
+    // Calculate magnitude: sqrt(real^2 + imag^2)
+    struct ggml_tensor * real_squared = ggml_mul(ctx0, real_part, real_part);
+    struct ggml_tensor * img_squared  = ggml_mul(ctx0, img_part, img_part);
+    struct ggml_tensor * sum_squares  = ggml_add(ctx0, real_squared, img_squared);
+    struct ggml_tensor * magnitude    = ggml_sqrt(ctx0, sum_squares);
+    return magnitude;
 }
 
-// takes a set of possible pushdown stacks on a grammar, which are required to
-// be positioned at a character range (see `whisper_grammar_advance_stack`), and
-// produces the N possible stacks if the given char is accepted at those
-// positions
-static std::vector<std::vector<const whisper_grammar_element *>> whisper_grammar_accept(
-        const std::vector<std::vector<whisper_grammar_element>>         & rules,
-        const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
-        const uint32_t                                                  chr) {
+static ggml_tensor * whisper_vad_build_encoder_layer(ggml_context * ctx0,
+        const whisper_vad_model & model, ggml_tensor * cur) {
+    // First Conv1D: expands to 128 channels.
+    cur = ggml_conv_1d(ctx0, model.encoder_0_weight, cur, 1, 1, 1);
+    cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_0_bias, 1, 128, 1));
+    cur = ggml_relu(ctx0, cur);
 
-    std::vector<std::vector<const whisper_grammar_element *>> new_stacks;
+    // Second Conv1D: reduces to 64 channels.
+    cur = ggml_conv_1d(ctx0, model.encoder_1_weight, cur, 2, 1, 1);
+    cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_1_bias, 1, 64, 1));
+    cur = ggml_relu(ctx0, cur);
 
-    for (const auto & stack : stacks) {
-        if (stack.empty()) {
-            continue;
-        }
+    // Third Conv1D: maintains 64 channels
+    cur = ggml_conv_1d(ctx0, model.encoder_2_weight, cur, 2, 1, 1);
+    cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_2_bias, 1, 64, 1));
+    cur = ggml_relu(ctx0, cur);
 
-        auto match = whisper_grammar_match_char(stack.back(), chr);
-        if (match.first) {
+    // Fourth Conv1D: expands to 128 channels
+    cur = ggml_conv_1d(ctx0, model.encoder_3_weight, cur, 1, 1, 1);
+    cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_3_bias, 1, 128, 1));
+    cur = ggml_relu(ctx0, cur);
+
+    return cur;
+}
+
+static ggml_tensor * whisper_vad_build_lstm_layer(ggml_context * ctx0,
+        const whisper_vad_context & vctx, ggml_tensor * cur, ggml_cgraph * gf) {
+    const whisper_vad_model & model = vctx.model;
+    const int hdim = model.hparams.lstm_hidden_size;
+
+    struct ggml_tensor * x_t = ggml_transpose(ctx0, cur);
+
+    // Create operations using the input-to-hidden weights.
+    struct ggml_tensor * inp_gate = ggml_mul_mat(ctx0, model.lstm_ih_weight, x_t);
+    inp_gate = ggml_add(ctx0, inp_gate, model.lstm_ih_bias);
+
+    // Create operations using the hidden-to-hidden weights.
+    struct ggml_tensor * hid_gate = ggml_mul_mat(ctx0, model.lstm_hh_weight, vctx.h_state);
+    hid_gate = ggml_add(ctx0, hid_gate, model.lstm_hh_bias);
+
+    // Create add operation to get preactivations for all gates.
+    struct ggml_tensor * out_gate = ggml_add(ctx0, inp_gate, hid_gate);
+
+    const size_t hdim_size = ggml_row_size(out_gate->type, hdim);
+
+    // Create sigmoid for input gate (using the first 128 bytes from the preactivations).
+    struct ggml_tensor * i_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 0 * hdim_size));
+
+    // Create sigmoid for the forget gate (using the second 128 bytes from the preactivations).
+    struct ggml_tensor * f_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 1 * hdim_size));
+
+    // Create sigmoid for the cell gate (using the third 128 bytes from the preactivations).
+    struct ggml_tensor * g_t = ggml_tanh(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 2 * hdim_size));
+
+    // Create sigmoid for the output gate (using the fourth 128 bytes from the preactivations).
+    struct ggml_tensor * o_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 3 * hdim_size));
+
+    // Update cell state
+    struct ggml_tensor * c_out = ggml_add(ctx0,
+        ggml_mul(ctx0, f_t, vctx.c_state),
+        ggml_mul(ctx0, i_t, g_t));
+    ggml_build_forward_expand(gf, ggml_cpy(ctx0, c_out, vctx.c_state));
+
+    // Update hidden state
+    struct ggml_tensor * out = ggml_mul(ctx0, o_t, ggml_tanh(ctx0, c_out));
+    ggml_build_forward_expand(gf, ggml_cpy(ctx0, out,   vctx.h_state));
+
+    return out;
+}
+
+static struct ggml_cgraph * whisper_vad_build_graph(whisper_vad_context & vctx) {
+    const auto & model = vctx.model;
+
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ vctx.sched.meta.size(),
+        /*.mem_buffer =*/ vctx.sched.meta.data(),
+        /*.no_alloc   =*/ true,
+    };
+
+    struct ggml_context * ctx0 = ggml_init(params);
+
+    ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+    struct ggml_tensor * frame = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, vctx.n_window, 1);
+    ggml_set_name(frame, "frame");
+    ggml_set_input(frame);
+
+    struct ggml_tensor * cur = nullptr;
+    {
+        cur = whisper_vad_build_stft_layer(ctx0, model, frame);
+
+        cur = whisper_vad_build_encoder_layer(ctx0, model, cur);
+
+        // Extract the first element of the first dimension
+        // (equivalent to pytorch's [:, :, 0])
+        cur = ggml_view_2d(ctx0, cur, 1, 128, cur->nb[1], 0);
+
+        cur = whisper_vad_build_lstm_layer(ctx0, vctx, cur, gf);
+        cur = ggml_relu(ctx0, cur);
+        cur = ggml_conv_1d(ctx0, model.final_conv_weight, cur, 1, 0, 1);
+        cur = ggml_add(ctx0, cur, model.final_conv_bias);
+        cur = ggml_sigmoid(ctx0, cur);
+        ggml_set_name(cur, "prob");
+        ggml_set_output(cur);
+    }
+
+    ggml_build_forward_expand(gf, cur);
+
+    ggml_free(ctx0);
+
+    return gf;
+}
+
+static bool whisper_vad_init_context(whisper_vad_context * vctx) {
+
+    auto whisper_context_params = whisper_context_default_params();
+    // TODO: GPU VAD is forced disabled until the performance is improved
+    //whisper_context_params.use_gpu    = vctx->params.use_gpu;
+    whisper_context_params.use_gpu    = false;
+    whisper_context_params.gpu_device = vctx->params.gpu_device;
+
+    vctx->backends = whisper_backend_init(whisper_context_params);
+    if (vctx->backends.empty()) {
+        WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
+        return false;
+    }
+
+    const int32_t lstm_hidden_size = vctx->model.hparams.lstm_hidden_size;
+
+    vctx->ctx_buf.resize(2u*ggml_tensor_overhead());
+
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ vctx->ctx_buf.size(),
+        /*.mem_buffer =*/ vctx->ctx_buf.data(),
+        /*.no_alloc   =*/ true,
+    };
+
+    ggml_context * ctx = ggml_init(params);
+    if (!ctx) {
+        WHISPER_LOG_ERROR("%s: failed to init LSTM state ggml context\n", __func__);
+        return false;
+    }
+
+    // LSTM Hidden state
+    vctx->h_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, lstm_hidden_size);
+    ggml_set_name(vctx->h_state, "h_state");
+
+    // LSTM Cell state
+    vctx->c_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, lstm_hidden_size);
+    ggml_set_name(vctx->c_state, "c_state");
+
+    vctx->buffer = ggml_backend_alloc_ctx_tensors(ctx, vctx->backends[0]);
+    if (!vctx->buffer) {
+        WHISPER_LOG_ERROR("%s: failed to allocate memory for the VAD state\n", __func__);
+        return false;
+    }
+
+    {
+        bool ok = whisper_sched_graph_init(vctx->sched, vctx->backends,
+                [&]() {
+                    return whisper_vad_build_graph(*vctx);
+                });
+
+        if (!ok) {
+            WHISPER_LOG_ERROR("%s: failed to init VAD allocator\n", __func__);
+            return false;
+        }
+
+        WHISPER_LOG_INFO("%s: compute buffer (VAD)   = %7.2f MB\n", __func__, whisper_sched_size(vctx->sched) / 1e6);
+    }
+
+    return true;
+}
+
+struct whisper_vad_context * whisper_vad_init_from_file_with_params(
+        const char * path_model,
+        struct whisper_vad_context_params params) {
+    WHISPER_LOG_INFO("%s: loading VAD model from '%s'\n", __func__, path_model);
+#ifdef _MSC_VER
+    std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
+    std::wstring path_model_wide = converter.from_bytes(path_model);
+    auto fin = std::ifstream(path_model_wide, std::ios::binary);
+#else
+    auto fin = std::ifstream(path_model, std::ios::binary);
+#endif
+    if (!fin) {
+        WHISPER_LOG_ERROR("%s: failed to open VAD model '%s'\n", __func__, path_model);
+        return nullptr;
+    }
+
+    whisper_model_loader loader = {};
+    loader.context = &fin;
+
+    loader.read = [](void * ctx, void * output, size_t read_size) {
+        std::ifstream * fin = (std::ifstream*)ctx;
+        fin->read((char *)output, read_size);
+        return read_size;
+    };
+
+    loader.eof = [](void * ctx) {
+        std::ifstream * fin = (std::ifstream*)ctx;
+        return fin->eof();
+    };
+
+    loader.close = [](void * ctx) {
+        std::ifstream * fin = (std::ifstream*)ctx;
+        fin->close();
+    };
+
+    auto ctx = whisper_vad_init_with_params(&loader, params);
+    if (!ctx) {
+        whisper_vad_free(ctx);
+        return nullptr;
+    }
+    ctx->path_model = path_model;
+    return ctx;
+}
+
+struct whisper_vad_context * whisper_vad_init_with_params(
+            struct whisper_model_loader * loader,
+            struct whisper_vad_context_params params) {
+    // Read the VAD model
+    {
+        uint32_t magic;
+        read_safe(loader, magic);
+        if (magic != GGML_FILE_MAGIC) {
+            WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__);
+            return nullptr;
+        }
+    }
+
+    whisper_vad_context * vctx = new whisper_vad_context;
+    vctx->n_threads = params.n_threads;
+    vctx->params.use_gpu = params.use_gpu;
+    vctx->params.gpu_device = params.gpu_device;
+
+    auto & model = vctx->model;
+    auto & hparams = model.hparams;
+
+    // load model context params.
+    {
+        int32_t str_len;
+        read_safe(loader, str_len);
+        std::vector<char> buffer(str_len + 1, 0);
+        loader->read(loader->context, buffer.data(), str_len);
+        std::string model_type(buffer.data(), str_len);
+        model.type = model_type;
+        WHISPER_LOG_INFO("%s: model type: %s\n", __func__, model.type.c_str());
+
+        int32_t major, minor, patch;
+        read_safe(loader, major);
+        read_safe(loader, minor);
+        read_safe(loader, patch);
+        std::string version_str = std::to_string(major) + "." +
+                                  std::to_string(minor) + "." +
+                                  std::to_string(patch);
+        model.version = version_str;
+        WHISPER_LOG_INFO("%s: model version: %s\n", __func__, model.version.c_str());
+
+        read_safe(loader, vctx->n_window);
+        read_safe(loader, vctx->n_context);
+    }
+
+    // load model hyper params (hparams).
+    {
+        read_safe(loader, hparams.n_encoder_layers);
+
+        hparams.encoder_in_channels = new int32_t[hparams.n_encoder_layers];
+        hparams.encoder_out_channels = new int32_t[hparams.n_encoder_layers];
+        hparams.kernel_sizes = new int32_t[hparams.n_encoder_layers];
+
+        for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
+            read_safe(loader, hparams.encoder_in_channels[i]);
+            read_safe(loader, hparams.encoder_out_channels[i]);
+            read_safe(loader, hparams.kernel_sizes[i]);
+        }
+
+        read_safe(loader, hparams.lstm_input_size);
+        read_safe(loader, hparams.lstm_hidden_size);
+        read_safe(loader, hparams.final_conv_in);
+        read_safe(loader, hparams.final_conv_out);
+
+        WHISPER_LOG_INFO("%s: n_encoder_layers = %d\n", __func__, hparams.n_encoder_layers);
+        for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
+            WHISPER_LOG_INFO("%s: encoder_in_channels[%d] = %d\n", __func__, i, hparams.encoder_in_channels[i]);
+        }
+        for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
+            WHISPER_LOG_INFO("%s: encoder_out_channels[%d] = %d\n", __func__, i, hparams.encoder_out_channels[i]);
+        }
+        WHISPER_LOG_INFO("%s: lstm_input_size = %d\n", __func__, hparams.lstm_input_size);
+        WHISPER_LOG_INFO("%s: lstm_hidden_size = %d\n", __func__, hparams.lstm_hidden_size);
+        WHISPER_LOG_INFO("%s: final_conv_in = %d\n", __func__, hparams.final_conv_in);
+        WHISPER_LOG_INFO("%s: final_conv_out = %d\n", __func__, hparams.final_conv_out);
+    }
+
+    // 1 STFT tensor, 4*2 encoder tensors, 4 LSTM tensors, 2 final output tensors
+    const size_t n_tensors = hparams.n_encoder_layers * 2 + 4 + 2 + 1;
+
+    std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
+    auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
+        auto it = ctx_map.find(buft);
+        if (it == ctx_map.end()) {
+            ggml_init_params params = {
+                /*.mem_size   =*/ n_tensors * ggml_tensor_overhead(),
+                /*.mem_buffer =*/ nullptr,
+                /*.no_alloc   =*/ true,
+            };
+
+            ggml_context * ctx = ggml_init(params);
+            if (!ctx) {
+                throw std::runtime_error("failed to create ggml context");
+            }
+
+            ctx_map[buft] = ctx;
+            model.ctxs.emplace_back(ctx);
+
+            return ctx;
+        }
+
+        return it->second;
+    };
+
+    whisper_context_params wparams = whisper_context_default_params();
+    wparams.use_gpu = params.use_gpu;
+    wparams.gpu_device = params.gpu_device;
+    buft_list_t buft_list = make_buft_list(wparams);
+
+    auto create_tensor = [&](vad_tensor type, ggml_tensor * meta) -> ggml_tensor * {
+        ggml_op op = VAD_TENSOR_OPS.at(type);
+        ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
+        if (!buft) {
+            throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", VAD_TENSOR_NAMES.at(type)));
+        }
+        ggml_context * ctx = get_ctx(buft);
+        ggml_tensor * tensor = ggml_dup_tensor(ctx, meta);
+        model.tensors[VAD_TENSOR_NAMES.at(type)] = tensor;
+
+        return tensor;
+    };
+
+    // create tensors
+    {
+        ggml_init_params params = {
+            /*.mem_size   =*/ n_tensors * ggml_tensor_overhead(),
+            /*.mem_buffer =*/ nullptr,
+            /*.no_alloc   =*/ true,
+        };
+
+        ggml_context * ctx = ggml_init(params);
+        const auto & hparams = model.hparams;
+
+        // SFTF precomputed basis matrix
+        model.stft_forward_basis = create_tensor(VAD_TENSOR_STFT_BASIS,
+            ggml_new_tensor_3d(ctx, GGML_TYPE_F16, 256, 1, 258));
+
+        model.encoder_0_weight = create_tensor(VAD_TENSOR_ENC_0_WEIGHT,
+            ggml_new_tensor_3d(
+                ctx,
+                GGML_TYPE_F16,
+                hparams.kernel_sizes[0],
+                hparams.encoder_in_channels[0],
+                hparams.encoder_out_channels[0]
+        ));
+        model.encoder_0_bias = create_tensor(VAD_TENSOR_ENC_0_BIAS,
+            ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[0]));
+
+        model.encoder_1_weight = create_tensor(VAD_TENSOR_ENC_1_WEIGHT,
+            ggml_new_tensor_3d(
+                ctx,
+                GGML_TYPE_F16,
+                hparams.kernel_sizes[1],
+                hparams.encoder_in_channels[1],
+                hparams.encoder_out_channels[1]
+        ));
+        model.encoder_1_bias = create_tensor(VAD_TENSOR_ENC_1_BIAS,
+            ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[1]));
+
+        model.encoder_2_weight = create_tensor(VAD_TENSOR_ENC_2_WEIGHT,
+            ggml_new_tensor_3d(
+                ctx,
+                GGML_TYPE_F16,
+                hparams.kernel_sizes[2],
+                hparams.encoder_in_channels[2],
+                hparams.encoder_out_channels[2]
+        ));
+        model.encoder_2_bias = create_tensor(VAD_TENSOR_ENC_2_BIAS,
+            ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[2]));
+
+        model.encoder_3_weight = create_tensor(VAD_TENSOR_ENC_3_WEIGHT,
+            ggml_new_tensor_3d(
+                ctx,
+                GGML_TYPE_F16,
+                hparams.kernel_sizes[3],
+                hparams.encoder_in_channels[3],
+                hparams.encoder_out_channels[3]
+        ));
+        model.encoder_3_bias = create_tensor(VAD_TENSOR_ENC_3_BIAS,
+                ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[3]));
+
+        // Hidden State dimension (input gate, forget gate, cell gate, output gate)
+        const int hstate_dim = hparams.lstm_hidden_size * 4;
+
+        // LSTM weights - input to hidden
+        model.lstm_ih_weight = create_tensor(
+            VAD_TENSOR_LSTM_WEIGHT_IH,
+            ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim)
+        );
+        model.lstm_ih_bias = create_tensor(
+            VAD_TENSOR_LSTM_BIAS_IH,
+            ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim)
+        );
+
+        // LSTM weights - hidden to hidden
+        model.lstm_hh_weight = create_tensor(
+            VAD_TENSOR_LSTM_WEIGHT_HH,
+            ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim)
+        );
+        model.lstm_hh_bias = create_tensor(
+            VAD_TENSOR_LSTM_BIAS_HH,
+            ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim)
+        );
+
+        // Final conv layer weight
+        model.final_conv_weight = create_tensor(
+            VAD_TENSOR_FINAL_CONV_WEIGHT,
+            ggml_new_tensor_2d(ctx, GGML_TYPE_F16, hparams.final_conv_in, 1)
+        );
+        model.final_conv_bias = create_tensor(
+            VAD_TENSOR_FINAL_CONV_BIAS,
+            ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1)
+        );
+
+        ggml_free(ctx);
+    }
+
+    // allocate tensors in the backend buffers
+    for (auto & p : ctx_map) {
+        ggml_backend_buffer_type_t buft = p.first;
+        ggml_context * ctx = p.second;
+        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
+        if (buf) {
+            model.buffers.emplace_back(buf);
+
+            size_t size_main = ggml_backend_buffer_get_size(buf);
+            WHISPER_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6);
+        }
+    }
+
+    // load weights
+    {
+        size_t total_size = 0;
+        model.n_loaded = 0;
+        std::vector<char> read_buf;
+
+        while (true) {
+            int32_t n_dims;
+            int32_t length;
+            int32_t ttype;
+
+            read_safe(loader, n_dims);
+            read_safe(loader, length);
+            read_safe(loader, ttype);
+
+            if (loader->eof(loader->context)) {
+                break;
+            }
+
+            int32_t nelements = 1;
+            int32_t ne[4] = { 1, 1, 1, 1 };
+            for (int i = 0; i < n_dims; ++i) {
+                read_safe(loader, ne[i]);
+                nelements *= ne[i];
+            }
+
+            std::string name;
+            std::vector<char> tmp(length);
+            loader->read(loader->context, &tmp[0], tmp.size());
+            name.assign(&tmp[0], tmp.size());
+
+            if (model.tensors.find(name) == model.tensors.end()) {
+                WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data());
+                return nullptr;
+            }
+
+            auto tensor = model.tensors[name.data()];
+
+            if (ggml_nelements(tensor) != nelements) {
+                WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
+                WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
+                        __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
+                return nullptr;
+            }
+
+            if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
+                WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
+                        __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
+                return nullptr;
+            }
+
+            const size_t bpe = ggml_type_size(ggml_type(ttype));
+
+            if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
+                WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
+                        __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
+                return nullptr;
+            }
+
+            if (ggml_backend_buffer_is_host(tensor->buffer)) {
+                // for the CPU and Metal backend, we can read directly into the tensor
+                loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
+                BYTESWAP_TENSOR(tensor);
+            } else {
+                // read into a temporary buffer first, then copy to device memory
+                read_buf.resize(ggml_nbytes(tensor));
+
+                loader->read(loader->context, read_buf.data(), read_buf.size());
+
+                ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
+            }
+
+            total_size += ggml_nbytes(tensor);
+            model.n_loaded++;
+        }
+
+        WHISPER_LOG_INFO("%s: model size    = %7.2f MB\n", __func__, total_size/1e6);
+
+        if (model.n_loaded == 0) {
+            WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
+        } else if (model.n_loaded != (int) model.tensors.size()) {
+            WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
+            return nullptr;
+        }
+
+    }
+
+    if (!whisper_vad_init_context(vctx)) {
+        whisper_vad_free(vctx);
+        return nullptr;
+    }
+
+    return vctx;
+}
+
+bool whisper_vad_detect_speech(
+        struct whisper_vad_context * vctx,
+        const float * samples,
+        int n_samples) {
+    int n_chunks = n_samples / vctx->n_window;
+    if (n_samples % vctx->n_window != 0) {
+        n_chunks += 1;  // Add one more chunk for remaining samples.
+    }
+
+    WHISPER_LOG_INFO("%s: detecting speech in %d samples\n", __func__, n_samples);
+    WHISPER_LOG_INFO("%s: n_chunks: %d\n", __func__, n_chunks);
+
+    // Reset LSTM hidden/cell states
+    ggml_backend_buffer_clear(vctx->buffer, 0);
+
+    vctx->probs.resize(n_chunks);
+    WHISPER_LOG_INFO("%s: props size: %u\n", __func__, n_chunks);
+
+    std::vector<float> window(vctx->n_window, 0.0f);
+
+    auto & sched = vctx->sched.sched;
+
+    ggml_cgraph * gf = whisper_vad_build_graph(*vctx);
+
+    if (!ggml_backend_sched_alloc_graph(sched, gf)) {
+        WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
+        return false;
+    }
+
+    struct ggml_tensor * frame = ggml_graph_get_tensor(gf, "frame");
+    struct ggml_tensor * prob  = ggml_graph_get_tensor(gf, "prob");
+
+    // we are going to reuse the graph multiple times for each chunk
+    const int64_t t_start_vad_us = ggml_time_us();
+
+    for (int i = 0; i < n_chunks; i++) {
+        const int idx_start = i * vctx->n_window;
+        const int idx_end = std::min(idx_start + vctx->n_window, n_samples);
+
+        const int chunk_len = idx_end - idx_start;
+
+        if (chunk_len < vctx->n_window) {
+            WHISPER_LOG_INFO("%s: chunk_len: %d < n_window: %d\n", __func__, chunk_len, vctx->n_window);
+            std::vector<float> partial_chunk(vctx->n_window, 0.0f);
+            std::copy(samples + idx_start, samples + idx_end, partial_chunk.begin());
+
+            // Copy the zero-padded chunk to the window.
+            const int samples_to_copy_max = vctx->n_window;
+            const int samples_to_copy_cur = std::min(samples_to_copy_max, (int)partial_chunk.size());
+            std::copy(partial_chunk.begin(), partial_chunk.begin() + samples_to_copy_cur, window.begin());
+            if (samples_to_copy_cur < samples_to_copy_max) {
+                std::fill(window.begin() + samples_to_copy_cur, window.end(), 0.0f);
+            }
+        } else {
+            // Copy current frame samples to the window.
+            const int samples_to_copy = std::min(idx_end - idx_start, vctx->n_window);
+            std::copy(samples + idx_start, samples + idx_start + samples_to_copy, window.begin());
+        }
+
+        // Set the frame tensor data with the samples.
+        ggml_backend_tensor_set(frame, window.data(), 0, ggml_nelements(frame) * sizeof(float));
+
+        // do not reset the scheduler - we will reuse the graph in the next chunk
+        if (!ggml_graph_compute_helper(sched, gf, vctx->n_threads, false)) {
+            WHISPER_LOG_ERROR("%s: failed to compute VAD graph\n", __func__);
+            break;
+        }
+
+        // Get the probability for this chunk.
+        ggml_backend_tensor_get(prob, &vctx->probs[i], 0, sizeof(float));
+
+        //WHISPER_LOG_DEBUG("chunk %d: p = %7.3f\n", i, probs[i]);
+    }
+
+    vctx->t_vad_us += ggml_time_us() - t_start_vad_us;
+    WHISPER_LOG_INFO("%s: vad time = %.2f ms processing %d samples\n", __func__, 1e-3f * vctx->t_vad_us, n_samples);
+
+    ggml_backend_sched_reset(sched);
+
+    return true;
+}
+
+int whisper_vad_segments_n_segments(struct whisper_vad_segments * segments) {
+    return segments->data.size();
+}
+
+float whisper_vad_segments_get_segment_t0(struct whisper_vad_segments * segments, int i_segment) {
+    return segments->data[i_segment].start;
+}
+
+float whisper_vad_segments_get_segment_t1(struct whisper_vad_segments * segments, int i_segment) {
+    return segments->data[i_segment].end;
+}
+
+int whisper_vad_n_probs(struct whisper_vad_context * vctx) {
+    return vctx->probs.size();
+}
+
+float * whisper_vad_probs(struct whisper_vad_context * vctx) {
+    return vctx->probs.data();
+}
+
+struct whisper_vad_segments * whisper_vad_segments_from_probs(
+        struct whisper_vad_context *  vctx,
+                whisper_vad_params    params) {
+    WHISPER_LOG_INFO("%s: detecting speech timestamps using %d probabilities\n", __func__, whisper_vad_n_probs(vctx));
+
+    int     n_probs                 = whisper_vad_n_probs(vctx);
+    float * probs                   = whisper_vad_probs(vctx);
+    float   threshold               = params.threshold;
+    int     min_speech_duration_ms  = params.min_speech_duration_ms;
+    int     min_silence_duration_ms = params.min_silence_duration_ms;
+    float   max_speech_duration_s   = params.max_speech_duration_s;
+    int     speech_pad_ms           = params.speech_pad_ms;
+    int     n_window                = vctx->n_window;
+    int     sample_rate             = WHISPER_SAMPLE_RATE;
+    int     min_silence_samples     = sample_rate * min_silence_duration_ms / 1000;
+    int     audio_length_samples    = n_probs * n_window;
+
+    // Min number of samples to be considered valid speech.
+    int     min_speech_samples      = sample_rate * min_speech_duration_ms / 1000;
+    int     speech_pad_samples      = sample_rate * speech_pad_ms / 1000;
+
+    // Max number of samples that a speech segment can contain before it is
+    // split into multiple segments.
+    int max_speech_samples;
+    if (max_speech_duration_s > 100000.0f) {
+        max_speech_samples = INT_MAX / 2;
+    } else {
+        int64_t temp = (int64_t)sample_rate * (int64_t)(max_speech_duration_s) - n_window - 2 * speech_pad_samples;
+        max_speech_samples = (temp > INT_MAX) ? INT_MAX / 2 : (int)temp;
+        if (max_speech_samples < 0) {
+            max_speech_samples = INT_MAX / 2;
+        }
+    }
+    // Detect silence period that exceeds this value, then that location (sample)
+    // is marked as a potential place where the segment could be split if
+    // max_speech_samples is reached. The value 98 was taken from the original
+    // silaro-vad python implementation:
+    //https://github.com/snakers4/silero-vad/blob/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/src/silero_vad/utils_vad.py#L291
+    int min_silence_samples_at_max_speech = sample_rate * 98 / 1000;
+
+    // Calculate lower threshold for detecting end of speech segments.
+    float neg_threshold = threshold - 0.15f;
+    if (neg_threshold < 0.01f) {
+        neg_threshold = 0.01f;
+    }
+
+    struct speech_segment_t {
+        int start;
+        int end;
+    };
+
+    std::vector<speech_segment_t> speeches;
+    speeches.reserve(256);
+
+    bool is_speech_segment = false;
+    int  temp_end          = 0;
+    int  prev_end          = 0;
+    int  next_start        = 0;
+    int  curr_speech_start = 0;
+    bool has_curr_speech   = false;
+
+    for (int i = 0; i < n_probs; i++) {
+        float curr_prob   = probs[i];
+        int   curr_sample = n_window * i;
+
+        // Reset temp_end when we get back to speech
+        if ((curr_prob >= threshold) && temp_end) {
+            temp_end = 0;
+            if (next_start < prev_end) {
+                next_start = curr_sample;
+            }
+        }
+
+        // Start a new speech segment when probability exceeds threshold and not already in speech
+        if ((curr_prob >= threshold) && !is_speech_segment) {
+            is_speech_segment = true;
+            curr_speech_start = curr_sample;
+            has_curr_speech = true;
+            continue;
+        }
+
+        // Handle maximum speech duration
+        if (is_speech_segment && (curr_sample - curr_speech_start) > max_speech_samples) {
+            if (prev_end) {
+                speeches.push_back({ curr_speech_start, prev_end });
+                has_curr_speech = true;
+
+                if (next_start < prev_end) {  // Previously reached silence and is still not speech
+                    is_speech_segment = false;
+                    has_curr_speech = false;
+                } else {
+                    curr_speech_start = next_start;
+                }
+                prev_end = next_start = temp_end = 0;
+            } else {
+                speeches.push_back({ curr_speech_start, curr_sample });
+
+                prev_end = next_start = temp_end = 0;
+                is_speech_segment = false;
+                has_curr_speech = false;
+                continue;
+            }
+        }
+
+        // Handle silence after speech
+        if ((curr_prob < neg_threshold) && is_speech_segment) {
+            if (!temp_end) {
+                temp_end = curr_sample;
+            }
+
+            // Track potential segment ends for max_speech handling
+            if ((curr_sample - temp_end) > min_silence_samples_at_max_speech) {
+                prev_end = temp_end;
+            }
+
+            // Check if silence is long enough to end the segment
+            if ((curr_sample - temp_end) < min_silence_samples) {
+                continue;
+            } else {
+                // End the segment if it's long enough
+                if ((temp_end - curr_speech_start) > min_speech_samples) {
+                    speeches.push_back({ curr_speech_start, temp_end });
+                }
+
+                prev_end = next_start = temp_end = 0;
+                is_speech_segment = false;
+                has_curr_speech = false;
+                continue;
+            }
+        }
+    }
+
+    // Handle the case if we're still in a speech segment at the end
+    if (has_curr_speech && (audio_length_samples - curr_speech_start) > min_speech_samples) {
+        speeches.push_back({ curr_speech_start, audio_length_samples });
+    }
+
+    // Merge adjacent segments with small gaps in between (post-processing)
+    if (speeches.size() > 1) {
+        int merged_count = 0;
+        for (int i = 0; i < (int) speeches.size() - 1; i++) {
+            // Define maximum gap allowed for merging (e.g., 200ms converted to samples)
+            int max_merge_gap_samples = sample_rate * 200 / 1000;
+
+            // If the gap between this segment and the next is small enough
+            if (speeches[i+1].start - speeches[i].end < max_merge_gap_samples) {
+                // Merge by extending current segment to the end of next segment
+                speeches[i].end = speeches[i+1].end;
+                speeches.erase(speeches.begin() + i + 1);
+
+                i--;
+                merged_count++;
+            }
+        }
+        WHISPER_LOG_INFO("%s: Merged %d adjacent segments, now have %d segments\n",
+                         __func__, merged_count, (int) speeches.size());
+    }
+
+    // Double-check for minimum speech duration
+    for (int i = 0; i < (int) speeches.size(); i++) {
+        if (speeches[i].end - speeches[i].start < min_speech_samples) {
+            WHISPER_LOG_INFO("%s: Removing segment %d (too short: %d samples)\n",
+                            __func__, i, speeches[i].end - speeches[i].start);
+
+            speeches.erase(speeches.begin() + i);
+            i--;
+        }
+    }
+
+    WHISPER_LOG_INFO("%s: Final speech segments after filtering: %d\n", __func__, (int) speeches.size());
+
+    // Allocate final segments
+    std::vector<whisper_vad_segment> segments;
+    if (speeches.size() > 0) {
+        try {
+            segments.resize(speeches.size());
+        } catch (const std::bad_alloc &) {
+            WHISPER_LOG_ERROR("%s: failed to allocate memory for final segments\n", __func__);
+            return nullptr;
+        }
+    }
+
+    // Apply padding to segments and copy to final segments
+    for (int i = 0; i < (int) speeches.size(); i++) {
+        // Apply padding to the start of the first segment
+        if (i == 0) {
+            speeches[i].start =
+                (speeches[i].start > speech_pad_samples) ?
+                (speeches[i].start - speech_pad_samples) : 0;
+        }
+
+        // Handle spacing between segments
+        if (i < (int) speeches.size() - 1) {
+            int silence_duration = speeches[i+1].start - speeches[i].end;
+
+            if (silence_duration < 2 * speech_pad_samples) {
+                // If segments are close, split the difference
+                speeches[i].end += silence_duration / 2;
+                speeches[i+1].start =
+                    (speeches[i+1].start > silence_duration / 2) ?
+                    (speeches[i+1].start - silence_duration / 2) : 0;
+            } else {
+                // Otherwise, apply full padding to both
+                speeches[i].end =
+                    (speeches[i].end + speech_pad_samples < audio_length_samples) ?
+                    (speeches[i].end + speech_pad_samples) : audio_length_samples;
+                speeches[i+1].start =
+                    (speeches[i+1].start > speech_pad_samples) ?
+                    (speeches[i+1].start - speech_pad_samples) : 0;
+            }
+        } else {
+            // Apply padding to the end of the last segment
+            speeches[i].end =
+                (speeches[i].end + speech_pad_samples < audio_length_samples) ?
+                (speeches[i].end + speech_pad_samples) : audio_length_samples;
+        }
+
+        // Convert from samples to seconds and copy to final segments
+        segments[i].start = (float)speeches[i].start / sample_rate;
+        segments[i].end   = (float)speeches[i].end   / sample_rate;
+
+        WHISPER_LOG_INFO("%s: VAD segment %d: start = %.2f, end = %.2f (duration: %.2f)\n",
+                        __func__, i, segments[i].start, segments[i].end, segments[i].end - segments[i].start);
+    }
+
+    whisper_vad_segments * vad_segments = new whisper_vad_segments;
+    if (vad_segments == NULL) {
+        WHISPER_LOG_ERROR("%s: failed to allocate memory for whisper_vad_segments\n", __func__);
+        return nullptr;
+    }
+
+    vad_segments->data = std::move(segments);
+
+    return vad_segments;
+}
+
+struct whisper_vad_segments * whisper_vad_segments_from_samples(
+        whisper_vad_context * vctx,
+        whisper_vad_params params,
+        const float * samples,
+        int n_samples) {
+    WHISPER_LOG_INFO("%s: detecting speech timestamps in %d samples\n", __func__, n_samples);
+    if (!whisper_vad_detect_speech(vctx, samples, n_samples)) {
+        WHISPER_LOG_ERROR("%s: failed to detect speech\n", __func__);
+        return nullptr;
+    }
+    return whisper_vad_segments_from_probs(vctx, params);
+}
+
+void whisper_vad_free(whisper_vad_context * ctx) {
+    if (ctx) {
+        for (ggml_context * context : ctx->model.ctxs) {
+            ggml_free(context);
+        }
+
+        for (ggml_backend_buffer_t buf : ctx->model.buffers) {
+            ggml_backend_buffer_free(buf);
+        }
+
+        ggml_backend_sched_free(ctx->sched.sched);
+
+        for (auto & backend : ctx->backends) {
+            ggml_backend_free(backend);
+        }
+
+
+        delete ctx;
+    }
+}
+
+void whisper_vad_free_segments(whisper_vad_segments * segments) {
+    if (segments) {
+        delete segments;
+    }
+}
+
+//////////////////////////////////
+// Grammar - ported from llama.cpp
+//////////////////////////////////
+
+// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
+// pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
+static std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
+        const char         * src,
+        whisper_partial_utf8   partial_start) {
+    static const int      lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
+    const char          * pos      = src;
+    std::vector<uint32_t> code_points;
+    uint32_t              value    = partial_start.value;
+    int                   n_remain = partial_start.n_remain;
+
+    // continue previous decode, if applicable
+    while (*pos != 0 && n_remain > 0) {
+        uint8_t next_byte = static_cast<uint8_t>(*pos);
+        if ((next_byte >> 6) != 2) {
+            // invalid sequence, abort
+            code_points.push_back(0);
+            return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
+        }
+        value = (value << 6) + (next_byte & 0x3F);
+        ++pos;
+        --n_remain;
+    }
+
+    if (partial_start.n_remain > 0 && n_remain == 0) {
+        code_points.push_back(value);
+    }
+
+    // decode any subsequent utf-8 sequences, which may end in an incomplete one
+    while (*pos != 0) {
+        uint8_t  first_byte = static_cast<uint8_t>(*pos);
+        uint8_t  highbits   = first_byte >> 4;
+                 n_remain   = lookup[highbits] - 1;
+
+        if (n_remain < 0) {
+            // invalid sequence, abort
+            code_points.clear();
+            code_points.push_back(0);
+            return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain });
+        }
+
+        uint8_t  mask       = (1 << (7 - n_remain)) - 1;
+                 value      = first_byte & mask;
+        ++pos;
+        while (*pos != 0 && n_remain > 0) {
+            value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
+            ++pos;
+            --n_remain;
+        }
+        if (n_remain == 0) {
+            code_points.push_back(value);
+        }
+    }
+    code_points.push_back(0);
+
+    return std::make_pair(std::move(code_points), whisper_partial_utf8{ value, n_remain });
+}
+
+// returns true iff pos points to the end of one of the definitions of a rule
+static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element * pos) {
+    switch (pos->type) {
+        case WHISPER_GRETYPE_END: return true;  // NOLINT
+        case WHISPER_GRETYPE_ALT: return true;  // NOLINT
+        default:                return false;
+    }
+}
+
+// returns true iff chr satisfies the char range at pos (regular or inverse range)
+// asserts that pos is pointing to a char range element
+static std::pair<bool, const whisper_grammar_element *> whisper_grammar_match_char(
+        const whisper_grammar_element * pos,
+        const uint32_t                chr) {
+
+    bool found            = false;
+    bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
+
+    WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT
+
+    do {
+        if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
+            // inclusive range, e.g. [a-z]
+            found = found || (pos->value <= chr && chr <= pos[1].value);
+            pos += 2;
+        } else {
+            // exact char match, e.g. [a] or "a"
+            found = found || pos->value == chr;
+            pos += 1;
+        }
+    } while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
+
+    return std::make_pair(found == is_positive_char, pos);
+}
+
+// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
+// range at pos (regular or inverse range)
+// asserts that pos is pointing to a char range element
+static bool whisper_grammar_match_partial_char(
+        const whisper_grammar_element * pos,
+        const whisper_partial_utf8      partial_utf8) {
+
+    bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
+    WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT);
+
+    uint32_t partial_value = partial_utf8.value;
+    int      n_remain      = partial_utf8.n_remain;
+
+    // invalid sequence or 7-bit char split across 2 bytes (overlong)
+    if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
+        return false;
+    }
+
+    // range of possible code points this partial UTF-8 sequence could complete to
+    uint32_t low  = partial_value << (n_remain * 6);
+    uint32_t high = low | ((1 << (n_remain * 6)) - 1);
+
+    if (low == 0) {
+        if (n_remain == 2) {
+            low = 1 << 11;
+        } else if (n_remain == 3) {
+            low = 1 << 16;
+        }
+    }
+
+    do {
+        if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
+            // inclusive range, e.g. [a-z]
+            if (pos->value <= high && low <= pos[1].value) {
+                return is_positive_char;
+            }
+            pos += 2;
+        } else {
+            // exact char match, e.g. [a] or "a"
+            if (low <= pos->value && pos->value <= high) {
+                return is_positive_char;
+            }
+            pos += 1;
+        }
+    } while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
+
+    return !is_positive_char;
+}
+
+
+// transforms a grammar pushdown stack into N possible stacks, all ending
+// at a character range (terminal element)
+static void whisper_grammar_advance_stack(
+        const std::vector<std::vector<whisper_grammar_element>>   & rules,
+        const std::vector<const whisper_grammar_element *>        & stack,
+        std::vector<std::vector<const whisper_grammar_element *>> & new_stacks) {
+
+    if (stack.empty()) {
+        new_stacks.emplace_back();
+        return;
+    }
+
+    const whisper_grammar_element * pos = stack.back();
+
+    switch (pos->type) {
+        case WHISPER_GRETYPE_RULE_REF: {
+            const size_t                  rule_id = static_cast<size_t>(pos->value);
+            const whisper_grammar_element * subpos  = rules[rule_id].data();
+            do {
+                // init new stack without the top (pos)
+                std::vector<const whisper_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
+                if (!whisper_grammar_is_end_of_sequence(pos + 1)) {
+                    // if this rule ref is followed by another element, add that to stack
+                    new_stack.push_back(pos + 1);
+                }
+                if (!whisper_grammar_is_end_of_sequence(subpos)) {
+                    // if alternate is nonempty, add to stack
+                    new_stack.push_back(subpos);
+                }
+                whisper_grammar_advance_stack(rules, new_stack, new_stacks);
+                while (!whisper_grammar_is_end_of_sequence(subpos)) {
+                    // scan to end of alternate def
+                    subpos++;
+                }
+                if (subpos->type == WHISPER_GRETYPE_ALT) {
+                    // there's another alternate def of this rule to process
+                    subpos++;
+                } else {
+                    break;
+                }
+            } while (true);
+            break;
+        }
+        case WHISPER_GRETYPE_CHAR:
+        case WHISPER_GRETYPE_CHAR_NOT:
+            new_stacks.push_back(stack);
+            break;
+        default:
+            // end of alternate (WHISPER_GRETYPE_END, WHISPER_GRETYPE_ALT) or middle of char range
+            // (WHISPER_GRETYPE_CHAR_ALT, WHISPER_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
+            // those
+            WHISPER_ASSERT(false);
+    }
+}
+
+// takes a set of possible pushdown stacks on a grammar, which are required to
+// be positioned at a character range (see `whisper_grammar_advance_stack`), and
+// produces the N possible stacks if the given char is accepted at those
+// positions
+static std::vector<std::vector<const whisper_grammar_element *>> whisper_grammar_accept(
+        const std::vector<std::vector<whisper_grammar_element>>         & rules,
+        const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
+        const uint32_t                                                  chr) {
+
+    std::vector<std::vector<const whisper_grammar_element *>> new_stacks;
+
+    for (const auto & stack : stacks) {
+        if (stack.empty()) {
+            continue;
+        }
+
+        auto match = whisper_grammar_match_char(stack.back(), chr);
+        if (match.first) {
             const whisper_grammar_element * pos = match.second;
 
             // update top of stack to next element, if any
@@ -4856,6 +5981,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
         /*.n_grammar_rules =*/ 0,
         /*.i_start_rule    =*/ 0,
         /*.grammar_penalty =*/ 100.0f,
+
+        /*.vad                         =*/ false,
+        /*.vad_model_path              =*/ nullptr,
+
+        /* vad_params =*/ whisper_vad_default_params(),
     };
 
     switch (strategy) {
@@ -5472,6 +6602,117 @@ static void whisper_sequence_score(
     }
 }
 
+static bool whisper_vad(
+        struct whisper_context * ctx,
+          struct whisper_state * state,
+    struct whisper_full_params   params,
+                   const float * samples,
+                           int   n_samples,
+            std::vector<float> & filtered_samples,
+                           int & filtered_n_samples) {
+    WHISPER_LOG_INFO("%s: VAD is enabled, processing speach segments only\n", __func__);
+    filtered_n_samples = 0;
+
+    struct whisper_vad_context_params vad_ctx_params = whisper_vad_default_context_params();
+    struct whisper_vad_context * vctx = whisper_vad_init_from_file_with_params(params.vad_model_path, vad_ctx_params);
+    if (vctx == nullptr) {
+        WHISPER_LOG_ERROR("%s: failed to initialize VAD context\n", __func__);
+        return false;
+    }
+
+    const whisper_vad_params & vad_params = params.vad_params;
+
+    whisper_vad_segments * vad_segments = whisper_vad_segments_from_samples(vctx, vad_params, samples, n_samples);
+
+    if (vad_segments->data.size() > 0) {
+        state->has_vad_segments = true;
+        ctx->state->vad_segments.clear();
+        ctx->state->vad_segments.reserve(vad_segments->data.size());
+
+        WHISPER_LOG_INFO("%s: detected %d speech segments\n", __func__, (int)vad_segments->data.size());
+        float overlap_seconds = vad_params.samples_overlap;
+        int overlap_samples = overlap_seconds * WHISPER_SAMPLE_RATE;
+
+        for (int i = 0; i < (int)vad_segments->data.size(); i++) {
+            int segment_start_samples = vad_segments->data[i].start * WHISPER_SAMPLE_RATE;
+            int segment_end_samples   = vad_segments->data[i].end   * WHISPER_SAMPLE_RATE;
+
+            if (i < (int)vad_segments->data.size() - 1) {
+                segment_end_samples += overlap_samples;
+            }
+            segment_end_samples = std::min(segment_end_samples, n_samples - 1);
+            filtered_n_samples  += (segment_end_samples - segment_start_samples);
+
+            WHISPER_LOG_INFO("%s: Including segment %d: %.2f - %.2f (duration: %.2f)\n",
+                __func__, i, vad_segments->data[i].start,
+                vad_segments->data[i].end + (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0),
+                (vad_segments->data[i].end - vad_segments->data[i].start) +
+                (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0));
+        }
+
+        int silence_samples = 0.1 * WHISPER_SAMPLE_RATE;
+        int total_silence_samples = (vad_segments->data.size() > 1) ? (vad_segments->data.size() - 1) * silence_samples : 0;
+        int total_samples_needed = filtered_n_samples + total_silence_samples;
+
+        WHISPER_LOG_INFO("%s: total duration of speech segments: %.2f seconds\n",
+                        __func__, (float)filtered_n_samples / WHISPER_SAMPLE_RATE);
+
+        try {
+            filtered_samples.resize(total_samples_needed);
+        } catch (const std::bad_alloc & /* e */) {
+            WHISPER_LOG_ERROR("%s: failed to allocate memory for filtered samples\n", __func__);
+            whisper_vad_free_segments(vad_segments);
+            whisper_vad_free(vctx);
+            return false;
+        }
+
+        int offset = 0;
+        for (int i = 0; i < (int)vad_segments->data.size(); i++) {
+            int segment_start_samples = vad_segments->data[i].start * WHISPER_SAMPLE_RATE;
+            int segment_end_samples   = vad_segments->data[i].end   * WHISPER_SAMPLE_RATE;
+
+            if (i < (int)vad_segments->data.size() - 1) {
+                segment_end_samples += overlap_samples;
+            }
+
+            segment_start_samples = std::min(segment_start_samples, n_samples - 1);
+            segment_end_samples = std::min(segment_end_samples, n_samples);
+            int segment_length = segment_end_samples - segment_start_samples;
+
+            if (segment_length > 0) {
+                whisper_state::vad_segment_info segment;
+
+                segment.orig_start = vad_segments->data[i].start;
+                segment.orig_end   = vad_segments->data[i].end;
+
+                segment.vad_start = offset / (float)WHISPER_SAMPLE_RATE;
+                segment.vad_end   = (offset + segment_length) / (float)WHISPER_SAMPLE_RATE;
+
+                WHISPER_LOG_INFO("%s: vad_segment_info: orig_start: %.2f, orig_end: %.2f, vad_start: %.2f, vad_end: %.2f\n",
+                    __func__, segment.orig_start, segment.orig_end, segment.vad_start, segment.vad_end);
+                ctx->state->vad_segments.push_back(segment);
+
+                // Copy this speech segment
+                memcpy(filtered_samples.data() + offset, samples + segment_start_samples, segment_length * sizeof(float));
+                offset += segment_length;
+
+                // Add silence after this segment (except after the last segment)
+                if (i < (int)vad_segments->data.size() - 1) {
+                    // Fill with zeros (silence)
+                    memset(filtered_samples.data() + offset, 0, silence_samples * sizeof(float));
+                    offset += silence_samples;
+                }
+            }
+        }
+
+        filtered_n_samples = offset;
+        WHISPER_LOG_INFO("%s: Reduced audio from %d to %d samples (%.1f%% reduction)\n",
+                        __func__, n_samples, filtered_n_samples, 100.0f * (1.0f - (float)filtered_n_samples / n_samples));
+    }
+
+    return true;
+}
+
 int whisper_full_with_state(
         struct whisper_context * ctx,
           struct whisper_state * state,
@@ -5483,9 +6724,24 @@ int whisper_full_with_state(
 
     result_all.clear();
 
-    if (n_samples > 0) {
+    const float * process_samples = samples;
+    int n_process_samples = n_samples;
+    std::vector<float> vad_samples;
+
+    if (params.vad) {
+        WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
+        int vad_n_samples;
+        if (!whisper_vad(ctx, state, params, samples, n_samples, vad_samples, vad_n_samples)) {
+            WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__);
+            return -1;
+        }
+        process_samples = vad_samples.data();
+        n_process_samples = vad_n_samples;
+    }
+
+    if (n_process_samples > 0) {
         // compute log mel spectrogram
-        if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
+        if (whisper_pcm_to_mel_with_state(ctx, state, process_samples, n_process_samples, params.n_threads) != 0) {
             WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
             return -2;
         }
@@ -6530,19 +7786,133 @@ int whisper_full_lang_id(struct whisper_context * ctx) {
 }
 
 int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) {
-    return state->result_all[i_segment].t0;
+    // If VAD wasn't used, return the original timestamp
+    if (!state->has_vad_segments || state->vad_segments.empty()) {
+        return state->result_all[i_segment].t0;
+    }
+
+    // Get the start timestamp produced by whisper_full. whisper_full processes
+    // only the speech segments in this case so we need to map these timestamps
+    // back to the original audio.
+    float t0 = state->result_all[i_segment].t0 / 100.0f;
+
+    // Find which VAD segment this timestamp belongs.
+    // TODO(danbev) This could be optimized by using a binary search if the number
+    // of segments exceed a certain limit. Also we might be able to assume that
+    // the access pattern is sequential and optimized for that too.
+    for (size_t i = 0; i < state->vad_segments.size(); i++) {
+        const auto & segment = state->vad_segments[i];
+
+        // Check if the timestamp falls within this segment.
+        if (t0 >= segment.vad_start && t0 <= segment.vad_end) {
+            float proportion = 0.0f;
+            if (segment.vad_end > segment.vad_start) {
+                proportion = (t0 - segment.vad_start) / (segment.vad_end - segment.vad_start);
+            }
+            float orig_t0 = segment.orig_start + proportion * (segment.orig_end - segment.orig_start);
+            return (int64_t)(orig_t0 * 100);
+        }
+    }
+
+    // Check if the timestamp falls between two segments.
+    for (size_t i = 0; i < state->vad_segments.size() - 1; i++) {
+        const auto & curr = state->vad_segments[i];
+        const auto & next = state->vad_segments[i + 1];
+
+        if (t0 > curr.vad_end && t0 < next.vad_start) {
+            // Calculate how far we are through the gap as a proportion
+            float gap_proportion = 0.0f;
+            if (next.vad_start > curr.vad_end) {
+                gap_proportion = (t0 - curr.vad_end) / (next.vad_start - curr.vad_end);
+            }
+            float orig_t0 = curr.orig_end + gap_proportion * (next.orig_start - curr.orig_end);
+            return (int64_t)(orig_t0 * 100);
+        }
+    }
+
+    // Handle the case where the timestamp is after the last segment.
+    if (t0 > state->vad_segments.back().vad_end) {
+        // For timestamps after the last segment, add the extra time to the end of the last segment
+        const auto& last = state->vad_segments.back();
+        // Calculate how far beyond the last segment
+        float extra_time = t0 - last.vad_end;
+        // Add this extra time to the original end time
+        float orig_t0 = last.orig_end + extra_time;
+        return (int64_t)(orig_t0 * 100);
+    }
+
+    WHISPER_LOG_WARN("%s: Could not map t0 = %f to a VAD segment\n", __func__, t0);
+    return t0;
 }
 
 int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
-    return ctx->state->result_all[i_segment].t0;
+    return whisper_full_get_segment_t0_from_state(ctx->state, i_segment);
 }
 
 int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) {
-    return state->result_all[i_segment].t1;
+    // If VAD wasn't used, return the original timestamp
+    if (!state->has_vad_segments || state->vad_segments.empty()) {
+        return state->result_all[i_segment].t1;
+    }
+
+    // Get the end timestamp produced by whisper_full. whisper_full processes
+    // only the speech segments in this case so we need to map these timestamps
+    // back to the original audio.
+    float t1 = state->result_all[i_segment].t1 / 100.0f;
+
+    // Find which VAD segment this timestamp belongs.
+    // TODO(danbev) This could be optimized by using a binary search if the number
+    // of segments exceed a certain limit. Also we might be able to assume that
+    // the access pattern is sequential and optimized for that too.
+    for (size_t i = 0; i < state->vad_segments.size(); i++) {
+        const auto& segment = state->vad_segments[i];
+
+        // Check if the timestamp falls within this segment.
+        if (t1 >= segment.vad_start && t1 <= segment.vad_end) {
+            // Calculate the proportion through the filtered segment.
+            float proportion = 0.0f;
+            if (segment.vad_end > segment.vad_start) {
+                proportion = (t1 - segment.vad_start) / (segment.vad_end - segment.vad_start);
+            }
+            float orig_t1 = segment.orig_start + proportion * (segment.orig_end - segment.orig_start);
+            return (int64_t)(orig_t1 * 100);
+        }
+    }
+
+    // Check if the timestamp falls between two segments.
+    for (size_t i = 0; i < state->vad_segments.size() - 1; i++) {
+        const auto & curr = state->vad_segments[i];
+        const auto & next = state->vad_segments[i + 1];
+
+        if (t1 > curr.vad_end && t1 < next.vad_start) {
+            // Calculate how far we are through the gap as a proportion
+            float gap_proportion = 0.0f;
+            if (next.vad_start > curr.vad_end) {
+                gap_proportion = (t1 - curr.vad_end) / (next.vad_start - curr.vad_end);
+            }
+            // Map to the corresponding position in the original gap
+            float orig_t1 = curr.orig_end + gap_proportion * (next.orig_start - curr.orig_end);
+            return (int64_t)(orig_t1 * 100);
+        }
+    }
+
+    // Handle the case where the timestamp is after the last segment
+    if (t1 > state->vad_segments.back().vad_end) {
+        // For the last segment, use the end of the last VAD segment
+        const auto& last = state->vad_segments.back();
+        // Calculate how far beyond the last segment
+        float extra_time = t1 - last.vad_end;
+        // Add this extra time to the original end time
+        float orig_t1 = last.orig_end + extra_time;
+        return (int64_t)(orig_t1 * 100);
+    }
+
+    WHISPER_LOG_WARN("%s: Could not map t1 = %f to a VAD segment\n", __func__, t1);
+    return t1;
 }
 
 int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
-    return ctx->state->result_all[i_segment].t1;
+    return whisper_full_get_segment_t1_from_state(ctx->state, i_segment);
 }
 
 bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment) {
index 7cdfed8228280a1257ad0b77a244b07c87ef197e..efa1bbe3fc8447c6881725ba24de4ed709b52c10 100644 (file)
@@ -1,3 +1,6 @@
+set(CMAKE_CXX_STANDARD 17)
+set(CMAKE_CXX_STANDARD_REQUIRED ON)
+
 if (EMSCRIPTEN)
     #
     # test-whisper-js
@@ -85,3 +88,18 @@ if (WHISPER_FFMPEG)
     set_tests_properties(${TEST_TARGET} PROPERTIES LABELS "tiny;mp3")
 endif()
 
+# VAD test tests VAD in isolation
+set(VAD_TEST test-vad)
+add_executable(${VAD_TEST} ${VAD_TEST}.cpp)
+target_include_directories(${VAD_TEST} PRIVATE ../include ../ggml/include ../examples)
+target_link_libraries(${VAD_TEST} PRIVATE common)
+add_test(NAME ${VAD_TEST} COMMAND ${VAD_TEST})
+set_tests_properties(${VAD_TEST} PROPERTIES LABELS "unit")
+
+# VAD test full uses whisper_full with VAD enabled
+set(VAD_TEST test-vad-full)
+add_executable(${VAD_TEST} ${VAD_TEST}.cpp)
+target_include_directories(${VAD_TEST} PRIVATE ../include ../ggml/include ../examples)
+target_link_libraries(${VAD_TEST} PRIVATE common)
+add_test(NAME ${VAD_TEST} COMMAND ${VAD_TEST})
+set_tests_properties(${VAD_TARGET} PROPERTIES LABELS "base;en")
diff --git a/tests/test-vad-full.cpp b/tests/test-vad-full.cpp
new file mode 100644 (file)
index 0000000..9eac11e
--- /dev/null
@@ -0,0 +1,54 @@
+#include "whisper.h"
+#include "common-whisper.h"
+
+#include <cstdio>
+#include <cfloat>
+#include <string>
+#include <cstring>
+
+#ifdef NDEBUG
+#undef NDEBUG
+#endif
+
+#include <cassert>
+
+int main() {
+    std::string whisper_model_path = "../../models/ggml-base.en.bin";
+    std::string vad_model_path     = "../../models/for-tests-silero-v5.1.2-ggml.bin";
+    std::string sample_path        = "../../samples/jfk.wav";
+
+    // Load the sample audio file
+    std::vector<float> pcmf32;
+    std::vector<std::vector<float>> pcmf32s;
+    assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false));
+
+    struct whisper_context_params cparams = whisper_context_default_params();
+    struct whisper_context * wctx = whisper_init_from_file_with_params(
+            whisper_model_path.c_str(),
+            cparams);
+
+    struct whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH);
+    wparams.vad            = true;
+    wparams.vad_model_path = vad_model_path.c_str();
+
+    wparams.vad_params.threshold               = 0.5f;
+    wparams.vad_params.min_speech_duration_ms  = 250;
+    wparams.vad_params.min_silence_duration_ms = 100;
+    wparams.vad_params.max_speech_duration_s   = FLT_MAX;
+    wparams.vad_params.speech_pad_ms           = 30;
+
+    assert(whisper_full_parallel(wctx, wparams, pcmf32.data(), pcmf32.size(), 1) == 0);
+
+    const int n_segments = whisper_full_n_segments(wctx);
+    assert(n_segments == 1);
+
+    assert(strcmp(" And so my fellow Americans, ask not what your country can do for you,"
+                  " ask what you can do for your country.",
+           whisper_full_get_segment_text(wctx, 0)) == 0);
+    assert(whisper_full_get_segment_t0(wctx, 0) == 29);
+    assert(whisper_full_get_segment_t1(wctx, 0) == 1049);
+
+    whisper_free(wctx);
+
+    return 0;
+}
diff --git a/tests/test-vad.cpp b/tests/test-vad.cpp
new file mode 100644 (file)
index 0000000..e6886e3
--- /dev/null
@@ -0,0 +1,83 @@
+#include "whisper.h"
+#include "common-whisper.h"
+
+#include <cstdio>
+#include <string>
+
+#ifdef NDEBUG
+#undef NDEBUG
+#endif
+#include <cassert>
+
+void assert_default_params(const struct whisper_vad_params & params) {
+    assert(params.threshold == 0.5);
+    assert(params.min_speech_duration_ms == 250);
+    assert(params.min_silence_duration_ms == 100);
+    assert(params.samples_overlap == 0.1f);
+}
+
+void assert_default_context_params(const struct whisper_vad_context_params & params) {
+    assert(params.n_threads == 4);
+    assert(params.use_gpu == false);
+    assert(params.gpu_device == 0);
+}
+
+void test_detect_speech(
+        struct whisper_vad_context * vctx,
+        struct whisper_vad_params params,
+        const float * pcmf32,
+        int n_samples) {
+    assert(whisper_vad_detect_speech(vctx, pcmf32, n_samples));
+    assert(whisper_vad_n_probs(vctx) == 344);
+    assert(whisper_vad_probs(vctx) != nullptr);
+}
+
+struct whisper_vad_segments * test_detect_timestamps(
+        struct whisper_vad_context * vctx,
+        struct whisper_vad_params params) {
+    struct whisper_vad_segments * timestamps = whisper_vad_segments_from_probs(vctx, params);
+    assert(whisper_vad_segments_n_segments(timestamps) == 5);
+
+    for (int i = 0; i < whisper_vad_segments_n_segments(timestamps); ++i) {
+        printf("VAD segment %d: start = %.2f, end = %.2f\n", i,
+               whisper_vad_segments_get_segment_t0(timestamps, i),
+               whisper_vad_segments_get_segment_t1(timestamps, i));
+    }
+
+    return timestamps;
+}
+
+int main() {
+    std::string vad_model_path = "../../models/for-tests-silero-v5.1.2-ggml.bin";
+    std::string sample_path    = "../../samples/jfk.wav";
+
+    // Load the sample audio file
+    std::vector<float> pcmf32;
+    std::vector<std::vector<float>> pcmf32s;
+    assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false));
+    assert(pcmf32.size() > 0);
+    assert(pcmf32s.size() == 0); // no stereo vector
+
+    // Load the VAD model
+    struct whisper_vad_context_params ctx_params = whisper_vad_default_context_params();
+    assert_default_context_params(ctx_params);
+
+    struct whisper_vad_context * vctx = whisper_vad_init_from_file_with_params(
+            vad_model_path.c_str(),
+            ctx_params);
+    assert(vctx != nullptr);
+
+    struct whisper_vad_params params = whisper_vad_default_params();
+    assert_default_params(params);
+
+    // Test speech probabilites
+    test_detect_speech(vctx, params, pcmf32.data(), pcmf32.size());
+
+    // Test speech timestamps (uses speech probabilities from above)
+    struct whisper_vad_segments * timestamps = test_detect_timestamps(vctx, params);
+
+    whisper_vad_free_segments(timestamps);
+    whisper_vad_free(vctx);
+
+    return 0;
+}