#include <thread>
#include <vector>
#include <regex>
+#include <random>
+
+#define WHISPER_ASSERT(x) \
+ do { \
+ if (!(x)) { \
+ fprintf(stderr, "WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
+ abort(); \
+ } \
+ } while (0)
+
+// define this to enable verbose trace logging - useful for debugging purposes
+//#define WHISPER_DEBUG
+
+#if defined(WHISPER_DEBUG)
+#define WHISPER_PRINT_DEBUG(...) \
+ do { \
+ fprintf(stderr, __VA_ARGS__); \
+ } while (0)
+#else
+#define WHISPER_PRINT_DEBUG(...)
+#endif
-#define USE_FLASH_ATTN
-//#define USE_FLASH_FF
+#define WHISPER_USE_FLASH_ATTN
+//#define WHISPER_USE_FLASH_FF
+#define WHISPER_MAX_DECODERS 16
// available whisper models
enum e_model {
{ MODEL_LARGE, 2952ull*MB },
};
-static const std::map<e_model, size_t> MEM_REQ_MEMORY = {
- { MODEL_TINY, 12ull*MB },
- { MODEL_BASE, 24ull*MB },
- { MODEL_SMALL, 70ull*MB },
- { MODEL_MEDIUM, 184ull*MB },
- { MODEL_LARGE, 306ull*MB },
+static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
+ { MODEL_TINY, 3ull*MB },
+ { MODEL_BASE, 6ull*MB },
+ { MODEL_SMALL, 16ull*MB },
+ { MODEL_MEDIUM, 43ull*MB },
+ { MODEL_LARGE, 71ull*MB },
+};
+
+static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
+ { MODEL_TINY, 9ull*MB },
+ { MODEL_BASE, 18ull*MB },
+ { MODEL_SMALL, 53ull*MB },
+ { MODEL_MEDIUM, 141ull*MB },
+ { MODEL_LARGE, 235ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
std::map<token, id> token_to_id;
std::map<id, token> id_to_token;
- // used to avoid memory allocations during sampling
- // TODO: move to whisper_context in the future
- std::vector<std::pair<double, whisper_vocab::id>> probs_id;
-
id token_eot = 50256;
id token_sot = 50257;
id token_prev = 50360;
struct ggml_tensor * mlp_1_b;
};
+struct whisper_kv_cache {
+ struct ggml_tensor * k;
+ struct ggml_tensor * v;
+
+ struct ggml_context * ctx;
+
+ std::vector<uint8_t> buf;
+
+ int n; // number of tokens currently in the cache
+};
+
struct whisper_model {
e_model type = MODEL_UNKNOWN;
struct ggml_tensor * e_ln_b;
// decoder.positional_embedding
- struct ggml_tensor * d_pe; // DD
+ struct ggml_tensor * d_pe;
// decoder.token_embedding
- struct ggml_tensor * d_te; // DD
+ struct ggml_tensor * d_te;
// decoder.ln
- struct ggml_tensor * d_ln_w; // DD
- struct ggml_tensor * d_ln_b; // DD
+ struct ggml_tensor * d_ln_w;
+ struct ggml_tensor * d_ln_b;
std::vector<whisper_layer_encoder> layers_encoder;
std::vector<whisper_layer_decoder> layers_decoder;
- // key + value memory
- struct ggml_tensor * memory_k;
- struct ggml_tensor * memory_v;
-
- struct ggml_tensor * memory_cross_k;
- struct ggml_tensor * memory_cross_v;
-
// context
struct ggml_context * ctx;
- struct ggml_context * ctx_mem;
+
+ // the model memory buffer is read-only and can be shared between processors
+ std::vector<uint8_t> * buf;
// tensors
int n_loaded;
std::map<std::string, struct ggml_tensor *> tensors;
};
+struct whisper_sequence {
+ std::vector<whisper_token_data> tokens;
+
+ // the accumulated transcription in the current interation (used to truncate the tokens array)
+ int result_len;
+
+ double sum_logprobs_all; // the sum of the log probabilities of the tokens
+ double sum_logprobs; // the sum of the log probabilities of the tokens (first result_len tokens)
+ double avg_logprobs; // the average log probability of the tokens
+ double entropy; // the entropy of the tokens
+ double score; // likelihood rank score
+};
+
+// TAGS: WHISPER_DECODER_INIT
+struct whisper_decoder {
+ // each decoders keeps its own KV-cache
+ whisper_kv_cache kv_self;
+
+ // the currently generated sequence of tokens
+ whisper_sequence sequence;
+
+ int seek_delta; // the window shift found so far based on the decoded timestamp tokens
+
+ bool failed; // has the current segment failed to decode?
+ bool completed; // has the decoder completed the current segment?
+ bool has_ts; // have we already sampled a non-beg timestamp token for the current segment?
+
+ // new token probs, logits and logprobs after the last whisper_decode (1-dimensional array: [n_vocab])
+ std::vector<float> probs;
+ std::vector<float> logits;
+ std::vector<float> logprobs;
+
+ std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
+};
+
struct whisper_context {
int64_t t_load_us = 0;
int64_t t_mel_us = 0;
int64_t t_decode_us = 0;
int64_t t_start_us = 0;
- std::vector<uint8_t> * buf_model; // the model buffer is read-only and can be shared between processors
- std::vector<uint8_t> buf_memory;
- std::vector<uint8_t> buf_compute;
- std::vector<uint8_t> buf_compute_layer;
-
ggml_type wtype; // weight type (FP32 or FP16)
+ whisper_mel mel;
+
whisper_model model;
whisper_vocab vocab;
- whisper_mel mel;
+ // cross-attention KV cache for the decoders
+ // shared between all decoders
+ whisper_kv_cache kv_cross;
- std::vector<float> probs;
+ whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
+
+ // memory buffers used by encode / decode contexts
+ std::vector<uint8_t> buf_compute;
+ std::vector<uint8_t> buf_compute_layer;
+
+ // decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits;
std::vector<whisper_segment> result_all;
+ std::vector<whisper_token> prompt_past;
+
+ // work container used to avoid memory allocations
+ std::vector<std::pair<double, whisper_vocab::id>> logits_id;
- std::vector<whisper_token> prompt_past;
+ mutable std::mt19937 rng; // used for sampling at t > 0.0
// [EXPERIMENTAL] token-level timestamps data
int64_t t_beg;
loader->read(loader->context, &dest, sizeof(T));
}
+static bool kv_cache_init(
+ const struct whisper_hparams & hparams,
+ const size_t mem_bytes,
+ struct whisper_kv_cache & cache,
+ ggml_type wtype,
+ int n_ctx) {
+ cache.buf.resize(mem_bytes);
+
+ struct ggml_init_params params;
+ params.mem_size = cache.buf.size();
+ params.mem_buffer = cache.buf.data();
+
+ cache.ctx = ggml_init(params);
+
+ if (!cache.ctx) {
+ fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
+ return false;
+ }
+
+ const int n_text_state = hparams.n_text_state;
+ const int n_text_layer = hparams.n_text_layer;
+
+ const int n_mem = n_text_layer*n_ctx;
+ const int n_elements = n_text_state*n_mem;
+
+ cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
+ cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
+
+ return true;
+}
+
+static bool kv_cache_reinit(struct whisper_kv_cache & cache) {
+ WHISPER_ASSERT(cache.ctx);
+
+ const int n_elements = ggml_nelements(cache.k);
+ WHISPER_ASSERT(n_elements == ggml_nelements(cache.v));
+
+ const ggml_type wtype = cache.k->type;
+ WHISPER_ASSERT(wtype == cache.v->type);
+
+ WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_size(wtype));
+
+ struct ggml_init_params params;
+ params.mem_size = cache.buf.size();
+ params.mem_buffer = cache.buf.data();
+
+ cache.ctx = ggml_init(params);
+
+ if (!cache.ctx) {
+ fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
+ return false;
+ }
+
+ cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
+ cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
+
+ return true;
+}
+
+static void kv_cache_free(struct whisper_kv_cache & cache) {
+ if (cache.ctx) {
+ ggml_free(cache.ctx);
+ cache.ctx = nullptr;
+ }
+}
+
// load the model from a ggml file
//
// file format:
static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
fprintf(stderr, "%s: loading model\n", __func__);
+ const int64_t t_start_us = ggml_time_us();
+
+ wctx.t_start_us = t_start_us;
+
auto & model = wctx.model;
auto & vocab = wctx.vocab;
model.type = e_model::MODEL_LARGE;
}
+ // for the big tensors, we have the option to store the data in 16-bit floats
+ // in order to save memory and also to speed up the computation
+ wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
+
+ const size_t scale = model.hparams.f16 ? 1 : 2;
+
fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16);
fprintf(stderr, "%s: type = %d\n", __func__, model.type);
- wctx.buf_model = new std::vector<uint8_t>();
- wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type));
- wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type));
- wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
- wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
+ // print memory requirements
+ {
+ // this is the total memory required to run the inference
+ const size_t mem_required =
+ scale*MEM_REQ_MODEL.at (model.type) +
+ scale*MEM_REQ_KV_CROSS.at (model.type) +
+ scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)) +
+ scale*std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type));
+
+ // this is the memory required by one decoder
+ const size_t mem_required_decoder =
+ scale*MEM_REQ_KV_SELF.at(model.type);
+
+ fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
+ mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
+ }
+
+ // initialize all memory buffers
+ // always have at least one decoder
+
+ wctx.model.buf = new std::vector<uint8_t>();
+ wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type));
+
+ if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) {
+ fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
+ return false;
+ }
+
+ {
+ const size_t memory_size = ggml_nbytes(wctx.decoders[0].kv_self.k) + ggml_nbytes(wctx.decoders[0].kv_self.v);
+ fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
+ }
+
+ if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) {
+ fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
+ return false;
+ }
+
+ {
+ const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v);
+ fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
+ }
+
+ wctx.buf_compute.resize (scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
+ wctx.buf_compute_layer.resize(scale*std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
}
// load mel filters
}
wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
- wctx.probs.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
- vocab.probs_id.reserve(n_vocab);
- }
+ wctx.logits_id.reserve(n_vocab);
- {
- // this is the total memory required to run the inference
- const size_t mem_required =
- wctx.buf_model->size() +
- wctx.buf_memory.size() +
- wctx.buf_compute.size() +
- wctx.buf_compute_layer.size();
+ // TAGS: WHISPER_DECODER_INIT
+ wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx);
- fprintf(stderr, "%s: mem_required = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0);
+ wctx.decoders[0].probs.reserve (vocab.n_vocab);
+ wctx.decoders[0].logits.reserve (vocab.n_vocab);
+ wctx.decoders[0].logprobs.reserve(vocab.n_vocab);
}
- // for the big tensors, we have the option to store the data in 16-bit floats
- // in order to save memory and also to speed up the computation
- wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
+ size_t ctx_size = 0;
const ggml_type wtype = wctx.wtype;
- size_t ctx_size = 0;
-
{
const auto & hparams = model.hparams;
ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
- fprintf(stderr, "%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
+ fprintf(stderr, "%s: model ctx = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
}
// create the ggml context
{
struct ggml_init_params params;
- params.mem_size = wctx.buf_model->size();
- params.mem_buffer = wctx.buf_model->data();
+ params.mem_size = wctx.model.buf->size();
+ params.mem_buffer = wctx.model.buf->data();
model.ctx = ggml_init(params);
if (!model.ctx) {
}
}
- // create the ggml memory context
- {
- struct ggml_init_params params;
- params.mem_size = wctx.buf_memory.size();
- params.mem_buffer = wctx.buf_memory.data();
-
- model.ctx_mem = ggml_init(params);
- if (!model.ctx_mem) {
- fprintf(stderr, "%s: ggml_init() failed\n", __func__);
- return false;
- }
- }
-
- // key + value memory
- {
- auto & ctx = model.ctx_mem;
-
- const auto & hparams = model.hparams;
-
- const int n_text_state = hparams.n_text_state;
- const int n_text_layer = hparams.n_text_layer;
- const int n_text_ctx = hparams.n_text_ctx;
-
- // key/value memory for the self-attention layer
- {
- const int n_mem = n_text_layer*n_text_ctx;
- const int n_elements = n_text_state*n_mem;
-
- model.memory_k = ggml_new_tensor_1d(ctx, wtype, n_elements);
- model.memory_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
- }
-
- // key/value memory for the cross-attention layer
- {
- const int n_audio_ctx = hparams.n_audio_ctx;
-
- const int n_mem = n_text_layer*n_audio_ctx;
- const int n_elements = n_text_state*n_mem;
-
- model.memory_cross_k = ggml_new_tensor_1d(ctx, wtype, n_elements);
- model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
- }
-
- const size_t memory_size =
- ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) +
- ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
-
- fprintf(stderr, "%s: memory size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
- }
-
// load weights
{
size_t total_size = 0;
}
}
+ wctx.rng = std::mt19937(0);
+
+ wctx.t_load_us = ggml_time_us() - t_start_us;
+
return true;
}
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
//
static bool whisper_encode(
- whisper_context & wctx,
- const int n_threads,
- const int mel_offset) {
+ whisper_context & wctx,
+ const int mel_offset,
+ const int n_threads) {
+ const int64_t t_start_us = ggml_time_us();
+
const auto & model = wctx.model;
const auto & mel_inp = wctx.mel;
const auto & hparams = model.hparams;
// ------
-#ifdef USE_FLASH_ATTN
+#ifdef WHISPER_USE_FLASH_ATTN
struct ggml_tensor * Q =
ggml_permute(ctxL,
ggml_cpy(ctxL,
ggml_repeat(ctxL, layer.mlp_ln_b, cur));
}
-#ifdef USE_FLASH_FF
+#ifdef WHISPER_USE_FLASH_FF
cur = ggml_flash_ff(ctxL,
ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, wctx.wtype, n_state, N)),
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
Vcross),
Vcross);
- //struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
- //struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
- struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx));
- struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx));
+ //struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
+ //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
+ struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*n_ctx));
+ struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*n_ctx));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
ggml_free(ctx0);
+ wctx.t_encode_us += ggml_time_us() - t_start_us;
+
return true;
}
// - n_past: number of past tokens to prefix the prompt with
//
static bool whisper_decode(
- whisper_context & wctx,
- const int n_threads,
- const whisper_token * tokens,
- const int n_tokens,
- const int n_past) {
+ whisper_context & wctx,
+ whisper_decoder & decoder,
+ const whisper_token * tokens,
+ const int n_tokens,
+ const int n_past,
+ const int n_threads) {
+ const int64_t t_start_us = ggml_time_us();
+
const auto & model = wctx.model;
const auto & hparams = model.hparams;
+ auto & kv_self = decoder.kv_self;
+
+ WHISPER_ASSERT(!!kv_self.ctx);
+
auto & logits_out = wctx.logits;
- auto & probs_out = wctx.probs;
const int n_vocab = hparams.n_vocab;
const int N = n_tokens;
const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
+ //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
+
struct ggml_init_params params;
params.mem_size = wctx.buf_compute.size();
params.mem_buffer = wctx.buf_compute.data();
// store key and value to memory
{
- struct ggml_tensor * k = ggml_view_1d(ctxL, model.memory_k, N*n_state, (ggml_element_size(model.memory_k)*n_state)*(il*n_ctx + n_past));
- struct ggml_tensor * v = ggml_view_1d(ctxL, model.memory_v, N*n_state, (ggml_element_size(model.memory_v)*n_state)*(il*n_ctx + n_past));
+ struct ggml_tensor * k = ggml_view_1d(ctxL, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
+ struct ggml_tensor * v = ggml_view_1d(ctxL, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past));
ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k));
ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v));
struct ggml_tensor * K =
ggml_permute(ctxL,
ggml_reshape_3d(ctxL,
- ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state),
+ ggml_view_1d(ctxL, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state),
n_state/n_head, n_head, n_past + N),
0, 2, 1, 3);
struct ggml_tensor * V_trans =
ggml_permute(ctxL,
ggml_reshape_3d(ctxL,
- ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state),
+ ggml_view_1d(ctxL, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state),
n_state/n_head, n_head, n_past + N),
1, 2, 0, 3);
// Kcross is already scaled
struct ggml_tensor * Kcross =
ggml_reshape_3d(ctxL,
- ggml_view_1d(ctxL, model.memory_cross_k, M*n_state, il*M*ggml_element_size(model.memory_cross_k)*n_state),
+ ggml_view_1d(ctxL, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state),
n_state/n_head, n_head, M);
struct ggml_tensor * Vcross =
ggml_reshape_3d(ctxL,
- ggml_view_1d(ctxL, model.memory_cross_v, M*n_state, il*M*ggml_element_size(model.memory_cross_v)*n_state),
+ ggml_view_1d(ctxL, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state),
n_state/n_head, n_head, M);
// ------
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
- // logits -> probs
- cur = ggml_dup(ctx0, logits);
- cur = ggml_soft_max(ctx0, cur); // in-place
-
// run the computation
{
struct ggml_cgraph gf = {};
gf.n_threads = n_threads;
- ggml_build_forward_expand(&gf, cur);
+ ggml_build_forward_expand(&gf, logits);
ggml_graph_compute (ctx0, &gf);
}
logits_out.resize(N*n_vocab);
memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
- probs_out.resize(N*n_vocab);
- memcpy(probs_out.data(), ggml_get_data(cur), sizeof(float)*N*n_vocab);
-
if (N > 1) {
//const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N;
//printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token);
ggml_free(ctx0);
- return true;
-}
-
-// the most basic sampling scheme - select the top token
-static whisper_token_data whisper_sample_best(
- whisper_vocab & vocab,
- const float * probs,
- bool force_timestamp,
- bool is_initial) {
- whisper_token_data result = {
- 0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
- };
-
- const int n_logits = vocab.n_vocab;
-
- auto & probs_id = vocab.probs_id;
-
- probs_id.clear();
- for (int i = 0; i < n_logits; i++) {
- probs_id.emplace_back(probs[i], i);
- }
-
- {
- double sum_ts = 0.0;
- double max_ts = -1.0;
- double max_tx = -1.0;
-
- for (int i = 0; i < vocab.token_beg; i++) {
- max_tx = std::max(max_tx, probs_id[i].first);
- }
-
- const auto i0 = is_initial ? vocab.token_beg + 101 : vocab.token_beg;
- const auto i1 = is_initial ? vocab.token_beg + 101 : n_logits;
-
- // the initial timestamp cannot be larger than 100
- // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
- if (is_initial) {
- for (int i = i0; i < n_logits; ++ i) {
- probs_id[i].first = -INFINITY;
- }
- }
-
- for (int i = vocab.token_beg; i < i1; i++) {
- sum_ts += probs_id[i].first;
- if (probs_id[i].first > max_ts) {
- max_ts = probs_id[i].first;
- result.tid = probs_id[i].second;
- }
- }
-
- // if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a
- // timestamp token
- if (sum_ts > max_tx || force_timestamp) {
- // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
- for (int i = 0; i < vocab.token_beg; i++) {
- probs_id[i].first = -INFINITY;
- }
- }
-
- result.pt = max_ts/(sum_ts + 1e-10);
- result.ptsum = sum_ts;
- }
-
- // find the top K tokens
- const int top_k = 4;
-
- std::partial_sort(
- probs_id.begin(),
- probs_id.begin() + top_k, probs_id.end(),
- [](const std::pair<double, whisper_vocab::id> & a, const std::pair<double, whisper_vocab::id> & b) {
- return a.first > b.first;
- });
-
- probs_id.resize(top_k);
-
- //printf("\n");
- //for (int i = 0; i < (int) probs_id.size(); i++) {
- // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
- //}
-
- int res = 0;
- while ((probs_id[res].second == vocab.token_sot ||
- probs_id[res].second == vocab.token_solm ||
- probs_id[res].second == vocab.token_not) &&
- res < (int) probs_id.size() - 1) {
- res++;
- }
-
- result.id = probs_id[res].second;
- result.p = probs_id[res].first;
+ wctx.t_decode_us += ggml_time_us() - t_start_us;
- return result;
+ return true;
}
// 500 -> 00:05.000
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
static bool log_mel_spectrogram(
- const float * samples,
- const int n_samples,
- const int /*sample_rate*/,
- const int fft_size,
- const int fft_step,
- const int n_mel,
- const int n_threads,
- const whisper_filters & filters,
- const bool speed_up,
- whisper_mel & mel) {
+ whisper_context & wctx,
+ const float * samples,
+ const int n_samples,
+ const int /*sample_rate*/,
+ const int fft_size,
+ const int fft_step,
+ const int n_mel,
+ const int n_threads,
+ const whisper_filters & filters,
+ const bool speed_up,
+ whisper_mel & mel) {
+ const int64_t t_start_us = ggml_time_us();
// Hanning window
std::vector<float> hann;
mel.data[i] = (mel.data[i] + 4.0)/4.0;
}
+ wctx.t_mel_us += ggml_time_us() - t_start_us;
+
return true;
}
whisper_context * ctx = new whisper_context;
- const int64_t t_start_us = ggml_time_us();
-
- ctx->t_start_us = t_start_us;
-
if (!whisper_model_load(loader, *ctx)) {
loader->close(loader->context);
fprintf(stderr, "%s: failed to load model\n", __func__);
return nullptr;
}
- ctx->t_load_us = ggml_time_us() - t_start_us;
-
loader->close(loader->context);
return ctx;
if (ctx->model.ctx) {
ggml_free(ctx->model.ctx);
}
- if (ctx->model.ctx_mem) {
- ggml_free(ctx->model.ctx_mem);
+ if (ctx->model.buf) {
+ delete ctx->model.buf;
+ }
+ if (ctx->kv_cross.ctx) {
+ ggml_free(ctx->kv_cross.ctx);
}
- if (ctx->buf_model) {
- delete ctx->buf_model;
+ for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
+ if (ctx->decoders[i].kv_self.ctx) {
+ ggml_free(ctx->decoders[i].kv_self.ctx);
+ }
}
delete ctx;
}
}
int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
- const int64_t t_start_us = ggml_time_us();
-
- if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) {
+ if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) {
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
return -1;
}
- ctx->t_mel_us = ggml_time_us() - t_start_us;
-
return 0;
}
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
- const int64_t t_start_us = ggml_time_us();
-
- if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) {
+ if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) {
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
return -1;
}
- ctx->t_mel_us = ggml_time_us() - t_start_us;
-
return 0;
}
}
int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
- const int64_t t_start_us = ggml_time_us();
-
- if (!whisper_encode(*ctx, n_threads, offset)) {
+ if (!whisper_encode(*ctx, offset, n_threads)) {
fprintf(stderr, "%s: failed to eval\n", __func__);
return -1;
}
- ctx->t_encode_us += ggml_time_us() - t_start_us;
-
return 0;
}
int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
- const int64_t t_start_us = ggml_time_us();
+ // TODO: add selected_decoder_id to context
+ const int selected_decoder_id = 0;
- if (!whisper_decode(*ctx, n_threads, tokens, n_tokens, n_past)) {
+ if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
fprintf(stderr, "%s: failed to eval\n", __func__);
return 1;
}
- ctx->t_decode_us += ggml_time_us() - t_start_us;
-
return 0;
}
-struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) {
- const int64_t t_start_sample_us = ggml_time_us();
-
- const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false);
-
- ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-
- return res;
-}
-
-struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) {
- const int64_t t_start_sample_us = ggml_time_us();
-
- const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial);
-
- ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-
- return res;
-}
-
int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) {
const auto res = tokenize(ctx->vocab, text);
return -7;
}
- std::vector<std::pair<float, int>> probs_id;
+ auto & logits_id = ctx->logits_id;
+ logits_id.clear();
+
for (const auto & kv : g_lang) {
const auto token_lang = whisper_token_lang(ctx, kv.second.first);
- probs_id.emplace_back(ctx->probs[token_lang], kv.second.first);
+ logits_id.emplace_back(ctx->logits[token_lang], kv.second.first);
}
// sort descending
{
- using pair_type = decltype(probs_id)::value_type;
- std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
+ using pair_type = std::remove_reference<decltype(logits_id)>::type::value_type;
+ std::sort(logits_id.begin(), logits_id.end(), [](const pair_type & a, const pair_type & b) {
return a.first > b.first;
});
}
// softmax
{
- float sum = 0;
- for (const auto & kv : probs_id) {
- sum += exp(kv.first);
+ const auto max = logits_id[0].first;
+
+ double sum = 0.0f;
+ for (auto & kv : logits_id) {
+ kv.first = exp(kv.first - max);
+ sum += kv.first;
}
- for (auto & kv : probs_id) {
- kv.first = exp(kv.first) / sum;
+ for (auto & kv : logits_id) {
+ kv.first /= sum;
}
}
{
- for (const auto & prob : probs_id) {
+ for (const auto & prob : logits_id) {
if (lang_probs) {
lang_probs[prob.second] = prob.first;
}
}
}
- return probs_id[0].second;
+ return logits_id[0].second;
}
int whisper_n_len(struct whisper_context * ctx) {
return ctx->vocab.is_multilingual() ? 1 : 0;
}
-float * whisper_get_probs(struct whisper_context * ctx) {
- return ctx->probs.data();
+float * whisper_get_logits(struct whisper_context * ctx) {
+ return ctx->logits.data();
}
const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
////////////////////////////////////////////////////////////////////////////
struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
- struct whisper_full_params result;
+ struct whisper_full_params result = {
+ /*.strategy =*/ WHISPER_SAMPLING_GREEDY,
+
+ /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
+ /*.n_max_text_ctx =*/ 16384,
+ /*.offset_ms =*/ 0,
+ /*.duration_ms =*/ 0,
+
+ /*.translate =*/ false,
+ /*.no_context =*/ false,
+ /*.single_segment =*/ false,
+ /*.print_special =*/ false,
+ /*.print_progress =*/ true,
+ /*.print_realtime =*/ false,
+ /*.print_timestamps =*/ true,
+
+ /*.token_timestamps =*/ false,
+ /*.thold_pt =*/ 0.01f,
+ /*.thold_ptsum =*/ 0.01f,
+ /*.max_len =*/ 0,
+ /*.max_tokens =*/ 0,
+
+ /*.speed_up =*/ false,
+ /*.audio_ctx =*/ 0,
+
+ /*.prompt_tokens =*/ nullptr,
+ /*.prompt_n_tokens =*/ 0,
+
+ /*.language =*/ "en",
+
+ /*.suppress_blank =*/ true,
+
+ /*.temperature =*/ 0.0f,
+ /*.max_initial_ts =*/ 1.0f,
+ /*.length_penalty =*/ -1.0f,
+
+ /*.temperature_inc =*/ 0.2f,
+ /*.entropy_thold =*/ 2.4f,
+ /*.logprob_thold =*/ -1.0f,
+ /*.no_speech_thold =*/ 0.6f,
+
+ /*.greedy =*/ {
+ /*.best_of =*/ -1,
+ },
+
+ /*.beam_search =*/ {
+ /*.beam_size =*/ -1,
+
+ /*.patience =*/ -1.0f,
+ },
+
+ /*.new_segment_callback =*/ nullptr,
+ /*.new_segment_callback_user_data =*/ nullptr,
+
+ /*.encoder_begin_callback =*/ nullptr,
+ /*.encoder_begin_callback_user_data =*/ nullptr,
+ };
switch (strategy) {
case WHISPER_SAMPLING_GREEDY:
{
- result = {
- /*.strategy =*/ WHISPER_SAMPLING_GREEDY,
-
- /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
- /*.n_max_text_ctx =*/ 16384,
- /*.offset_ms =*/ 0,
- /*.duration_ms =*/ 0,
-
- /*.translate =*/ false,
- /*.no_context =*/ false,
- /*.single_segment =*/ false,
- /*.print_special =*/ false,
- /*.print_progress =*/ true,
- /*.print_realtime =*/ false,
- /*.print_timestamps =*/ true,
-
- /*.token_timestamps =*/ false,
- /*.thold_pt =*/ 0.01f,
- /*.thold_ptsum =*/ 0.01f,
- /*.max_len =*/ 0,
- /*.max_tokens =*/ 0,
-
- /*.speed_up =*/ false,
- /*.audio_ctx =*/ 0,
-
- /*.prompt_tokens =*/ nullptr,
- /*.prompt_n_tokens =*/ 0,
-
- /*.language =*/ "en",
-
- /*.greedy =*/ {
- /*.n_past =*/ 0,
- },
-
- /*.beam_search =*/ {
- /*.n_past =*/ -1,
- /*.beam_width =*/ -1,
- /*.n_best =*/ -1,
- },
-
- /*.new_segment_callback =*/ nullptr,
- /*.new_segment_callback_user_data =*/ nullptr,
-
- /*.encoder_begin_callback =*/ nullptr,
- /*.encoder_begin_callback_user_data =*/ nullptr,
+ result.greedy = {
+ /*.best_of =*/ 1,
};
} break;
case WHISPER_SAMPLING_BEAM_SEARCH:
{
- result = {
- /*.strategy =*/ WHISPER_SAMPLING_BEAM_SEARCH,
-
- /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
- /*.n_max_text_ctx =*/ 16384,
- /*.offset_ms =*/ 0,
- /*.duration_ms =*/ 0,
-
- /*.translate =*/ false,
- /*.no_context =*/ false,
- /*.single_segment =*/ false,
- /*.print_special =*/ false,
- /*.print_progress =*/ true,
- /*.print_realtime =*/ false,
- /*.print_timestamps =*/ true,
-
- /*.token_timestamps =*/ false,
- /*.thold_pt =*/ 0.01f,
- /*.thold_ptsum =*/ 0.01f,
- /*.max_len =*/ 0,
- /*.max_tokens =*/ 0,
-
- /*.speed_up =*/ false,
- /*.audio_ctx =*/ 0,
-
- /*.prompt_tokens =*/ nullptr,
- /*.prompt_n_tokens =*/ 0,
-
- /*.language =*/ "en",
-
- /*.greedy =*/ {
- /*.n_past =*/ -1,
- },
-
- /*.beam_search =*/ {
- /*.n_past =*/ 0,
- /*.beam_width =*/ 10,
- /*.n_best =*/ 5,
- },
-
- /*.new_segment_callback =*/ nullptr,
- /*.new_segment_callback_user_data =*/ nullptr,
-
- /*.encoder_begin_callback =*/ nullptr,
- /*.encoder_begin_callback_user_data =*/ nullptr,
+ result.beam_search = {
+ /*.beam_size =*/ 5,
+
+ /*.patience =*/ -1.0f,
};
} break;
}
// forward declarations
static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
static void whisper_exp_compute_token_level_timestamps(
- struct whisper_context * ctx,
- int i_segment,
- float thold_pt,
- float thold_ptsum);
+ struct whisper_context & ctx,
+ int i_segment,
+ float thold_pt,
+ float thold_ptsum);
// wrap the last segment to max_len characters
// returns the number of new segments
-static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) {
- auto segment = ctx->result_all.back();
+static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
+ auto segment = ctx.result_all.back();
int res = 1;
int acc = 0;
for (int i = 0; i < (int) segment.tokens.size(); i++) {
const auto & token = segment.tokens[i];
- if (token.id >= whisper_token_eot(ctx)) {
+ if (token.id >= whisper_token_eot(&ctx)) {
continue;
}
- const auto txt = whisper_token_to_str(ctx, token.id);
+ const auto txt = whisper_token_to_str(&ctx, token.id);
const int cur = strlen(txt);
if (acc + cur > max_len && i > 0) {
// split here
- ctx->result_all.back().text = std::move(text);
- ctx->result_all.back().t1 = token.t0;
- ctx->result_all.back().tokens.resize(i);
+ ctx.result_all.back().text = std::move(text);
+ ctx.result_all.back().t1 = token.t0;
+ ctx.result_all.back().tokens.resize(i);
- ctx->result_all.push_back({});
- ctx->result_all.back().t0 = token.t0;
- ctx->result_all.back().t1 = segment.t1;
+ ctx.result_all.push_back({});
+ ctx.result_all.back().t0 = token.t0;
+ ctx.result_all.back().t1 = segment.t1;
// add tokens [i, end] to the new segment
- ctx->result_all.back().tokens.insert(
- ctx->result_all.back().tokens.end(),
+ ctx.result_all.back().tokens.insert(
+ ctx.result_all.back().tokens.end(),
segment.tokens.begin() + i,
segment.tokens.end());
acc = 0;
text = "";
- segment = ctx->result_all.back();
+ segment = ctx.result_all.back();
i = -1;
res++;
}
}
- ctx->result_all.back().text = std::move(text);
+ ctx.result_all.back().text = std::move(text);
return res;
}
-int whisper_full(
- struct whisper_context * ctx,
- struct whisper_full_params params,
- const float * samples,
- int n_samples) {
- // clear old results
- auto & result_all = ctx->result_all;
-
- result_all.clear();
+// process the logits for the selected decoder
+// - applies logit filters
+// - computes logprobs and probs
+static void whisper_process_logits(
+ const struct whisper_context & ctx,
+ const struct whisper_full_params params,
+ struct whisper_decoder & decoder,
+ float temperature) {
+ const auto & vocab = ctx.vocab;
+ const auto & tokens_cur = decoder.sequence.tokens;
+
+ const bool is_initial = tokens_cur.size() == 0;
+ const int n_logits = vocab.id_to_token.size();
+
+ WHISPER_ASSERT(n_logits == ctx.vocab.n_vocab);
+
+ // extract the logits for the last token
+ // we will be mutating and therefore we don't want to use the ctx.logits buffer directly
+ auto & probs = decoder.probs;
+ auto & logits = decoder.logits;
+ auto & logprobs = decoder.logprobs;
+ {
+ logits.resize(n_logits);
+ memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float));
- // compute log mel spectrogram
- if (params.speed_up) {
- if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) {
- fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
- return -1;
- }
- } else {
- if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
- fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
- return -2;
+ if (temperature > 0.0f) {
+ for (int i = 0; i < n_logits; i++) {
+ logits[i] /= temperature;
+ }
}
- }
- // auto-detect language if not specified
- if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
- std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
+ // will be populated a bit later
+ probs.resize(n_logits);
+ logprobs.resize(n_logits);
+ }
- const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data());
- if (lang_id < 0) {
- fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
- return -3;
+ // apply logit filters here
+ // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L480-L493
+ {
+ // suppress blank
+ // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L388-L390
+ if (params.suppress_blank) {
+ if (is_initial) {
+ logits[vocab.token_eot] = -INFINITY;
+ logits[vocab.token_to_id.at(" ")] = -INFINITY;
+ }
}
- params.language = whisper_lang_str(lang_id);
+ // suppress <|notimestamps|> token
+ // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
+ logits[vocab.token_not] = -INFINITY;
- fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
+ // suppress sot and solm tokens
+ logits[vocab.token_sot] = -INFINITY;
+ logits[vocab.token_solm] = -INFINITY;
+
+ // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
+ // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424
+ {
+ const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg;
+ const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg;
+
+ //fprintf(stderr, "last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp);
+
+ if (last_was_timestamp) {
+ if (penultimate_was_timestamp) {
+ for (int i = vocab.token_beg; i < n_logits; ++i) {
+ logits[i] = -INFINITY;
+ }
+ } else {
+ for (int i = 0; i < vocab.token_eot; ++i) {
+ logits[i] = -INFINITY;
+ }
+ }
+ }
+ }
+
+ // the initial timestamp cannot be larger than max_initial_ts
+ // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
+ if (is_initial && params.max_initial_ts > 0.0f) {
+ const float precision = float(WHISPER_CHUNK_SIZE)/ctx.model.hparams.n_audio_ctx;
+ const int tid0 = std::round(params.max_initial_ts/precision);
+
+ for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++i) {
+ logits[i] = -INFINITY;
+ }
+ }
+
+ // populate the logprobs array (log_softmax)
+ {
+ const float logit_max = *std::max_element(logits.begin(), logits.end());
+ float logsumexp = 0.0f;
+ for (int i = 0; i < n_logits; ++i) {
+ if (logits[i] > -INFINITY) {
+ logsumexp += expf(logits[i] - logit_max);
+ }
+ }
+ logsumexp = logf(logsumexp) + logit_max;
+
+ for (int i = 0; i < n_logits; ++i) {
+ if (logits[i] > -INFINITY) {
+ logprobs[i] = logits[i] - logsumexp;
+ } else {
+ logprobs[i] = -INFINITY;
+ }
+ }
+ }
+
+ // if sum of probability over timestamps is above any other token, sample timestamp
+ // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
+ {
+ // logsumexp over timestamps
+ float timestamp_logprob = -INFINITY;
+ {
+ float logsumexp = 0.0f;
+ const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end());
+ for (int i = vocab.token_beg; i < n_logits; ++i) {
+ if (logprobs[i] > -INFINITY) {
+ logsumexp += expf(logprobs[i] - logprob_max);
+ }
+ }
+ if (logsumexp > 0.0f) {
+ timestamp_logprob = logf(logsumexp) + logprob_max;
+ }
+ }
+
+ const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg);
+
+ //fprintf(stderr, "timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
+
+ if (timestamp_logprob > max_text_token_logprob) {
+ for (int i = 0; i < vocab.token_beg; ++i) {
+ logits[i] = -INFINITY;
+ logprobs[i] = -INFINITY;
+ }
+ }
+ }
+ }
+
+ // compute probs
+ {
+ for (int i = 0; i < n_logits; ++i) {
+ if (logits[i] == -INFINITY) {
+ probs[i] = 0.0f;
+ } else {
+ probs[i] = expf(logprobs[i]);
+ }
+ }
+ }
+
+#if 0
+ // print first 100 logits - token string : logit
+ for (int i = 0; i < 100; i++) {
+ const auto token = vocab.id_to_token.at(i);
+ const auto prob = probs[i];
+ const auto logit = logits[i];
+ const auto logprob = logprobs[i];
+ printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
+ }
+
+ // "And", "and", " And", " and"
+ printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
+ printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
+ printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
+ printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
+ printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);
+
+ printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
+ printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
+ printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
+ printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
+ printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
+
+ printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
+ printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
+ printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
+ printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
+ printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
+#endif
+}
+
+static whisper_token_data whisper_sample_token(
+ const whisper_context & ctx,
+ const whisper_decoder & decoder,
+ bool best) {
+ whisper_token_data result = {
+ 0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
+ };
+
+ const auto & vocab = ctx.vocab;
+
+ const auto & probs = decoder.probs;
+ const auto & logprobs = decoder.logprobs;
+
+ const int n_logits = vocab.n_vocab;
+
+ {
+ double sum_ts = 0.0;
+ double max_ts = 0.0;
+
+ for (int i = vocab.token_beg; i < n_logits; i++) {
+ if (probs[i] == -INFINITY) {
+ continue;
+ }
+
+ sum_ts += probs[i];
+ if (max_ts < probs[i]) {
+ max_ts = probs[i];
+ result.tid = i;
+ }
+ }
+
+ result.pt = max_ts/(sum_ts + 1e-10);
+ result.ptsum = sum_ts;
+ }
+
+ if (best) {
+ for (int i = 0; i < n_logits; ++i) {
+ if (result.p < probs[i]) {
+ result.id = i;
+ result.p = probs[i];
+ result.plog = logprobs[i];
+ }
+ }
+ } else {
+ std::discrete_distribution<> dist(probs.begin(), probs.end());
+
+ result.id = dist(ctx.rng);
+ result.p = probs[result.id];
+ result.plog = logprobs[result.id];
+ }
+
+ if (result.id >= vocab.token_beg) {
+ result.tid = result.id;
+ result.pt = result.p;
+ }
+
+ return result;
+}
+
+static std::vector<whisper_token_data> whisper_sample_token_topk(
+ whisper_context & ctx,
+ const whisper_decoder & decoder,
+ int k) {
+ const auto & vocab = ctx.vocab;
+
+ const auto & probs = decoder.probs;
+ const auto & logits = decoder.logits;
+ const auto & logprobs = decoder.logprobs;
+
+ const int n_logits = vocab.n_vocab;
+
+ auto & logits_id = ctx.logits_id;
+
+ logits_id.clear();
+ for (int i = 0; i < n_logits; ++i) {
+ logits_id.push_back({ logits[i], i });
+ }
+
+ std::partial_sort(
+ logits_id.begin(),
+ logits_id.begin() + k, logits_id.end(),
+ [](const std::pair<double, whisper_token> & a, const std::pair<double, whisper_token> & b) {
+ return a.first > b.first;
+ });
+
+ std::vector<whisper_token_data> result;
+ result.reserve(k);
+
+ whisper_token tid;
+
+ float pt;
+ float ptsum;
+
+ {
+ double sum_ts = 0.0;
+ double max_ts = 0.0;
+
+ for (int i = vocab.token_beg; i < n_logits; i++) {
+ if (probs[i] == -INFINITY) {
+ continue;
+ }
+
+ sum_ts += probs[i];
+ if (max_ts < probs[i]) {
+ max_ts = probs[i];
+ tid = i;
+ }
+ }
+
+ pt = max_ts/(sum_ts + 1e-10);
+ ptsum = sum_ts;
+ }
+
+ for (int i = 0; i < k; ++i) {
+ const auto id = logits_id[i].second;
+
+ result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
+
+ if (result[i].id >= vocab.token_beg) {
+ result[i].tid = result[i].id;
+ result[i].pt = result[i].p;
+ }
+ }
+
+ return result;
+}
+
+// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L178-L192
+static void whisper_sequence_score(
+ const struct whisper_full_params & params,
+ whisper_sequence & sequence) {
+ if (sequence.result_len == 0) {
+ return;
+ }
+
+ double result = 0.0f;
+
+ for (int i = 0; i < sequence.result_len; ++i) {
+ result += sequence.tokens[i].plog;
+ }
+
+ sequence.sum_logprobs = result;
+ sequence.avg_logprobs = result/sequence.result_len;
+
+ double penalty = sequence.result_len;
+
+ if (params.length_penalty > 0.0f) {
+ penalty = pow((5.0 + penalty)/6.0, params.length_penalty);
+ }
+
+ sequence.score = result/penalty;
+
+ // compute the entropy of the sequence of the last 32 tokens
+ {
+ const int n = 32;
+
+ int cnt = 0;
+ double entropy = 0.0f;
+
+ std::map<whisper_token, int> token_counts;
+ for (int i = std::max(0, sequence.result_len - n); i < sequence.result_len; ++i) {
+ token_counts[sequence.tokens[i].id]++;
+ cnt++;
+ }
+
+ for (const auto & kv : token_counts) {
+ const auto p = kv.second/(double)cnt;
+ entropy -= p*log(p);
+
+ //WHISPER_PRINT_DEBUG("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second);
+ }
+
+ sequence.entropy = entropy;
+ }
+}
+
+int whisper_full(
+ struct whisper_context * ctx,
+ struct whisper_full_params params,
+ const float * samples,
+ int n_samples) {
+ // clear old results
+ auto & result_all = ctx->result_all;
+
+ result_all.clear();
+
+ // compute log mel spectrogram
+ if (params.speed_up) {
+ if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) {
+ fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
+ return -1;
+ }
+ } else {
+ if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
+ fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
+ return -2;
+ }
+ }
+
+ // auto-detect language if not specified
+ if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
+ std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
+
+ const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data());
+ if (lang_id < 0) {
+ fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
+ return -3;
+ }
+
+ params.language = whisper_lang_str(lang_id);
+
+ fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
}
if (params.token_timestamps) {
- ctx->t_beg = 0;
- ctx->t_last = 0;
+ ctx->t_beg = 0;
+ ctx->t_last = 0;
ctx->tid_last = 0;
ctx->energy = get_signal_energy(samples, n_samples, 32);
}
return 0;
}
+ // a set of temperatures to use
+ // [ t0, t0 + delta, t0 + 2*delta, ..., < 1.0f + 1e-6f ]
+ std::vector<float> temperatures;
+ if (params.temperature_inc > 0.0f) {
+ for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_inc) {
+ temperatures.push_back(t);
+ }
+ } else {
+ temperatures.push_back(params.temperature);
+ }
+
+ // initialize the decoders
+ int n_decoders = 1;
+
+ switch (params.strategy) {
+ case WHISPER_SAMPLING_GREEDY:
+ {
+ n_decoders = params.greedy.best_of;
+ } break;
+ case WHISPER_SAMPLING_BEAM_SEARCH:
+ {
+ n_decoders = std::max(params.greedy.best_of, params.beam_search.beam_size);
+ } break;
+ };
+
+ n_decoders = std::max(1, n_decoders);
+
+ // TAGS: WHISPER_DECODER_INIT
+ for (int j = 1; j < n_decoders; j++) {
+ auto & decoder = ctx->decoders[j];
+
+ if (decoder.kv_self.ctx == nullptr) {
+ decoder.kv_self = ctx->decoders[0].kv_self;
+ if (!kv_cache_reinit(decoder.kv_self)) {
+ fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
+ return -4;
+ }
+
+ WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
+
+ decoder.sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity());
+
+ decoder.probs.resize (ctx->vocab.n_vocab);
+ decoder.logits.resize (ctx->vocab.n_vocab);
+ decoder.logprobs.resize(ctx->vocab.n_vocab);
+ }
+ }
+
// the accumulated text context so far
auto & prompt_past = ctx->prompt_past;
if (params.no_context) {
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
- return -4;
+ return -5;
}
ctx->exp_n_audio_ctx = params.audio_ctx;
int progress_prev = 0;
int progress_step = 5;
- std::vector<whisper_token_data> tokens_cur;
- tokens_cur.reserve(whisper_n_text_ctx(ctx));
+ int seek = seek_start;
std::vector<whisper_token> prompt;
prompt.reserve(whisper_n_text_ctx(ctx));
+ // beam-search helpers
+ struct kv_buf {
+ std::vector<uint8_t> k;
+ std::vector<uint8_t> v;
+ };
+
+ std::vector<kv_buf> kv_bufs;
+
+ struct beam_candidate {
+ int decoder_idx;
+ int seek_delta;
+
+ bool has_ts;
+
+ whisper_sequence sequence;
+ };
+
+ std::vector<beam_candidate> beam_candidates;
+
// main loop
- int seek = seek_start;
while (true) {
const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
while (progress_cur >= progress_prev + progress_step) {
break;
}
- // if there is a very short audio segment left to process, we remove any past prompt since it tends
- // to confuse the decoder and often make it repeat or hallucinate stuff
- if (seek > seek_start && seek + 500 >= seek_end) {
- prompt_past.clear();
- }
-
if (params.encoder_begin_callback) {
if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
}
// encode audio features starting at offset seek
- if (whisper_encode(ctx, seek, params.n_threads) != 0) {
+ if (!whisper_encode(*ctx, seek, params.n_threads)) {
fprintf(stderr, "%s: failed to encode\n", __func__);
- return -4;
+ return -6;
+ }
+
+ // if there is a very short audio segment left to process, we remove any past prompt since it tends
+ // to confuse the decoder and often make it repeat or hallucinate stuff
+ if (seek > seek_start && seek + 500 >= seek_end) {
+ prompt_past.clear();
}
- int n_past = 0;
- prompt.clear();
+ int best_decoder_id = 0;
- // if we have already generated some text, use it as a prompt to condition the next generation
- if (!prompt_past.empty()) {
- int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
+ for (int it = 0; it < (int) temperatures.size(); ++it) {
+ const float t_cur = temperatures[it];
- prompt = { whisper_token_prev(ctx) };
- prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
+ int n_decoders_cur = 1;
- prompt_past.clear();
- prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end());
- }
+ switch (params.strategy) {
+ case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
+ {
+ if (t_cur > 0.0f) {
+ n_decoders_cur = params.greedy.best_of;
+ }
+ } break;
+ case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
+ {
+ if (t_cur > 0.0f) {
+ n_decoders_cur = params.greedy.best_of;
+ } else {
+ n_decoders_cur = params.beam_search.beam_size;
+ }
+ } break;
+ };
- prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
+ n_decoders_cur = std::max(1, n_decoders_cur);
- int seek_delta = 100*WHISPER_CHUNK_SIZE;
+ WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur);
- // print the prompt
- //printf("\n\n");
- //for (int i = 0; i < prompt.size(); i++) {
- // printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str());
- //}
- //printf("\n\n");
+ // TAGS: WHISPER_DECODER_INIT
+ for (int j = 0; j < n_decoders_cur; ++j) {
+ auto & decoder = ctx->decoders[j];
- // the accumulated transcription in the current interation
- int result_len = 0;
- tokens_cur.clear();
+ decoder.kv_self.n = 0;
- bool failed = false;
- bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment?
+ decoder.sequence.tokens.clear();
+ decoder.sequence.result_len = 0;
+ decoder.sequence.sum_logprobs_all = 0.0;
+ decoder.sequence.sum_logprobs = -INFINITY;
+ decoder.sequence.avg_logprobs = -INFINITY;
+ decoder.sequence.entropy = 0.0;
+ decoder.sequence.score = -INFINITY;
- for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
- if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
- fprintf(stderr, "%s: failed to decode\n", __func__);
- return -5;
- }
+ decoder.seek_delta = 100*WHISPER_CHUNK_SIZE;
- n_past += prompt.size();
- prompt.clear();
+ decoder.failed = false;
+ decoder.completed = false;
+ decoder.has_ts = false;
+ }
- // very basic greedy sampling strategy:
- //
- // - always take the most probable token
- //
- // more sophisticated sampling strategies could be implemented here, but we keep it simple
- // feel free to experiment!
- //
+ // init prompt and kv cache for the current iteration
+ // run whisper_decoder() only for decoder 0 and copy the results for the other decoders
{
- const auto token = (i == 0) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx);
+ prompt.clear();
- // timestamp token - update sliding window
- if (token.id > whisper_token_beg(ctx)) {
- const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
+ // if we have already generated some text, use it as a prompt to condition the next generation
+ if (!prompt_past.empty() && t_cur > 0.5f) {
+ int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
- // do not allow to go back in time
- if (has_ts && seek_delta > seek_delta_new && result_len < i) {
- break;
+ prompt = { whisper_token_prev(ctx) };
+ prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
+ }
+
+ // init new transcription with sot, language (opt) and task tokens
+ prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
+
+ // print the prompt
+ //WHISPER_PRINT_DEBUG("\n\n");
+ //for (int i = 0; i < (int) prompt.size(); i++) {
+ // WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str());
+ //}
+ //WHISPER_PRINT_DEBUG("\n\n");
+
+ if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
+ fprintf(stderr, "%s: failed to decode\n", __func__);
+ return -7;
+ }
+
+ {
+ const int64_t t_start_sample_us = ggml_time_us();
+
+ whisper_process_logits(*ctx, params, ctx->decoders[0], t_cur);
+
+ ctx->decoders[0].kv_self.n += prompt.size();
+
+ for (int j = 1; j < n_decoders_cur; ++j) {
+ auto & decoder = ctx->decoders[j];
+
+ memcpy(decoder.kv_self.k->data, ctx->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
+ memcpy(decoder.kv_self.v->data, ctx->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
+
+ decoder.kv_self.n += prompt.size();
+
+ memcpy(decoder.probs.data(), ctx->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
+ memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
+ memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
}
- seek_delta = seek_delta_new;
- result_len = i + 1;
- has_ts = true;
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
}
+ }
- // add it to the context
- prompt.push_back(token.id);
- tokens_cur.push_back(token);
+ for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
+ const int64_t t_start_sample_us = ggml_time_us();
- //{
- // const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
- // printf("%s: %3d %10s %6d %6.3f '%s'\n", __func__, i, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
- //}
+ // store the KV caches of all decoders when doing beam-search
+ if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
+ kv_bufs.resize(n_decoders_cur);
+ for (int j = 0; j < n_decoders_cur; ++j) {
+ auto & decoder = ctx->decoders[j];
- // end of segment
- if (token.id == whisper_token_eot(ctx) || // end of text token
- (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
- (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached
- ) {
- if (result_len == 0) {
- if (seek + seek_delta + 100 >= seek_end) {
- result_len = i + 1;
- } else {
- failed = true;
- break;
+ if (decoder.completed || decoder.failed) {
+ continue;
}
- }
- if (params.single_segment) {
- result_len = i + 1;
- seek_delta = 100*WHISPER_CHUNK_SIZE;
+ kv_bufs[j].k.resize(ggml_nbytes(decoder.kv_self.k));
+ kv_bufs[j].v.resize(ggml_nbytes(decoder.kv_self.v));
+
+ memcpy(kv_bufs[j].k.data(), decoder.kv_self.k->data, kv_bufs[j].k.size());
+ memcpy(kv_bufs[j].v.data(), decoder.kv_self.v->data, kv_bufs[j].v.size());
}
- break;
+ beam_candidates.clear();
}
- // TESTS: if no tensors are loaded, it means we are running tests
- if (ctx->model.n_loaded == 0) {
- seek_delta = 100*WHISPER_CHUNK_SIZE;
- break;
+ // generate new sequence candidates for each decoder
+ for (int j = 0; j < n_decoders_cur; ++j) {
+ auto & decoder = ctx->decoders[j];
+
+ if (decoder.completed || decoder.failed) {
+ continue;
+ }
+
+ switch (params.strategy) {
+ case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
+ {
+ if (t_cur < 1e-6f) {
+ decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
+ } else {
+ decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
+ }
+
+ decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
+ } break;
+ case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
+ {
+ const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
+
+ for (const auto & token : tokens_new) {
+ beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence });
+ beam_candidates.back().sequence.tokens.push_back(token);
+ beam_candidates.back().sequence.sum_logprobs_all += token.plog;
+
+ //WHISPER_PRINT_DEBUG("%s: beam candidate: %s (%f, %f)\n", __func__, ctx->vocab.id_to_token.at(token.id).c_str(), token.plog, beam_candidates.back().sequence.sum_logprobs_all);
+ }
+ } break;
+ };
}
- }
- // sometimes, the decoding can get stuck in a repetition loop
- // this is a simple strategy to avoid such cases - we simply flag the decoding as failed and advance
- // the sliding window by 1 second
- if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
- failed = true;
- break;
- }
- }
+ // for beam-search, choose the top candidates and update the KV caches
+ if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
+ std::sort(
+ beam_candidates.begin(),
+ beam_candidates.end(),
+ [](const beam_candidate & a, const beam_candidate & b) {
+ return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
+ });
- if (failed) {
- // when we fail to sample timestamp token, retry by clearing the past prompt
- // if it fails again, then we advance the window by 1 second
- if (!prompt_past.empty()) {
- prompt_past.clear();
- } else {
- fprintf(stderr, "\n%s: failed to generate timestamp token - skipping one second\n\n", __func__);
- seek += 100;
- }
- continue;
- }
+ int cur_c = 0;
- // shrink down to result_len
- tokens_cur.resize(result_len);
+ for (int j = 0; j < n_decoders_cur; ++j) {
+ auto & decoder = ctx->decoders[j];
- for (const auto & r : tokens_cur) {
- prompt_past.push_back(r.id);
- }
+ if (decoder.completed || decoder.failed) {
+ continue;
+ }
- // store the text from this iteration
- if (!tokens_cur.empty()) {
- int i0 = 0;
- auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
+ auto & cur = beam_candidates[cur_c++];
- std::string text;
+ while (beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
+ ++cur_c;
+ }
- for (int i = 0; i < (int) tokens_cur.size(); i++) {
- //printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
- // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
- // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
+ decoder.sequence = cur.sequence;
+ decoder.seek_delta = cur.seek_delta;
+ decoder.has_ts = cur.has_ts;
- if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
- } else {
- text += whisper_token_to_str(ctx, tokens_cur[i].id);
+ memcpy(decoder.kv_self.k->data, kv_bufs[cur.decoder_idx].k.data(), kv_bufs[cur.decoder_idx].k.size());
+ memcpy(decoder.kv_self.v->data, kv_bufs[cur.decoder_idx].v.data(), kv_bufs[cur.decoder_idx].v.size());
+
+ WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
+ __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
+ }
}
- if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
- const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
- if (!text.empty()) {
- const auto tt0 = params.speed_up ? 2*t0 : t0;
- const auto tt1 = params.speed_up ? 2*t1 : t1;
-
- if (params.print_realtime) {
- if (params.print_timestamps) {
- printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
- } else {
- printf("%s", text.c_str());
- fflush(stdout);
+
+ // update the decoder state
+ // - check if the sequence is completed
+ // - check if the sequence is failed
+ // - update sliding window based on timestamp tokens
+ for (int j = 0; j < n_decoders_cur; ++j) {
+ auto & decoder = ctx->decoders[j];
+
+ if (decoder.completed || decoder.failed) {
+ continue;
+ }
+
+ auto & has_ts = decoder.has_ts;
+ auto & failed = decoder.failed;
+ auto & completed = decoder.completed;
+ auto & seek_delta = decoder.seek_delta;
+ auto & result_len = decoder.sequence.result_len;
+
+ {
+ const auto & token = decoder.sequence.tokens.back();
+
+ // timestamp token - update sliding window
+ if (token.id > whisper_token_beg(ctx)) {
+ const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
+
+ // do not allow to go back in time
+ if (has_ts && seek_delta > seek_delta_new && result_len < i) {
+ failed = true; // TODO: maybe this is not a failure ?
+ continue;
}
- }
- result_all.push_back({ tt0, tt1, text, {} });
- for (int j = i0; j <= i; j++) {
- result_all.back().tokens.push_back(tokens_cur[j]);
+ seek_delta = seek_delta_new;
+ result_len = i + 1;
+ has_ts = true;
}
- int n_new = 1;
+#ifdef WHISPER_DEBUG
+ {
+ const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]";
+ WHISPER_PRINT_DEBUG("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n",
+ __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str());
+ }
+#endif
- if (params.token_timestamps) {
- whisper_exp_compute_token_level_timestamps(
- ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
+ // end of segment
+ if (token.id == whisper_token_eot(ctx) || // end of text token
+ (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
+ (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached
+ ) {
+ if (result_len == 0) {
+ if (seek + seek_delta + 100 >= seek_end) {
+ result_len = i + 1;
+ } else {
+ failed = true;
+ continue;
+ }
+ }
- if (params.max_len > 0) {
- n_new = whisper_wrap_segment(ctx, params.max_len);
+ if (params.single_segment) {
+ result_len = i + 1;
+ seek_delta = 100*WHISPER_CHUNK_SIZE;
}
+
+ completed = true;
+ continue;
}
- if (params.new_segment_callback) {
- params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
+
+ // TESTS: if no tensors are loaded, it means we are running tests
+ if (ctx->model.n_loaded == 0) {
+ seek_delta = 100*WHISPER_CHUNK_SIZE;
+ completed = true;
+ continue;
}
}
- text = "";
- while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
- i++;
+
+ // sometimes, the decoding can get stuck in a repetition loop
+ // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
+ if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
+ failed = true;
+ continue;
}
- i--;
- t0 = t1;
- i0 = i + 1;
}
- }
- if (!text.empty()) {
- const auto t1 = seek + seek_delta;
+ // check if all decoders have finished (i.e. completed or failed)
+ {
+ bool completed_all = true;
- const auto tt0 = params.speed_up ? 2*t0 : t0;
- const auto tt1 = params.speed_up ? 2*t1 : t1;
+ for (int j = 0; j < n_decoders_cur; ++j) {
+ auto & decoder = ctx->decoders[j];
- if (params.print_realtime) {
- if (params.print_timestamps) {
- printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
- } else {
- printf("%s", text.c_str());
- fflush(stdout);
+ if (decoder.completed || decoder.failed) {
+ continue;
+ }
+
+ completed_all = false;
+ }
+
+ if (completed_all) {
+ break;
}
}
- result_all.push_back({ tt0, tt1, text, {} });
- for (int j = i0; j < (int) tokens_cur.size(); j++) {
- result_all.back().tokens.push_back(tokens_cur[j]);
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+
+ // obtain logits for the next token
+ for (int j = 0; j < n_decoders_cur; ++j) {
+ auto & decoder = ctx->decoders[j];
+
+ if (decoder.failed || decoder.completed) {
+ continue;
+ }
+
+ decoder.tokens_tmp.resize(1);
+ decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id;
+
+ //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
+
+ if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
+ fprintf(stderr, "%s: failed to decode\n", __func__);
+ return -8;
+ }
+
+ {
+ const int64_t t_start_sample_us = ggml_time_us();
+
+ whisper_process_logits(*ctx, params, decoder, t_cur);
+
+ ++decoder.kv_self.n;
+
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+ }
}
+ }
+
+ // rank the resulting sequences and select the best one
+ {
+ double best_score = -INFINITY;
+
+ for (int j = 0; j < n_decoders_cur; ++j) {
+ auto & decoder = ctx->decoders[j];
+
+ if (decoder.failed) {
+ continue;
+ }
+
+ decoder.sequence.tokens.resize(decoder.sequence.result_len);
+ whisper_sequence_score(params, decoder.sequence);
- int n_new = 1;
+ WHISPER_PRINT_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n",
+ __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy);
- if (params.token_timestamps) {
- whisper_exp_compute_token_level_timestamps(
- ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
+ if (decoder.sequence.result_len > 8 && decoder.sequence.entropy < params.entropy_thold) {
+ WHISPER_PRINT_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n",
+ __func__, j, decoder.sequence.entropy, params.entropy_thold);
- if (params.max_len > 0) {
- n_new = whisper_wrap_segment(ctx, params.max_len);
+ decoder.failed = true;
+
+ continue;
+ }
+
+ if (best_score < decoder.sequence.score) {
+ best_score = decoder.sequence.score;
+ best_decoder_id = j;
}
}
- if (params.new_segment_callback) {
- params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
+
+ WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id);
+ }
+
+ // was the decoding successful for the current temperature?
+ {
+ bool success = true;
+
+ const auto & decoder = ctx->decoders[best_decoder_id];
+
+ if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
+ success = false;
+ }
+
+ if (success) {
+ //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
+ // WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
+ //}
+
+ break;
}
}
+
+ WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
}
- seek += seek_delta;
+ // output results through a user-provided callback
+ {
+ const auto & best_decoder = ctx->decoders[best_decoder_id];
+
+ const auto seek_delta = best_decoder.seek_delta;
+ const auto result_len = best_decoder.sequence.result_len;
+
+ const auto & tokens_cur = best_decoder.sequence.tokens;
+
+ //WHISPER_PRINT_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
+
+ // update prompt_past
+ prompt_past.clear();
+ if (prompt.front() == whisper_token_prev(ctx)) {
+ prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
+ }
+
+ for (int i = 0; i < result_len; ++i) {
+ prompt_past.push_back(tokens_cur[i].id);
+ }
+
+ // store the text from this iteration
+ if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
+ int i0 = 0;
+ auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
+
+ std::string text;
+
+ for (int i = 0; i < (int) tokens_cur.size(); i++) {
+ //printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
+ // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
+ // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
+
+ if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
+ } else {
+ text += whisper_token_to_str(ctx, tokens_cur[i].id);
+ }
+
+ if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
+ const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
+ if (!text.empty()) {
+ const auto tt0 = params.speed_up ? 2*t0 : t0;
+ const auto tt1 = params.speed_up ? 2*t1 : t1;
+
+ if (params.print_realtime) {
+ if (params.print_timestamps) {
+ printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
+ } else {
+ printf("%s", text.c_str());
+ fflush(stdout);
+ }
+ }
+
+ //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid);
+
+ result_all.push_back({ tt0, tt1, text, {} });
+ for (int j = i0; j <= i; j++) {
+ result_all.back().tokens.push_back(tokens_cur[j]);
+ }
+
+ int n_new = 1;
+
+ if (params.token_timestamps) {
+ whisper_exp_compute_token_level_timestamps(
+ *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
+
+ if (params.max_len > 0) {
+ n_new = whisper_wrap_segment(*ctx, params.max_len);
+ }
+ }
+ if (params.new_segment_callback) {
+ params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
+ }
+ }
+ text = "";
+ while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
+ i++;
+ }
+ i--;
+ t0 = t1;
+ i0 = i + 1;
+ }
+ }
+
+ if (!text.empty()) {
+ const auto t1 = seek + seek_delta;
+
+ const auto tt0 = params.speed_up ? 2*t0 : t0;
+ const auto tt1 = params.speed_up ? 2*t1 : t1;
+
+ if (params.print_realtime) {
+ if (params.print_timestamps) {
+ printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
+ } else {
+ printf("%s", text.c_str());
+ fflush(stdout);
+ }
+ }
+
+ result_all.push_back({ tt0, tt1, text, {} });
+ for (int j = i0; j < (int) tokens_cur.size(); j++) {
+ result_all.back().tokens.push_back(tokens_cur[j]);
+ }
+
+ int n_new = 1;
+
+ if (params.token_timestamps) {
+ whisper_exp_compute_token_level_timestamps(
+ *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
+
+ if (params.max_len > 0) {
+ n_new = whisper_wrap_segment(*ctx, params.max_len);
+ }
+ }
+ if (params.new_segment_callback) {
+ params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
+ }
+ }
+ }
+
+ // update audio window
+ seek += seek_delta;
+
+ WHISPER_PRINT_DEBUG("seek = %d, seek_delta = %d\n", seek, seek_delta);
+ }
}
return 0;
std::vector<struct whisper_context> ctxs(n_processors - 1);
for (int i = 0; i < n_processors - 1; ++i) {
- ctxs[i] = *ctx;
-
- auto & model = ctxs[i].model;
-
- // create the ggml memory context
- {
- struct ggml_init_params params;
- params.mem_size = ctxs[i].buf_memory.size();
- params.mem_buffer = ctxs[i].buf_memory.data();
+ auto & ctx_p = ctxs[i];
- model.ctx_mem = ggml_init(params);
- if (!model.ctx_mem) {
- fprintf(stderr, "%s: ggml_init() failed\n", __func__);
- return false;
- }
- }
+ ctx_p = *ctx;
- // separate key + value memory for each processor
- {
- auto & mctx = model.ctx_mem;
-
- const auto & hparams = model.hparams;
+ ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
- const int n_text_state = hparams.n_text_state;
- const int n_text_layer = hparams.n_text_layer;
- const int n_text_ctx = hparams.n_text_ctx;
+ ctx_p.logits_id.reserve(ctx_p.vocab.n_vocab);
- // key/value memory for the self-attention layer
- {
- const int n_mem = n_text_layer*n_text_ctx;
- const int n_elements = n_text_state*n_mem;
+ if (!kv_cache_reinit(ctx_p.kv_cross)) {
+ fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i);
+ return false;
+ }
- model.memory_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
- model.memory_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
+ // TAGS: WHISPER_DECODER_INIT
+ for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
+ if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) {
+ fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i);
+ return false;
}
- // key/value memory for the cross-attention layer
- {
- const int n_audio_ctx = hparams.n_audio_ctx;
-
- const int n_mem = n_text_layer*n_audio_ctx;
- const int n_elements = n_text_state*n_mem;
+ ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx);
- model.memory_cross_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
- model.memory_cross_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
- }
+ ctx_p.decoders[j].probs.reserve (ctx_p.vocab.n_vocab);
+ ctx_p.decoders[j].logits.reserve (ctx_p.vocab.n_vocab);
+ ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab);
}
}
ctx->t_sample_us += ctxs[i].t_sample_us;
ctx->t_encode_us += ctxs[i].t_encode_us;
ctx->t_decode_us += ctxs[i].t_decode_us;
+
+ kv_cache_free(ctx->kv_cross);
+
+ for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
+ kv_cache_free(ctx->decoders[j].kv_self);
+ }
}
// average the timings
}
static void whisper_exp_compute_token_level_timestamps(
- struct whisper_context * ctx,
- int i_segment,
- float thold_pt,
- float thold_ptsum) {
- auto & segment = ctx->result_all[i_segment];
+ struct whisper_context & ctx,
+ int i_segment,
+ float thold_pt,
+ float thold_ptsum) {
+ auto & segment = ctx.result_all[i_segment];
auto & tokens = segment.tokens;
- const int n_samples = ctx->energy.size();
+ const int n_samples = ctx.energy.size();
if (n_samples == 0) {
fprintf(stderr, "%s: no signal data available\n", __func__);
return;
}
- auto & t_beg = ctx->t_beg;
- auto & t_last = ctx->t_last;
- auto & tid_last = ctx->tid_last;
+ auto & t_beg = ctx.t_beg;
+ auto & t_last = ctx.t_last;
+ auto & tid_last = ctx.tid_last;
for (int j = 0; j < n; ++j) {
auto & token = tokens[j];
if (j == 0) {
- if (token.id == whisper_token_beg(ctx)) {
+ if (token.id == whisper_token_beg(&ctx)) {
tokens[j ].t0 = t0;
tokens[j ].t1 = t0;
tokens[j + 1].t0 = t0;
t_beg = t0;
t_last = t0;
- tid_last = whisper_token_beg(ctx);
+ tid_last = whisper_token_beg(&ctx);
} else {
tokens[j ].t0 = t_last;
}
}
- const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx));
+ const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(&ctx));
tokens[j].id = token.id;
tokens[j].tid = token.tid;
tokens[j].pt = token.pt;
tokens[j].ptsum = token.ptsum;
- tokens[j].vlen = voice_length(whisper_token_to_str(ctx, token.id));
+ tokens[j].vlen = voice_length(whisper_token_to_str(&ctx, token.id));
if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
if (j > 0) {
p1--;
}
+ //printf("p0=%d p1=%d t0=%lld t1=%lld\n", p0, p1, tokens[p0].t0, tokens[p1].t1);
+
if (p1 > p0) {
double psum = 0.0;
for (int j = p0; j <= p1; j++) {
const int hw = WHISPER_SAMPLE_RATE/8;
for (int j = 0; j < n; j++) {
- if (tokens[j].id >= whisper_token_eot(ctx)) {
+ if (tokens[j].id >= whisper_token_eot(&ctx)) {
continue;
}
float sum = 0.0f;
for (int k = ss0; k < ss1; k++) {
- sum += ctx->energy[k];
+ sum += ctx.energy[k];
}
const float thold = 0.5*sum/ns;
{
int k = s0;
- if (ctx->energy[k] > thold && j > 0) {
- while (k > 0 && ctx->energy[k] > thold) {
+ if (ctx.energy[k] > thold && j > 0) {
+ while (k > 0 && ctx.energy[k] > thold) {
k--;
}
tokens[j].t0 = sample_to_timestamp(k);
s0 = k;
}
} else {
- while (ctx->energy[k] < thold && k < s1) {
+ while (ctx.energy[k] < thold && k < s1) {
k++;
}
s0 = k;
{
int k = s1;
- if (ctx->energy[k] > thold) {
- while (k < n_samples - 1 && ctx->energy[k] > thold) {
+ if (ctx.energy[k] > thold) {
+ while (k < n_samples - 1 && ctx.energy[k] > thold) {
k++;
}
tokens[j].t1 = sample_to_timestamp(k);
s1 = k;
}
} else {
- while (ctx->energy[k] < thold && k > s0) {
+ while (ctx.energy[k] < thold && k > s0) {
k--;
}
s1 = k;
// debug info
//for (int j = 0; j < n; ++j) {
// const auto & token = tokens[j];
- // const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]";
+ // const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(&ctx, token.tid) : "[?]";
// printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
- // tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(ctx, token.id));
+ // tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(&ctx, token.id));
- // if (tokens[j].id >= whisper_token_eot(ctx)) {
+ // if (tokens[j].id >= whisper_token_eot(&ctx)) {
// continue;
// }
//}