#include <atomic>
#include <algorithm>
#include <cassert>
+#include <cfloat>
#define _USE_MATH_DEFINES
#include <cmath>
#include <climits>
int n_threads,
ggml_abort_callback abort_callback,
void * abort_callback_data) {
-
ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get()));
static bool ggml_graph_compute_helper(
ggml_backend_sched_t sched,
struct ggml_cgraph * graph,
- int n_threads) {
-
+ int n_threads,
+ bool sched_reset = true) {
for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) {
ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i);
ggml_backend_dev_t dev = ggml_backend_get_device(backend);
}
}
- bool t = ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS;
- ggml_backend_sched_reset(sched);
+ const bool t = (ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS);
+
+ if (!t || sched_reset) {
+ ggml_backend_sched_reset(sched);
+ }
+
return t;
}
// [EXPERIMENTAL] speed-up techniques
int32_t exp_n_audio_ctx = 0; // 0 - use default
+
+ struct vad_segment_info {
+ float orig_start;
+ float orig_end;
+ float vad_start;
+ float vad_end;
+ };
+ std::vector<vad_segment_info> vad_segments;
+ bool has_vad_segments = false;
};
struct whisper_context {
}
//////////////////////////////////
-// Grammar - ported from llama.cpp
+// Voice Activity Detection (VAD)
//////////////////////////////////
-// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
-// pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
-static std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
- const char * src,
- whisper_partial_utf8 partial_start) {
- static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
- const char * pos = src;
- std::vector<uint32_t> code_points;
- uint32_t value = partial_start.value;
- int n_remain = partial_start.n_remain;
+struct whisper_vad_hparams {
+ int32_t n_encoder_layers;
+ int32_t * encoder_in_channels;
+ int32_t * encoder_out_channels;
+ int32_t * kernel_sizes;
+ int32_t lstm_input_size;
+ int32_t lstm_hidden_size;
+ int32_t final_conv_in;
+ int32_t final_conv_out;
+};
- // continue previous decode, if applicable
- while (*pos != 0 && n_remain > 0) {
- uint8_t next_byte = static_cast<uint8_t>(*pos);
- if ((next_byte >> 6) != 2) {
- // invalid sequence, abort
- code_points.push_back(0);
- return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
- }
- value = (value << 6) + (next_byte & 0x3F);
- ++pos;
- --n_remain;
- }
+struct whisper_vad_model {
+ std::string type;
+ std::string version;
+ whisper_vad_hparams hparams;
- if (partial_start.n_remain > 0 && n_remain == 0) {
- code_points.push_back(value);
- }
+ struct ggml_tensor * stft_forward_basis; // [256, 1, 258]
- // decode any subsequent utf-8 sequences, which may end in an incomplete one
- while (*pos != 0) {
- uint8_t first_byte = static_cast<uint8_t>(*pos);
- uint8_t highbits = first_byte >> 4;
- n_remain = lookup[highbits] - 1;
+ // Encoder tensors - 4 convolutional layers
+ struct ggml_tensor * encoder_0_weight; // [3, 129, 128]
+ struct ggml_tensor * encoder_0_bias; // [128]
- if (n_remain < 0) {
- // invalid sequence, abort
- code_points.clear();
- code_points.push_back(0);
- return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain });
- }
+ // Second encoder layer
+ struct ggml_tensor * encoder_1_weight; // [3, 128, 64]
+ struct ggml_tensor * encoder_1_bias; // [64]
- uint8_t mask = (1 << (7 - n_remain)) - 1;
- value = first_byte & mask;
- ++pos;
- while (*pos != 0 && n_remain > 0) {
- value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
- ++pos;
- --n_remain;
- }
- if (n_remain == 0) {
- code_points.push_back(value);
- }
- }
- code_points.push_back(0);
+ // Third encoder layer
+ struct ggml_tensor * encoder_2_weight; // [3, 64, 64]
+ struct ggml_tensor * encoder_2_bias; // [64]
- return std::make_pair(std::move(code_points), whisper_partial_utf8{ value, n_remain });
-}
+ // Fourth encoder layer
+ struct ggml_tensor * encoder_3_weight; // [3, 64, 128]
+ struct ggml_tensor * encoder_3_bias; // [128]
-// returns true iff pos points to the end of one of the definitions of a rule
-static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element * pos) {
- switch (pos->type) {
- case WHISPER_GRETYPE_END: return true; // NOLINT
- case WHISPER_GRETYPE_ALT: return true; // NOLINT
- default: return false;
- }
-}
+ // LSTM decoder tensors
+ struct ggml_tensor * lstm_ih_weight; // [128, 512] input-to-hidden
+ struct ggml_tensor * lstm_ih_bias; // [512]
+ struct ggml_tensor * lstm_hh_weight; // [128, 512] hidden-to-hidden
+ struct ggml_tensor * lstm_hh_bias; // [512]
-// returns true iff chr satisfies the char range at pos (regular or inverse range)
-// asserts that pos is pointing to a char range element
-static std::pair<bool, const whisper_grammar_element *> whisper_grammar_match_char(
- const whisper_grammar_element * pos,
- const uint32_t chr) {
+ // Final conv layer
+ struct ggml_tensor * final_conv_weight; // [128]
+ struct ggml_tensor * final_conv_bias; // [1]
- bool found = false;
- bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
+ // ggml contexts
+ std::vector<ggml_context *> ctxs;
- WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT
+ // buffer for the model tensors
+ std::vector<ggml_backend_buffer_t> buffers;
- do {
- if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
- // inclusive range, e.g. [a-z]
- found = found || (pos->value <= chr && chr <= pos[1].value);
- pos += 2;
- } else {
- // exact char match, e.g. [a] or "a"
- found = found || pos->value == chr;
- pos += 1;
- }
- } while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
+ // tensors
+ int n_loaded;
+ std::map<std::string, struct ggml_tensor *> tensors;
+};
- return std::make_pair(found == is_positive_char, pos);
+struct whisper_vad_segment {
+ float start; // Start time in seconds
+ float end; // End time in seconds
+};
+
+struct whisper_vad_segments {
+ std::vector<whisper_vad_segment> data;
+};
+
+struct whisper_vad_context {
+ int64_t t_vad_us = 0;
+
+ int n_window;
+ int n_context;
+ int n_threads;
+
+ std::vector<ggml_backend_t> backends;
+ ggml_backend_buffer_t buffer = nullptr;
+ whisper_context_params params;
+ std::vector<uint8_t> ctx_buf;
+ whisper_sched sched;
+
+ whisper_vad_model model;
+ std::string path_model;
+ struct ggml_tensor * h_state;
+ struct ggml_tensor * c_state;
+ std::vector<float> probs;
+};
+
+struct whisper_vad_context_params whisper_vad_default_context_params(void) {
+ whisper_vad_context_params result = {
+ /*.n_thread = */ 4,
+ /*.use_gpu = */ false,
+ /*.gpu_device = */ 0,
+ };
+ return result;
}
-// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
-// range at pos (regular or inverse range)
-// asserts that pos is pointing to a char range element
-static bool whisper_grammar_match_partial_char(
- const whisper_grammar_element * pos,
- const whisper_partial_utf8 partial_utf8) {
+struct whisper_vad_params whisper_vad_default_params(void) {
+ whisper_vad_params result = {
+ /* threshold = */ 0.5f,
+ /* min_speech_duration_ms = */ 250,
+ /* min_silence_duration_ms = */ 100,
+ /* max_speech_duration_s = */ FLT_MAX,
+ /* speech_pad_ms = */ 30,
+ /* samples_overlap = */ 0.1,
+ };
+ return result;
+}
- bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
- WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT);
+static bool weight_buft_supported(const whisper_vad_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) {
+ bool op_supported = true;
- uint32_t partial_value = partial_utf8.value;
- int n_remain = partial_utf8.n_remain;
+ if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU ||
+ (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) {
+ // GPU and default CPU backend support all operators
+ op_supported = true;
+ } else {
+ switch (op) {
+ // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT
+ case GGML_OP_MUL_MAT: {
+ ggml_init_params params = {
+ /*.mem_size =*/ 2 * ggml_tensor_overhead(),
+ /*.mem_buffer =*/ nullptr,
+ /*.no_alloc =*/ true,
+ };
- // invalid sequence or 7-bit char split across 2 bytes (overlong)
- if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
- return false;
- }
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
+ if (!ctx_ptr) {
+ throw std::runtime_error("failed to create ggml context");
+ }
+ ggml_context * ctx = ctx_ptr.get();
- // range of possible code points this partial UTF-8 sequence could complete to
- uint32_t low = partial_value << (n_remain * 6);
- uint32_t high = low | ((1 << (n_remain * 6)) - 1);
+ ggml_tensor * op_tensor = nullptr;
- if (low == 0) {
- if (n_remain == 2) {
- low = 1 << 11;
- } else if (n_remain == 3) {
- low = 1 << 16;
- }
- }
+ int64_t n_ctx = hparams.lstm_hidden_size;
+ ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
+ op_tensor = ggml_mul_mat(ctx, w, b);
- do {
- if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
- // inclusive range, e.g. [a-z]
- if (pos->value <= high && low <= pos[1].value) {
- return is_positive_char;
+ // create a temporary dummy buffer for the weight so that supports_op can check the buffer type
+ GGML_ASSERT(w->buffer == nullptr);
+ w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
+ op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
+ ggml_backend_buffer_free(w->buffer);
+ w->buffer = nullptr;
+ break;
}
- pos += 2;
- } else {
- // exact char match, e.g. [a] or "a"
- if (low <= pos->value && pos->value <= high) {
- return is_positive_char;
+ default: {
+ op_supported = false;
+ break;
}
- pos += 1;
+ };
+ }
+ return op_supported;
+}
+
+static ggml_backend_buffer_type_t select_weight_buft(const whisper_vad_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) {
+ GGML_ASSERT(!buft_list.empty());
+ for (const auto & p : buft_list) {
+ ggml_backend_dev_t dev = p.first;
+ ggml_backend_buffer_type_t buft = p.second;
+ if (weight_buft_supported(hparams, w, op, buft, dev)) {
+ return buft;
}
- } while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
+ }
- return !is_positive_char;
+ return nullptr;
}
+static ggml_tensor * whisper_vad_build_stft_layer(ggml_context * ctx0,
+ const whisper_vad_model & model, ggml_tensor * cur) {
+ // Apply reflective padding to the input tensor
+ ggml_tensor * padded = ggml_pad_reflect_1d(ctx0, cur, 64, 64);
-// transforms a grammar pushdown stack into N possible stacks, all ending
-// at a character range (terminal element)
-static void whisper_grammar_advance_stack(
- const std::vector<std::vector<whisper_grammar_element>> & rules,
- const std::vector<const whisper_grammar_element *> & stack,
- std::vector<std::vector<const whisper_grammar_element *>> & new_stacks) {
+ struct ggml_tensor * stft = ggml_conv_1d(ctx0, model.stft_forward_basis, padded, model.hparams.lstm_input_size, 0, 1);
- if (stack.empty()) {
- new_stacks.emplace_back();
- return;
- }
+ // Calculate cutoff for real/imaginary parts
+ int cutoff = model.stft_forward_basis->ne[2] / 2;
- const whisper_grammar_element * pos = stack.back();
+ // Extract real part (first half of the STFT output).
+ struct ggml_tensor * real_part = ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], 0);
+ // Extract imaginary part (second half of the STFT output).
+ struct ggml_tensor * img_part = ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], cutoff * stft->nb[1]);
- switch (pos->type) {
- case WHISPER_GRETYPE_RULE_REF: {
- const size_t rule_id = static_cast<size_t>(pos->value);
- const whisper_grammar_element * subpos = rules[rule_id].data();
- do {
- // init new stack without the top (pos)
- std::vector<const whisper_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
- if (!whisper_grammar_is_end_of_sequence(pos + 1)) {
- // if this rule ref is followed by another element, add that to stack
- new_stack.push_back(pos + 1);
- }
- if (!whisper_grammar_is_end_of_sequence(subpos)) {
- // if alternate is nonempty, add to stack
- new_stack.push_back(subpos);
- }
- whisper_grammar_advance_stack(rules, new_stack, new_stacks);
- while (!whisper_grammar_is_end_of_sequence(subpos)) {
- // scan to end of alternate def
- subpos++;
- }
- if (subpos->type == WHISPER_GRETYPE_ALT) {
- // there's another alternate def of this rule to process
- subpos++;
- } else {
- break;
- }
- } while (true);
- break;
- }
- case WHISPER_GRETYPE_CHAR:
- case WHISPER_GRETYPE_CHAR_NOT:
- new_stacks.push_back(stack);
- break;
- default:
- // end of alternate (WHISPER_GRETYPE_END, WHISPER_GRETYPE_ALT) or middle of char range
- // (WHISPER_GRETYPE_CHAR_ALT, WHISPER_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
- // those
- WHISPER_ASSERT(false);
- }
+ // Calculate magnitude: sqrt(real^2 + imag^2)
+ struct ggml_tensor * real_squared = ggml_mul(ctx0, real_part, real_part);
+ struct ggml_tensor * img_squared = ggml_mul(ctx0, img_part, img_part);
+ struct ggml_tensor * sum_squares = ggml_add(ctx0, real_squared, img_squared);
+ struct ggml_tensor * magnitude = ggml_sqrt(ctx0, sum_squares);
+ return magnitude;
}
-// takes a set of possible pushdown stacks on a grammar, which are required to
-// be positioned at a character range (see `whisper_grammar_advance_stack`), and
-// produces the N possible stacks if the given char is accepted at those
-// positions
-static std::vector<std::vector<const whisper_grammar_element *>> whisper_grammar_accept(
- const std::vector<std::vector<whisper_grammar_element>> & rules,
- const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
- const uint32_t chr) {
+static ggml_tensor * whisper_vad_build_encoder_layer(ggml_context * ctx0,
+ const whisper_vad_model & model, ggml_tensor * cur) {
+ // First Conv1D: expands to 128 channels.
+ cur = ggml_conv_1d(ctx0, model.encoder_0_weight, cur, 1, 1, 1);
+ cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_0_bias, 1, 128, 1));
+ cur = ggml_relu(ctx0, cur);
- std::vector<std::vector<const whisper_grammar_element *>> new_stacks;
+ // Second Conv1D: reduces to 64 channels.
+ cur = ggml_conv_1d(ctx0, model.encoder_1_weight, cur, 2, 1, 1);
+ cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_1_bias, 1, 64, 1));
+ cur = ggml_relu(ctx0, cur);
- for (const auto & stack : stacks) {
- if (stack.empty()) {
- continue;
- }
+ // Third Conv1D: maintains 64 channels
+ cur = ggml_conv_1d(ctx0, model.encoder_2_weight, cur, 2, 1, 1);
+ cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_2_bias, 1, 64, 1));
+ cur = ggml_relu(ctx0, cur);
- auto match = whisper_grammar_match_char(stack.back(), chr);
- if (match.first) {
+ // Fourth Conv1D: expands to 128 channels
+ cur = ggml_conv_1d(ctx0, model.encoder_3_weight, cur, 1, 1, 1);
+ cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_3_bias, 1, 128, 1));
+ cur = ggml_relu(ctx0, cur);
+
+ return cur;
+}
+
+static ggml_tensor * whisper_vad_build_lstm_layer(ggml_context * ctx0,
+ const whisper_vad_context & vctx, ggml_tensor * cur, ggml_cgraph * gf) {
+ const whisper_vad_model & model = vctx.model;
+ const int hdim = model.hparams.lstm_hidden_size;
+
+ struct ggml_tensor * x_t = ggml_transpose(ctx0, cur);
+
+ // Create operations using the input-to-hidden weights.
+ struct ggml_tensor * inp_gate = ggml_mul_mat(ctx0, model.lstm_ih_weight, x_t);
+ inp_gate = ggml_add(ctx0, inp_gate, model.lstm_ih_bias);
+
+ // Create operations using the hidden-to-hidden weights.
+ struct ggml_tensor * hid_gate = ggml_mul_mat(ctx0, model.lstm_hh_weight, vctx.h_state);
+ hid_gate = ggml_add(ctx0, hid_gate, model.lstm_hh_bias);
+
+ // Create add operation to get preactivations for all gates.
+ struct ggml_tensor * out_gate = ggml_add(ctx0, inp_gate, hid_gate);
+
+ const size_t hdim_size = ggml_row_size(out_gate->type, hdim);
+
+ // Create sigmoid for input gate (using the first 128 bytes from the preactivations).
+ struct ggml_tensor * i_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 0 * hdim_size));
+
+ // Create sigmoid for the forget gate (using the second 128 bytes from the preactivations).
+ struct ggml_tensor * f_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 1 * hdim_size));
+
+ // Create sigmoid for the cell gate (using the third 128 bytes from the preactivations).
+ struct ggml_tensor * g_t = ggml_tanh(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 2 * hdim_size));
+
+ // Create sigmoid for the output gate (using the fourth 128 bytes from the preactivations).
+ struct ggml_tensor * o_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 3 * hdim_size));
+
+ // Update cell state
+ struct ggml_tensor * c_out = ggml_add(ctx0,
+ ggml_mul(ctx0, f_t, vctx.c_state),
+ ggml_mul(ctx0, i_t, g_t));
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, c_out, vctx.c_state));
+
+ // Update hidden state
+ struct ggml_tensor * out = ggml_mul(ctx0, o_t, ggml_tanh(ctx0, c_out));
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, out, vctx.h_state));
+
+ return out;
+}
+
+static struct ggml_cgraph * whisper_vad_build_graph(whisper_vad_context & vctx) {
+ const auto & model = vctx.model;
+
+ struct ggml_init_params params = {
+ /*.mem_size =*/ vctx.sched.meta.size(),
+ /*.mem_buffer =*/ vctx.sched.meta.data(),
+ /*.no_alloc =*/ true,
+ };
+
+ struct ggml_context * ctx0 = ggml_init(params);
+
+ ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+ struct ggml_tensor * frame = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, vctx.n_window, 1);
+ ggml_set_name(frame, "frame");
+ ggml_set_input(frame);
+
+ struct ggml_tensor * cur = nullptr;
+ {
+ cur = whisper_vad_build_stft_layer(ctx0, model, frame);
+
+ cur = whisper_vad_build_encoder_layer(ctx0, model, cur);
+
+ // Extract the first element of the first dimension
+ // (equivalent to pytorch's [:, :, 0])
+ cur = ggml_view_2d(ctx0, cur, 1, 128, cur->nb[1], 0);
+
+ cur = whisper_vad_build_lstm_layer(ctx0, vctx, cur, gf);
+ cur = ggml_relu(ctx0, cur);
+ cur = ggml_conv_1d(ctx0, model.final_conv_weight, cur, 1, 0, 1);
+ cur = ggml_add(ctx0, cur, model.final_conv_bias);
+ cur = ggml_sigmoid(ctx0, cur);
+ ggml_set_name(cur, "prob");
+ ggml_set_output(cur);
+ }
+
+ ggml_build_forward_expand(gf, cur);
+
+ ggml_free(ctx0);
+
+ return gf;
+}
+
+static bool whisper_vad_init_context(whisper_vad_context * vctx) {
+
+ auto whisper_context_params = whisper_context_default_params();
+ // TODO: GPU VAD is forced disabled until the performance is improved
+ //whisper_context_params.use_gpu = vctx->params.use_gpu;
+ whisper_context_params.use_gpu = false;
+ whisper_context_params.gpu_device = vctx->params.gpu_device;
+
+ vctx->backends = whisper_backend_init(whisper_context_params);
+ if (vctx->backends.empty()) {
+ WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
+ return false;
+ }
+
+ const int32_t lstm_hidden_size = vctx->model.hparams.lstm_hidden_size;
+
+ vctx->ctx_buf.resize(2u*ggml_tensor_overhead());
+
+ struct ggml_init_params params = {
+ /*.mem_size =*/ vctx->ctx_buf.size(),
+ /*.mem_buffer =*/ vctx->ctx_buf.data(),
+ /*.no_alloc =*/ true,
+ };
+
+ ggml_context * ctx = ggml_init(params);
+ if (!ctx) {
+ WHISPER_LOG_ERROR("%s: failed to init LSTM state ggml context\n", __func__);
+ return false;
+ }
+
+ // LSTM Hidden state
+ vctx->h_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, lstm_hidden_size);
+ ggml_set_name(vctx->h_state, "h_state");
+
+ // LSTM Cell state
+ vctx->c_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, lstm_hidden_size);
+ ggml_set_name(vctx->c_state, "c_state");
+
+ vctx->buffer = ggml_backend_alloc_ctx_tensors(ctx, vctx->backends[0]);
+ if (!vctx->buffer) {
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for the VAD state\n", __func__);
+ return false;
+ }
+
+ {
+ bool ok = whisper_sched_graph_init(vctx->sched, vctx->backends,
+ [&]() {
+ return whisper_vad_build_graph(*vctx);
+ });
+
+ if (!ok) {
+ WHISPER_LOG_ERROR("%s: failed to init VAD allocator\n", __func__);
+ return false;
+ }
+
+ WHISPER_LOG_INFO("%s: compute buffer (VAD) = %7.2f MB\n", __func__, whisper_sched_size(vctx->sched) / 1e6);
+ }
+
+ return true;
+}
+
+struct whisper_vad_context * whisper_vad_init_from_file_with_params(
+ const char * path_model,
+ struct whisper_vad_context_params params) {
+ WHISPER_LOG_INFO("%s: loading VAD model from '%s'\n", __func__, path_model);
+#ifdef _MSC_VER
+ std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
+ std::wstring path_model_wide = converter.from_bytes(path_model);
+ auto fin = std::ifstream(path_model_wide, std::ios::binary);
+#else
+ auto fin = std::ifstream(path_model, std::ios::binary);
+#endif
+ if (!fin) {
+ WHISPER_LOG_ERROR("%s: failed to open VAD model '%s'\n", __func__, path_model);
+ return nullptr;
+ }
+
+ whisper_model_loader loader = {};
+ loader.context = &fin;
+
+ loader.read = [](void * ctx, void * output, size_t read_size) {
+ std::ifstream * fin = (std::ifstream*)ctx;
+ fin->read((char *)output, read_size);
+ return read_size;
+ };
+
+ loader.eof = [](void * ctx) {
+ std::ifstream * fin = (std::ifstream*)ctx;
+ return fin->eof();
+ };
+
+ loader.close = [](void * ctx) {
+ std::ifstream * fin = (std::ifstream*)ctx;
+ fin->close();
+ };
+
+ auto ctx = whisper_vad_init_with_params(&loader, params);
+ if (!ctx) {
+ whisper_vad_free(ctx);
+ return nullptr;
+ }
+ ctx->path_model = path_model;
+ return ctx;
+}
+
+struct whisper_vad_context * whisper_vad_init_with_params(
+ struct whisper_model_loader * loader,
+ struct whisper_vad_context_params params) {
+ // Read the VAD model
+ {
+ uint32_t magic;
+ read_safe(loader, magic);
+ if (magic != GGML_FILE_MAGIC) {
+ WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__);
+ return nullptr;
+ }
+ }
+
+ whisper_vad_context * vctx = new whisper_vad_context;
+ vctx->n_threads = params.n_threads;
+ vctx->params.use_gpu = params.use_gpu;
+ vctx->params.gpu_device = params.gpu_device;
+
+ auto & model = vctx->model;
+ auto & hparams = model.hparams;
+
+ // load model context params.
+ {
+ int32_t str_len;
+ read_safe(loader, str_len);
+ std::vector<char> buffer(str_len + 1, 0);
+ loader->read(loader->context, buffer.data(), str_len);
+ std::string model_type(buffer.data(), str_len);
+ model.type = model_type;
+ WHISPER_LOG_INFO("%s: model type: %s\n", __func__, model.type.c_str());
+
+ int32_t major, minor, patch;
+ read_safe(loader, major);
+ read_safe(loader, minor);
+ read_safe(loader, patch);
+ std::string version_str = std::to_string(major) + "." +
+ std::to_string(minor) + "." +
+ std::to_string(patch);
+ model.version = version_str;
+ WHISPER_LOG_INFO("%s: model version: %s\n", __func__, model.version.c_str());
+
+ read_safe(loader, vctx->n_window);
+ read_safe(loader, vctx->n_context);
+ }
+
+ // load model hyper params (hparams).
+ {
+ read_safe(loader, hparams.n_encoder_layers);
+
+ hparams.encoder_in_channels = new int32_t[hparams.n_encoder_layers];
+ hparams.encoder_out_channels = new int32_t[hparams.n_encoder_layers];
+ hparams.kernel_sizes = new int32_t[hparams.n_encoder_layers];
+
+ for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
+ read_safe(loader, hparams.encoder_in_channels[i]);
+ read_safe(loader, hparams.encoder_out_channels[i]);
+ read_safe(loader, hparams.kernel_sizes[i]);
+ }
+
+ read_safe(loader, hparams.lstm_input_size);
+ read_safe(loader, hparams.lstm_hidden_size);
+ read_safe(loader, hparams.final_conv_in);
+ read_safe(loader, hparams.final_conv_out);
+
+ WHISPER_LOG_INFO("%s: n_encoder_layers = %d\n", __func__, hparams.n_encoder_layers);
+ for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
+ WHISPER_LOG_INFO("%s: encoder_in_channels[%d] = %d\n", __func__, i, hparams.encoder_in_channels[i]);
+ }
+ for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
+ WHISPER_LOG_INFO("%s: encoder_out_channels[%d] = %d\n", __func__, i, hparams.encoder_out_channels[i]);
+ }
+ WHISPER_LOG_INFO("%s: lstm_input_size = %d\n", __func__, hparams.lstm_input_size);
+ WHISPER_LOG_INFO("%s: lstm_hidden_size = %d\n", __func__, hparams.lstm_hidden_size);
+ WHISPER_LOG_INFO("%s: final_conv_in = %d\n", __func__, hparams.final_conv_in);
+ WHISPER_LOG_INFO("%s: final_conv_out = %d\n", __func__, hparams.final_conv_out);
+ }
+
+ // 1 STFT tensor, 4*2 encoder tensors, 4 LSTM tensors, 2 final output tensors
+ const size_t n_tensors = hparams.n_encoder_layers * 2 + 4 + 2 + 1;
+
+ std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
+ auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
+ auto it = ctx_map.find(buft);
+ if (it == ctx_map.end()) {
+ ggml_init_params params = {
+ /*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
+ /*.mem_buffer =*/ nullptr,
+ /*.no_alloc =*/ true,
+ };
+
+ ggml_context * ctx = ggml_init(params);
+ if (!ctx) {
+ throw std::runtime_error("failed to create ggml context");
+ }
+
+ ctx_map[buft] = ctx;
+ model.ctxs.emplace_back(ctx);
+
+ return ctx;
+ }
+
+ return it->second;
+ };
+
+ whisper_context_params wparams = whisper_context_default_params();
+ wparams.use_gpu = params.use_gpu;
+ wparams.gpu_device = params.gpu_device;
+ buft_list_t buft_list = make_buft_list(wparams);
+
+ auto create_tensor = [&](vad_tensor type, ggml_tensor * meta) -> ggml_tensor * {
+ ggml_op op = VAD_TENSOR_OPS.at(type);
+ ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
+ if (!buft) {
+ throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", VAD_TENSOR_NAMES.at(type)));
+ }
+ ggml_context * ctx = get_ctx(buft);
+ ggml_tensor * tensor = ggml_dup_tensor(ctx, meta);
+ model.tensors[VAD_TENSOR_NAMES.at(type)] = tensor;
+
+ return tensor;
+ };
+
+ // create tensors
+ {
+ ggml_init_params params = {
+ /*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
+ /*.mem_buffer =*/ nullptr,
+ /*.no_alloc =*/ true,
+ };
+
+ ggml_context * ctx = ggml_init(params);
+ const auto & hparams = model.hparams;
+
+ // SFTF precomputed basis matrix
+ model.stft_forward_basis = create_tensor(VAD_TENSOR_STFT_BASIS,
+ ggml_new_tensor_3d(ctx, GGML_TYPE_F16, 256, 1, 258));
+
+ model.encoder_0_weight = create_tensor(VAD_TENSOR_ENC_0_WEIGHT,
+ ggml_new_tensor_3d(
+ ctx,
+ GGML_TYPE_F16,
+ hparams.kernel_sizes[0],
+ hparams.encoder_in_channels[0],
+ hparams.encoder_out_channels[0]
+ ));
+ model.encoder_0_bias = create_tensor(VAD_TENSOR_ENC_0_BIAS,
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[0]));
+
+ model.encoder_1_weight = create_tensor(VAD_TENSOR_ENC_1_WEIGHT,
+ ggml_new_tensor_3d(
+ ctx,
+ GGML_TYPE_F16,
+ hparams.kernel_sizes[1],
+ hparams.encoder_in_channels[1],
+ hparams.encoder_out_channels[1]
+ ));
+ model.encoder_1_bias = create_tensor(VAD_TENSOR_ENC_1_BIAS,
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[1]));
+
+ model.encoder_2_weight = create_tensor(VAD_TENSOR_ENC_2_WEIGHT,
+ ggml_new_tensor_3d(
+ ctx,
+ GGML_TYPE_F16,
+ hparams.kernel_sizes[2],
+ hparams.encoder_in_channels[2],
+ hparams.encoder_out_channels[2]
+ ));
+ model.encoder_2_bias = create_tensor(VAD_TENSOR_ENC_2_BIAS,
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[2]));
+
+ model.encoder_3_weight = create_tensor(VAD_TENSOR_ENC_3_WEIGHT,
+ ggml_new_tensor_3d(
+ ctx,
+ GGML_TYPE_F16,
+ hparams.kernel_sizes[3],
+ hparams.encoder_in_channels[3],
+ hparams.encoder_out_channels[3]
+ ));
+ model.encoder_3_bias = create_tensor(VAD_TENSOR_ENC_3_BIAS,
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[3]));
+
+ // Hidden State dimension (input gate, forget gate, cell gate, output gate)
+ const int hstate_dim = hparams.lstm_hidden_size * 4;
+
+ // LSTM weights - input to hidden
+ model.lstm_ih_weight = create_tensor(
+ VAD_TENSOR_LSTM_WEIGHT_IH,
+ ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim)
+ );
+ model.lstm_ih_bias = create_tensor(
+ VAD_TENSOR_LSTM_BIAS_IH,
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim)
+ );
+
+ // LSTM weights - hidden to hidden
+ model.lstm_hh_weight = create_tensor(
+ VAD_TENSOR_LSTM_WEIGHT_HH,
+ ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim)
+ );
+ model.lstm_hh_bias = create_tensor(
+ VAD_TENSOR_LSTM_BIAS_HH,
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim)
+ );
+
+ // Final conv layer weight
+ model.final_conv_weight = create_tensor(
+ VAD_TENSOR_FINAL_CONV_WEIGHT,
+ ggml_new_tensor_2d(ctx, GGML_TYPE_F16, hparams.final_conv_in, 1)
+ );
+ model.final_conv_bias = create_tensor(
+ VAD_TENSOR_FINAL_CONV_BIAS,
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1)
+ );
+
+ ggml_free(ctx);
+ }
+
+ // allocate tensors in the backend buffers
+ for (auto & p : ctx_map) {
+ ggml_backend_buffer_type_t buft = p.first;
+ ggml_context * ctx = p.second;
+ ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
+ if (buf) {
+ model.buffers.emplace_back(buf);
+
+ size_t size_main = ggml_backend_buffer_get_size(buf);
+ WHISPER_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6);
+ }
+ }
+
+ // load weights
+ {
+ size_t total_size = 0;
+ model.n_loaded = 0;
+ std::vector<char> read_buf;
+
+ while (true) {
+ int32_t n_dims;
+ int32_t length;
+ int32_t ttype;
+
+ read_safe(loader, n_dims);
+ read_safe(loader, length);
+ read_safe(loader, ttype);
+
+ if (loader->eof(loader->context)) {
+ break;
+ }
+
+ int32_t nelements = 1;
+ int32_t ne[4] = { 1, 1, 1, 1 };
+ for (int i = 0; i < n_dims; ++i) {
+ read_safe(loader, ne[i]);
+ nelements *= ne[i];
+ }
+
+ std::string name;
+ std::vector<char> tmp(length);
+ loader->read(loader->context, &tmp[0], tmp.size());
+ name.assign(&tmp[0], tmp.size());
+
+ if (model.tensors.find(name) == model.tensors.end()) {
+ WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data());
+ return nullptr;
+ }
+
+ auto tensor = model.tensors[name.data()];
+
+ if (ggml_nelements(tensor) != nelements) {
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
+ WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
+ __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
+ return nullptr;
+ }
+
+ if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
+ __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
+ return nullptr;
+ }
+
+ const size_t bpe = ggml_type_size(ggml_type(ttype));
+
+ if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
+ __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
+ return nullptr;
+ }
+
+ if (ggml_backend_buffer_is_host(tensor->buffer)) {
+ // for the CPU and Metal backend, we can read directly into the tensor
+ loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
+ BYTESWAP_TENSOR(tensor);
+ } else {
+ // read into a temporary buffer first, then copy to device memory
+ read_buf.resize(ggml_nbytes(tensor));
+
+ loader->read(loader->context, read_buf.data(), read_buf.size());
+
+ ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
+ }
+
+ total_size += ggml_nbytes(tensor);
+ model.n_loaded++;
+ }
+
+ WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6);
+
+ if (model.n_loaded == 0) {
+ WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
+ } else if (model.n_loaded != (int) model.tensors.size()) {
+ WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
+ return nullptr;
+ }
+
+ }
+
+ if (!whisper_vad_init_context(vctx)) {
+ whisper_vad_free(vctx);
+ return nullptr;
+ }
+
+ return vctx;
+}
+
+bool whisper_vad_detect_speech(
+ struct whisper_vad_context * vctx,
+ const float * samples,
+ int n_samples) {
+ int n_chunks = n_samples / vctx->n_window;
+ if (n_samples % vctx->n_window != 0) {
+ n_chunks += 1; // Add one more chunk for remaining samples.
+ }
+
+ WHISPER_LOG_INFO("%s: detecting speech in %d samples\n", __func__, n_samples);
+ WHISPER_LOG_INFO("%s: n_chunks: %d\n", __func__, n_chunks);
+
+ // Reset LSTM hidden/cell states
+ ggml_backend_buffer_clear(vctx->buffer, 0);
+
+ vctx->probs.resize(n_chunks);
+ WHISPER_LOG_INFO("%s: props size: %u\n", __func__, n_chunks);
+
+ std::vector<float> window(vctx->n_window, 0.0f);
+
+ auto & sched = vctx->sched.sched;
+
+ ggml_cgraph * gf = whisper_vad_build_graph(*vctx);
+
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
+ WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
+ return false;
+ }
+
+ struct ggml_tensor * frame = ggml_graph_get_tensor(gf, "frame");
+ struct ggml_tensor * prob = ggml_graph_get_tensor(gf, "prob");
+
+ // we are going to reuse the graph multiple times for each chunk
+ const int64_t t_start_vad_us = ggml_time_us();
+
+ for (int i = 0; i < n_chunks; i++) {
+ const int idx_start = i * vctx->n_window;
+ const int idx_end = std::min(idx_start + vctx->n_window, n_samples);
+
+ const int chunk_len = idx_end - idx_start;
+
+ if (chunk_len < vctx->n_window) {
+ WHISPER_LOG_INFO("%s: chunk_len: %d < n_window: %d\n", __func__, chunk_len, vctx->n_window);
+ std::vector<float> partial_chunk(vctx->n_window, 0.0f);
+ std::copy(samples + idx_start, samples + idx_end, partial_chunk.begin());
+
+ // Copy the zero-padded chunk to the window.
+ const int samples_to_copy_max = vctx->n_window;
+ const int samples_to_copy_cur = std::min(samples_to_copy_max, (int)partial_chunk.size());
+ std::copy(partial_chunk.begin(), partial_chunk.begin() + samples_to_copy_cur, window.begin());
+ if (samples_to_copy_cur < samples_to_copy_max) {
+ std::fill(window.begin() + samples_to_copy_cur, window.end(), 0.0f);
+ }
+ } else {
+ // Copy current frame samples to the window.
+ const int samples_to_copy = std::min(idx_end - idx_start, vctx->n_window);
+ std::copy(samples + idx_start, samples + idx_start + samples_to_copy, window.begin());
+ }
+
+ // Set the frame tensor data with the samples.
+ ggml_backend_tensor_set(frame, window.data(), 0, ggml_nelements(frame) * sizeof(float));
+
+ // do not reset the scheduler - we will reuse the graph in the next chunk
+ if (!ggml_graph_compute_helper(sched, gf, vctx->n_threads, false)) {
+ WHISPER_LOG_ERROR("%s: failed to compute VAD graph\n", __func__);
+ break;
+ }
+
+ // Get the probability for this chunk.
+ ggml_backend_tensor_get(prob, &vctx->probs[i], 0, sizeof(float));
+
+ //WHISPER_LOG_DEBUG("chunk %d: p = %7.3f\n", i, probs[i]);
+ }
+
+ vctx->t_vad_us += ggml_time_us() - t_start_vad_us;
+ WHISPER_LOG_INFO("%s: vad time = %.2f ms processing %d samples\n", __func__, 1e-3f * vctx->t_vad_us, n_samples);
+
+ ggml_backend_sched_reset(sched);
+
+ return true;
+}
+
+int whisper_vad_segments_n_segments(struct whisper_vad_segments * segments) {
+ return segments->data.size();
+}
+
+float whisper_vad_segments_get_segment_t0(struct whisper_vad_segments * segments, int i_segment) {
+ return segments->data[i_segment].start;
+}
+
+float whisper_vad_segments_get_segment_t1(struct whisper_vad_segments * segments, int i_segment) {
+ return segments->data[i_segment].end;
+}
+
+int whisper_vad_n_probs(struct whisper_vad_context * vctx) {
+ return vctx->probs.size();
+}
+
+float * whisper_vad_probs(struct whisper_vad_context * vctx) {
+ return vctx->probs.data();
+}
+
+struct whisper_vad_segments * whisper_vad_segments_from_probs(
+ struct whisper_vad_context * vctx,
+ whisper_vad_params params) {
+ WHISPER_LOG_INFO("%s: detecting speech timestamps using %d probabilities\n", __func__, whisper_vad_n_probs(vctx));
+
+ int n_probs = whisper_vad_n_probs(vctx);
+ float * probs = whisper_vad_probs(vctx);
+ float threshold = params.threshold;
+ int min_speech_duration_ms = params.min_speech_duration_ms;
+ int min_silence_duration_ms = params.min_silence_duration_ms;
+ float max_speech_duration_s = params.max_speech_duration_s;
+ int speech_pad_ms = params.speech_pad_ms;
+ int n_window = vctx->n_window;
+ int sample_rate = WHISPER_SAMPLE_RATE;
+ int min_silence_samples = sample_rate * min_silence_duration_ms / 1000;
+ int audio_length_samples = n_probs * n_window;
+
+ // Min number of samples to be considered valid speech.
+ int min_speech_samples = sample_rate * min_speech_duration_ms / 1000;
+ int speech_pad_samples = sample_rate * speech_pad_ms / 1000;
+
+ // Max number of samples that a speech segment can contain before it is
+ // split into multiple segments.
+ int max_speech_samples;
+ if (max_speech_duration_s > 100000.0f) {
+ max_speech_samples = INT_MAX / 2;
+ } else {
+ int64_t temp = (int64_t)sample_rate * (int64_t)(max_speech_duration_s) - n_window - 2 * speech_pad_samples;
+ max_speech_samples = (temp > INT_MAX) ? INT_MAX / 2 : (int)temp;
+ if (max_speech_samples < 0) {
+ max_speech_samples = INT_MAX / 2;
+ }
+ }
+ // Detect silence period that exceeds this value, then that location (sample)
+ // is marked as a potential place where the segment could be split if
+ // max_speech_samples is reached. The value 98 was taken from the original
+ // silaro-vad python implementation:
+ //https://github.com/snakers4/silero-vad/blob/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/src/silero_vad/utils_vad.py#L291
+ int min_silence_samples_at_max_speech = sample_rate * 98 / 1000;
+
+ // Calculate lower threshold for detecting end of speech segments.
+ float neg_threshold = threshold - 0.15f;
+ if (neg_threshold < 0.01f) {
+ neg_threshold = 0.01f;
+ }
+
+ struct speech_segment_t {
+ int start;
+ int end;
+ };
+
+ std::vector<speech_segment_t> speeches;
+ speeches.reserve(256);
+
+ bool is_speech_segment = false;
+ int temp_end = 0;
+ int prev_end = 0;
+ int next_start = 0;
+ int curr_speech_start = 0;
+ bool has_curr_speech = false;
+
+ for (int i = 0; i < n_probs; i++) {
+ float curr_prob = probs[i];
+ int curr_sample = n_window * i;
+
+ // Reset temp_end when we get back to speech
+ if ((curr_prob >= threshold) && temp_end) {
+ temp_end = 0;
+ if (next_start < prev_end) {
+ next_start = curr_sample;
+ }
+ }
+
+ // Start a new speech segment when probability exceeds threshold and not already in speech
+ if ((curr_prob >= threshold) && !is_speech_segment) {
+ is_speech_segment = true;
+ curr_speech_start = curr_sample;
+ has_curr_speech = true;
+ continue;
+ }
+
+ // Handle maximum speech duration
+ if (is_speech_segment && (curr_sample - curr_speech_start) > max_speech_samples) {
+ if (prev_end) {
+ speeches.push_back({ curr_speech_start, prev_end });
+ has_curr_speech = true;
+
+ if (next_start < prev_end) { // Previously reached silence and is still not speech
+ is_speech_segment = false;
+ has_curr_speech = false;
+ } else {
+ curr_speech_start = next_start;
+ }
+ prev_end = next_start = temp_end = 0;
+ } else {
+ speeches.push_back({ curr_speech_start, curr_sample });
+
+ prev_end = next_start = temp_end = 0;
+ is_speech_segment = false;
+ has_curr_speech = false;
+ continue;
+ }
+ }
+
+ // Handle silence after speech
+ if ((curr_prob < neg_threshold) && is_speech_segment) {
+ if (!temp_end) {
+ temp_end = curr_sample;
+ }
+
+ // Track potential segment ends for max_speech handling
+ if ((curr_sample - temp_end) > min_silence_samples_at_max_speech) {
+ prev_end = temp_end;
+ }
+
+ // Check if silence is long enough to end the segment
+ if ((curr_sample - temp_end) < min_silence_samples) {
+ continue;
+ } else {
+ // End the segment if it's long enough
+ if ((temp_end - curr_speech_start) > min_speech_samples) {
+ speeches.push_back({ curr_speech_start, temp_end });
+ }
+
+ prev_end = next_start = temp_end = 0;
+ is_speech_segment = false;
+ has_curr_speech = false;
+ continue;
+ }
+ }
+ }
+
+ // Handle the case if we're still in a speech segment at the end
+ if (has_curr_speech && (audio_length_samples - curr_speech_start) > min_speech_samples) {
+ speeches.push_back({ curr_speech_start, audio_length_samples });
+ }
+
+ // Merge adjacent segments with small gaps in between (post-processing)
+ if (speeches.size() > 1) {
+ int merged_count = 0;
+ for (int i = 0; i < (int) speeches.size() - 1; i++) {
+ // Define maximum gap allowed for merging (e.g., 200ms converted to samples)
+ int max_merge_gap_samples = sample_rate * 200 / 1000;
+
+ // If the gap between this segment and the next is small enough
+ if (speeches[i+1].start - speeches[i].end < max_merge_gap_samples) {
+ // Merge by extending current segment to the end of next segment
+ speeches[i].end = speeches[i+1].end;
+ speeches.erase(speeches.begin() + i + 1);
+
+ i--;
+ merged_count++;
+ }
+ }
+ WHISPER_LOG_INFO("%s: Merged %d adjacent segments, now have %d segments\n",
+ __func__, merged_count, (int) speeches.size());
+ }
+
+ // Double-check for minimum speech duration
+ for (int i = 0; i < (int) speeches.size(); i++) {
+ if (speeches[i].end - speeches[i].start < min_speech_samples) {
+ WHISPER_LOG_INFO("%s: Removing segment %d (too short: %d samples)\n",
+ __func__, i, speeches[i].end - speeches[i].start);
+
+ speeches.erase(speeches.begin() + i);
+ i--;
+ }
+ }
+
+ WHISPER_LOG_INFO("%s: Final speech segments after filtering: %d\n", __func__, (int) speeches.size());
+
+ // Allocate final segments
+ std::vector<whisper_vad_segment> segments;
+ if (speeches.size() > 0) {
+ try {
+ segments.resize(speeches.size());
+ } catch (const std::bad_alloc &) {
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for final segments\n", __func__);
+ return nullptr;
+ }
+ }
+
+ // Apply padding to segments and copy to final segments
+ for (int i = 0; i < (int) speeches.size(); i++) {
+ // Apply padding to the start of the first segment
+ if (i == 0) {
+ speeches[i].start =
+ (speeches[i].start > speech_pad_samples) ?
+ (speeches[i].start - speech_pad_samples) : 0;
+ }
+
+ // Handle spacing between segments
+ if (i < (int) speeches.size() - 1) {
+ int silence_duration = speeches[i+1].start - speeches[i].end;
+
+ if (silence_duration < 2 * speech_pad_samples) {
+ // If segments are close, split the difference
+ speeches[i].end += silence_duration / 2;
+ speeches[i+1].start =
+ (speeches[i+1].start > silence_duration / 2) ?
+ (speeches[i+1].start - silence_duration / 2) : 0;
+ } else {
+ // Otherwise, apply full padding to both
+ speeches[i].end =
+ (speeches[i].end + speech_pad_samples < audio_length_samples) ?
+ (speeches[i].end + speech_pad_samples) : audio_length_samples;
+ speeches[i+1].start =
+ (speeches[i+1].start > speech_pad_samples) ?
+ (speeches[i+1].start - speech_pad_samples) : 0;
+ }
+ } else {
+ // Apply padding to the end of the last segment
+ speeches[i].end =
+ (speeches[i].end + speech_pad_samples < audio_length_samples) ?
+ (speeches[i].end + speech_pad_samples) : audio_length_samples;
+ }
+
+ // Convert from samples to seconds and copy to final segments
+ segments[i].start = (float)speeches[i].start / sample_rate;
+ segments[i].end = (float)speeches[i].end / sample_rate;
+
+ WHISPER_LOG_INFO("%s: VAD segment %d: start = %.2f, end = %.2f (duration: %.2f)\n",
+ __func__, i, segments[i].start, segments[i].end, segments[i].end - segments[i].start);
+ }
+
+ whisper_vad_segments * vad_segments = new whisper_vad_segments;
+ if (vad_segments == NULL) {
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for whisper_vad_segments\n", __func__);
+ return nullptr;
+ }
+
+ vad_segments->data = std::move(segments);
+
+ return vad_segments;
+}
+
+struct whisper_vad_segments * whisper_vad_segments_from_samples(
+ whisper_vad_context * vctx,
+ whisper_vad_params params,
+ const float * samples,
+ int n_samples) {
+ WHISPER_LOG_INFO("%s: detecting speech timestamps in %d samples\n", __func__, n_samples);
+ if (!whisper_vad_detect_speech(vctx, samples, n_samples)) {
+ WHISPER_LOG_ERROR("%s: failed to detect speech\n", __func__);
+ return nullptr;
+ }
+ return whisper_vad_segments_from_probs(vctx, params);
+}
+
+void whisper_vad_free(whisper_vad_context * ctx) {
+ if (ctx) {
+ for (ggml_context * context : ctx->model.ctxs) {
+ ggml_free(context);
+ }
+
+ for (ggml_backend_buffer_t buf : ctx->model.buffers) {
+ ggml_backend_buffer_free(buf);
+ }
+
+ ggml_backend_sched_free(ctx->sched.sched);
+
+ for (auto & backend : ctx->backends) {
+ ggml_backend_free(backend);
+ }
+
+
+ delete ctx;
+ }
+}
+
+void whisper_vad_free_segments(whisper_vad_segments * segments) {
+ if (segments) {
+ delete segments;
+ }
+}
+
+//////////////////////////////////
+// Grammar - ported from llama.cpp
+//////////////////////////////////
+
+// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
+// pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
+static std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
+ const char * src,
+ whisper_partial_utf8 partial_start) {
+ static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
+ const char * pos = src;
+ std::vector<uint32_t> code_points;
+ uint32_t value = partial_start.value;
+ int n_remain = partial_start.n_remain;
+
+ // continue previous decode, if applicable
+ while (*pos != 0 && n_remain > 0) {
+ uint8_t next_byte = static_cast<uint8_t>(*pos);
+ if ((next_byte >> 6) != 2) {
+ // invalid sequence, abort
+ code_points.push_back(0);
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
+ }
+ value = (value << 6) + (next_byte & 0x3F);
+ ++pos;
+ --n_remain;
+ }
+
+ if (partial_start.n_remain > 0 && n_remain == 0) {
+ code_points.push_back(value);
+ }
+
+ // decode any subsequent utf-8 sequences, which may end in an incomplete one
+ while (*pos != 0) {
+ uint8_t first_byte = static_cast<uint8_t>(*pos);
+ uint8_t highbits = first_byte >> 4;
+ n_remain = lookup[highbits] - 1;
+
+ if (n_remain < 0) {
+ // invalid sequence, abort
+ code_points.clear();
+ code_points.push_back(0);
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain });
+ }
+
+ uint8_t mask = (1 << (7 - n_remain)) - 1;
+ value = first_byte & mask;
+ ++pos;
+ while (*pos != 0 && n_remain > 0) {
+ value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
+ ++pos;
+ --n_remain;
+ }
+ if (n_remain == 0) {
+ code_points.push_back(value);
+ }
+ }
+ code_points.push_back(0);
+
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ value, n_remain });
+}
+
+// returns true iff pos points to the end of one of the definitions of a rule
+static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element * pos) {
+ switch (pos->type) {
+ case WHISPER_GRETYPE_END: return true; // NOLINT
+ case WHISPER_GRETYPE_ALT: return true; // NOLINT
+ default: return false;
+ }
+}
+
+// returns true iff chr satisfies the char range at pos (regular or inverse range)
+// asserts that pos is pointing to a char range element
+static std::pair<bool, const whisper_grammar_element *> whisper_grammar_match_char(
+ const whisper_grammar_element * pos,
+ const uint32_t chr) {
+
+ bool found = false;
+ bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
+
+ WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT
+
+ do {
+ if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
+ // inclusive range, e.g. [a-z]
+ found = found || (pos->value <= chr && chr <= pos[1].value);
+ pos += 2;
+ } else {
+ // exact char match, e.g. [a] or "a"
+ found = found || pos->value == chr;
+ pos += 1;
+ }
+ } while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
+
+ return std::make_pair(found == is_positive_char, pos);
+}
+
+// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
+// range at pos (regular or inverse range)
+// asserts that pos is pointing to a char range element
+static bool whisper_grammar_match_partial_char(
+ const whisper_grammar_element * pos,
+ const whisper_partial_utf8 partial_utf8) {
+
+ bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
+ WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT);
+
+ uint32_t partial_value = partial_utf8.value;
+ int n_remain = partial_utf8.n_remain;
+
+ // invalid sequence or 7-bit char split across 2 bytes (overlong)
+ if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
+ return false;
+ }
+
+ // range of possible code points this partial UTF-8 sequence could complete to
+ uint32_t low = partial_value << (n_remain * 6);
+ uint32_t high = low | ((1 << (n_remain * 6)) - 1);
+
+ if (low == 0) {
+ if (n_remain == 2) {
+ low = 1 << 11;
+ } else if (n_remain == 3) {
+ low = 1 << 16;
+ }
+ }
+
+ do {
+ if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
+ // inclusive range, e.g. [a-z]
+ if (pos->value <= high && low <= pos[1].value) {
+ return is_positive_char;
+ }
+ pos += 2;
+ } else {
+ // exact char match, e.g. [a] or "a"
+ if (low <= pos->value && pos->value <= high) {
+ return is_positive_char;
+ }
+ pos += 1;
+ }
+ } while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
+
+ return !is_positive_char;
+}
+
+
+// transforms a grammar pushdown stack into N possible stacks, all ending
+// at a character range (terminal element)
+static void whisper_grammar_advance_stack(
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
+ const std::vector<const whisper_grammar_element *> & stack,
+ std::vector<std::vector<const whisper_grammar_element *>> & new_stacks) {
+
+ if (stack.empty()) {
+ new_stacks.emplace_back();
+ return;
+ }
+
+ const whisper_grammar_element * pos = stack.back();
+
+ switch (pos->type) {
+ case WHISPER_GRETYPE_RULE_REF: {
+ const size_t rule_id = static_cast<size_t>(pos->value);
+ const whisper_grammar_element * subpos = rules[rule_id].data();
+ do {
+ // init new stack without the top (pos)
+ std::vector<const whisper_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
+ if (!whisper_grammar_is_end_of_sequence(pos + 1)) {
+ // if this rule ref is followed by another element, add that to stack
+ new_stack.push_back(pos + 1);
+ }
+ if (!whisper_grammar_is_end_of_sequence(subpos)) {
+ // if alternate is nonempty, add to stack
+ new_stack.push_back(subpos);
+ }
+ whisper_grammar_advance_stack(rules, new_stack, new_stacks);
+ while (!whisper_grammar_is_end_of_sequence(subpos)) {
+ // scan to end of alternate def
+ subpos++;
+ }
+ if (subpos->type == WHISPER_GRETYPE_ALT) {
+ // there's another alternate def of this rule to process
+ subpos++;
+ } else {
+ break;
+ }
+ } while (true);
+ break;
+ }
+ case WHISPER_GRETYPE_CHAR:
+ case WHISPER_GRETYPE_CHAR_NOT:
+ new_stacks.push_back(stack);
+ break;
+ default:
+ // end of alternate (WHISPER_GRETYPE_END, WHISPER_GRETYPE_ALT) or middle of char range
+ // (WHISPER_GRETYPE_CHAR_ALT, WHISPER_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
+ // those
+ WHISPER_ASSERT(false);
+ }
+}
+
+// takes a set of possible pushdown stacks on a grammar, which are required to
+// be positioned at a character range (see `whisper_grammar_advance_stack`), and
+// produces the N possible stacks if the given char is accepted at those
+// positions
+static std::vector<std::vector<const whisper_grammar_element *>> whisper_grammar_accept(
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
+ const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
+ const uint32_t chr) {
+
+ std::vector<std::vector<const whisper_grammar_element *>> new_stacks;
+
+ for (const auto & stack : stacks) {
+ if (stack.empty()) {
+ continue;
+ }
+
+ auto match = whisper_grammar_match_char(stack.back(), chr);
+ if (match.first) {
const whisper_grammar_element * pos = match.second;
// update top of stack to next element, if any
/*.n_grammar_rules =*/ 0,
/*.i_start_rule =*/ 0,
/*.grammar_penalty =*/ 100.0f,
+
+ /*.vad =*/ false,
+ /*.vad_model_path =*/ nullptr,
+
+ /* vad_params =*/ whisper_vad_default_params(),
};
switch (strategy) {
}
}
+static bool whisper_vad(
+ struct whisper_context * ctx,
+ struct whisper_state * state,
+ struct whisper_full_params params,
+ const float * samples,
+ int n_samples,
+ std::vector<float> & filtered_samples,
+ int & filtered_n_samples) {
+ WHISPER_LOG_INFO("%s: VAD is enabled, processing speach segments only\n", __func__);
+ filtered_n_samples = 0;
+
+ struct whisper_vad_context_params vad_ctx_params = whisper_vad_default_context_params();
+ struct whisper_vad_context * vctx = whisper_vad_init_from_file_with_params(params.vad_model_path, vad_ctx_params);
+ if (vctx == nullptr) {
+ WHISPER_LOG_ERROR("%s: failed to initialize VAD context\n", __func__);
+ return false;
+ }
+
+ const whisper_vad_params & vad_params = params.vad_params;
+
+ whisper_vad_segments * vad_segments = whisper_vad_segments_from_samples(vctx, vad_params, samples, n_samples);
+
+ if (vad_segments->data.size() > 0) {
+ state->has_vad_segments = true;
+ ctx->state->vad_segments.clear();
+ ctx->state->vad_segments.reserve(vad_segments->data.size());
+
+ WHISPER_LOG_INFO("%s: detected %d speech segments\n", __func__, (int)vad_segments->data.size());
+ float overlap_seconds = vad_params.samples_overlap;
+ int overlap_samples = overlap_seconds * WHISPER_SAMPLE_RATE;
+
+ for (int i = 0; i < (int)vad_segments->data.size(); i++) {
+ int segment_start_samples = vad_segments->data[i].start * WHISPER_SAMPLE_RATE;
+ int segment_end_samples = vad_segments->data[i].end * WHISPER_SAMPLE_RATE;
+
+ if (i < (int)vad_segments->data.size() - 1) {
+ segment_end_samples += overlap_samples;
+ }
+ segment_end_samples = std::min(segment_end_samples, n_samples - 1);
+ filtered_n_samples += (segment_end_samples - segment_start_samples);
+
+ WHISPER_LOG_INFO("%s: Including segment %d: %.2f - %.2f (duration: %.2f)\n",
+ __func__, i, vad_segments->data[i].start,
+ vad_segments->data[i].end + (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0),
+ (vad_segments->data[i].end - vad_segments->data[i].start) +
+ (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0));
+ }
+
+ int silence_samples = 0.1 * WHISPER_SAMPLE_RATE;
+ int total_silence_samples = (vad_segments->data.size() > 1) ? (vad_segments->data.size() - 1) * silence_samples : 0;
+ int total_samples_needed = filtered_n_samples + total_silence_samples;
+
+ WHISPER_LOG_INFO("%s: total duration of speech segments: %.2f seconds\n",
+ __func__, (float)filtered_n_samples / WHISPER_SAMPLE_RATE);
+
+ try {
+ filtered_samples.resize(total_samples_needed);
+ } catch (const std::bad_alloc & /* e */) {
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for filtered samples\n", __func__);
+ whisper_vad_free_segments(vad_segments);
+ whisper_vad_free(vctx);
+ return false;
+ }
+
+ int offset = 0;
+ for (int i = 0; i < (int)vad_segments->data.size(); i++) {
+ int segment_start_samples = vad_segments->data[i].start * WHISPER_SAMPLE_RATE;
+ int segment_end_samples = vad_segments->data[i].end * WHISPER_SAMPLE_RATE;
+
+ if (i < (int)vad_segments->data.size() - 1) {
+ segment_end_samples += overlap_samples;
+ }
+
+ segment_start_samples = std::min(segment_start_samples, n_samples - 1);
+ segment_end_samples = std::min(segment_end_samples, n_samples);
+ int segment_length = segment_end_samples - segment_start_samples;
+
+ if (segment_length > 0) {
+ whisper_state::vad_segment_info segment;
+
+ segment.orig_start = vad_segments->data[i].start;
+ segment.orig_end = vad_segments->data[i].end;
+
+ segment.vad_start = offset / (float)WHISPER_SAMPLE_RATE;
+ segment.vad_end = (offset + segment_length) / (float)WHISPER_SAMPLE_RATE;
+
+ WHISPER_LOG_INFO("%s: vad_segment_info: orig_start: %.2f, orig_end: %.2f, vad_start: %.2f, vad_end: %.2f\n",
+ __func__, segment.orig_start, segment.orig_end, segment.vad_start, segment.vad_end);
+ ctx->state->vad_segments.push_back(segment);
+
+ // Copy this speech segment
+ memcpy(filtered_samples.data() + offset, samples + segment_start_samples, segment_length * sizeof(float));
+ offset += segment_length;
+
+ // Add silence after this segment (except after the last segment)
+ if (i < (int)vad_segments->data.size() - 1) {
+ // Fill with zeros (silence)
+ memset(filtered_samples.data() + offset, 0, silence_samples * sizeof(float));
+ offset += silence_samples;
+ }
+ }
+ }
+
+ filtered_n_samples = offset;
+ WHISPER_LOG_INFO("%s: Reduced audio from %d to %d samples (%.1f%% reduction)\n",
+ __func__, n_samples, filtered_n_samples, 100.0f * (1.0f - (float)filtered_n_samples / n_samples));
+ }
+
+ return true;
+}
+
int whisper_full_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
result_all.clear();
- if (n_samples > 0) {
+ const float * process_samples = samples;
+ int n_process_samples = n_samples;
+ std::vector<float> vad_samples;
+
+ if (params.vad) {
+ WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
+ int vad_n_samples;
+ if (!whisper_vad(ctx, state, params, samples, n_samples, vad_samples, vad_n_samples)) {
+ WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__);
+ return -1;
+ }
+ process_samples = vad_samples.data();
+ n_process_samples = vad_n_samples;
+ }
+
+ if (n_process_samples > 0) {
// compute log mel spectrogram
- if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
+ if (whisper_pcm_to_mel_with_state(ctx, state, process_samples, n_process_samples, params.n_threads) != 0) {
WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
return -2;
}
}
int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) {
- return state->result_all[i_segment].t0;
+ // If VAD wasn't used, return the original timestamp
+ if (!state->has_vad_segments || state->vad_segments.empty()) {
+ return state->result_all[i_segment].t0;
+ }
+
+ // Get the start timestamp produced by whisper_full. whisper_full processes
+ // only the speech segments in this case so we need to map these timestamps
+ // back to the original audio.
+ float t0 = state->result_all[i_segment].t0 / 100.0f;
+
+ // Find which VAD segment this timestamp belongs.
+ // TODO(danbev) This could be optimized by using a binary search if the number
+ // of segments exceed a certain limit. Also we might be able to assume that
+ // the access pattern is sequential and optimized for that too.
+ for (size_t i = 0; i < state->vad_segments.size(); i++) {
+ const auto & segment = state->vad_segments[i];
+
+ // Check if the timestamp falls within this segment.
+ if (t0 >= segment.vad_start && t0 <= segment.vad_end) {
+ float proportion = 0.0f;
+ if (segment.vad_end > segment.vad_start) {
+ proportion = (t0 - segment.vad_start) / (segment.vad_end - segment.vad_start);
+ }
+ float orig_t0 = segment.orig_start + proportion * (segment.orig_end - segment.orig_start);
+ return (int64_t)(orig_t0 * 100);
+ }
+ }
+
+ // Check if the timestamp falls between two segments.
+ for (size_t i = 0; i < state->vad_segments.size() - 1; i++) {
+ const auto & curr = state->vad_segments[i];
+ const auto & next = state->vad_segments[i + 1];
+
+ if (t0 > curr.vad_end && t0 < next.vad_start) {
+ // Calculate how far we are through the gap as a proportion
+ float gap_proportion = 0.0f;
+ if (next.vad_start > curr.vad_end) {
+ gap_proportion = (t0 - curr.vad_end) / (next.vad_start - curr.vad_end);
+ }
+ float orig_t0 = curr.orig_end + gap_proportion * (next.orig_start - curr.orig_end);
+ return (int64_t)(orig_t0 * 100);
+ }
+ }
+
+ // Handle the case where the timestamp is after the last segment.
+ if (t0 > state->vad_segments.back().vad_end) {
+ // For timestamps after the last segment, add the extra time to the end of the last segment
+ const auto& last = state->vad_segments.back();
+ // Calculate how far beyond the last segment
+ float extra_time = t0 - last.vad_end;
+ // Add this extra time to the original end time
+ float orig_t0 = last.orig_end + extra_time;
+ return (int64_t)(orig_t0 * 100);
+ }
+
+ WHISPER_LOG_WARN("%s: Could not map t0 = %f to a VAD segment\n", __func__, t0);
+ return t0;
}
int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
- return ctx->state->result_all[i_segment].t0;
+ return whisper_full_get_segment_t0_from_state(ctx->state, i_segment);
}
int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) {
- return state->result_all[i_segment].t1;
+ // If VAD wasn't used, return the original timestamp
+ if (!state->has_vad_segments || state->vad_segments.empty()) {
+ return state->result_all[i_segment].t1;
+ }
+
+ // Get the end timestamp produced by whisper_full. whisper_full processes
+ // only the speech segments in this case so we need to map these timestamps
+ // back to the original audio.
+ float t1 = state->result_all[i_segment].t1 / 100.0f;
+
+ // Find which VAD segment this timestamp belongs.
+ // TODO(danbev) This could be optimized by using a binary search if the number
+ // of segments exceed a certain limit. Also we might be able to assume that
+ // the access pattern is sequential and optimized for that too.
+ for (size_t i = 0; i < state->vad_segments.size(); i++) {
+ const auto& segment = state->vad_segments[i];
+
+ // Check if the timestamp falls within this segment.
+ if (t1 >= segment.vad_start && t1 <= segment.vad_end) {
+ // Calculate the proportion through the filtered segment.
+ float proportion = 0.0f;
+ if (segment.vad_end > segment.vad_start) {
+ proportion = (t1 - segment.vad_start) / (segment.vad_end - segment.vad_start);
+ }
+ float orig_t1 = segment.orig_start + proportion * (segment.orig_end - segment.orig_start);
+ return (int64_t)(orig_t1 * 100);
+ }
+ }
+
+ // Check if the timestamp falls between two segments.
+ for (size_t i = 0; i < state->vad_segments.size() - 1; i++) {
+ const auto & curr = state->vad_segments[i];
+ const auto & next = state->vad_segments[i + 1];
+
+ if (t1 > curr.vad_end && t1 < next.vad_start) {
+ // Calculate how far we are through the gap as a proportion
+ float gap_proportion = 0.0f;
+ if (next.vad_start > curr.vad_end) {
+ gap_proportion = (t1 - curr.vad_end) / (next.vad_start - curr.vad_end);
+ }
+ // Map to the corresponding position in the original gap
+ float orig_t1 = curr.orig_end + gap_proportion * (next.orig_start - curr.orig_end);
+ return (int64_t)(orig_t1 * 100);
+ }
+ }
+
+ // Handle the case where the timestamp is after the last segment
+ if (t1 > state->vad_segments.back().vad_end) {
+ // For the last segment, use the end of the last VAD segment
+ const auto& last = state->vad_segments.back();
+ // Calculate how far beyond the last segment
+ float extra_time = t1 - last.vad_end;
+ // Add this extra time to the original end time
+ float orig_t1 = last.orig_end + extra_time;
+ return (int64_t)(orig_t1 * 100);
+ }
+
+ WHISPER_LOG_WARN("%s: Could not map t1 = %f to a VAD segment\n", __func__, t1);
+ return t1;
}
int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
- return ctx->state->result_all[i_segment].t1;
+ return whisper_full_get_segment_t1_from_state(ctx->state, i_segment);
}
bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment) {