#include "ggml-alloc.h"
#include "ggml-backend.h"
+#include <atomic>
#include <algorithm>
#include <cassert>
#define _USE_MATH_DEFINES
//#define WHISPER_USE_FLASH_ATTN
//#define WHISPER_USE_FLASH_FF
-#define WHISPER_MAX_DECODERS 16
+#define WHISPER_MAX_DECODERS 8
#define WHISPER_MAX_NODES 4096
//
bool speaker_turn_next;
};
+struct whisper_batch {
+ int32_t n_tokens;
+
+ whisper_token * token;
+ whisper_pos * pos;
+ int32_t * n_seq_id;
+ whisper_seq_id ** seq_id; // null terminated
+ int8_t * logits;
+};
+
+static struct whisper_batch whisper_batch_init(int32_t n_tokens, int32_t n_seq_max) {
+ whisper_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, };
+
+ batch.token = (whisper_token * ) malloc(sizeof(whisper_token) * (n_tokens));
+ batch.pos = (whisper_pos *) malloc(sizeof(whisper_pos) * (n_tokens));
+ batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * (n_tokens));
+ batch.seq_id = (whisper_seq_id **) malloc(sizeof(whisper_seq_id *) * (n_tokens + 1));
+ for (int i = 0; i < n_tokens; ++i) {
+ batch.seq_id[i] = (whisper_seq_id *) malloc(sizeof(whisper_seq_id) * n_seq_max);
+ }
+ batch.seq_id[n_tokens] = nullptr;
+ batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
+
+ return batch;
+}
+
+static void whisper_batch_free(struct whisper_batch batch) {
+ if (batch.token) free(batch.token);
+ if (batch.pos) free(batch.pos);
+ if (batch.n_seq_id) free(batch.n_seq_id);
+ if (batch.seq_id) {
+ for (int i = 0; batch.seq_id[i]; ++i) {
+ free(batch.seq_id[i]);
+ }
+ free(batch.seq_id);
+ }
+ if (batch.logits) free(batch.logits);
+}
+
+static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token * tokens, int n_tokens, int n_past, int seq_id) {
+ batch.n_tokens = n_tokens;
+ for (int i = 0; i < n_tokens; ++i) {
+ if (tokens) {
+ batch.token[i] = tokens[i];
+ }
+ batch.pos [i] = n_past + i;
+ batch.n_seq_id[i] = 1;
+ batch.seq_id [i][0] = seq_id;
+ batch.logits [i] = 0;
+ }
+ batch.logits[n_tokens - 1] = 1;
+}
+
+// replace std::pair by using customized pair struct (reason: std::pair is very slow)
+template<typename A, typename B>
+struct whisper_pair {
+ A first;
+ B second;
+
+ // Define a constructor that takes two arguments.
+ whisper_pair(const A& a, const B& b) : first(a), second(b) {}
+ // Define a constructor that takes no argument.
+ whisper_pair() : first(A()), second(B()) {}
+};
+
+// ggml_allocr wrapper for whisper usage
+struct whisper_allocr {
+ ggml_allocr * alloc = nullptr;
+
+ std::vector<uint8_t> meta;
+
+ ggml_backend_buffer_t buffer;
+};
+
+static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
+ return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc);
+}
+
+// measure the memory usage of a graph and prepare the allocr's internal data buffer
+static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function<struct ggml_cgraph *()> && get_graph) {
+ auto & alloc = allocr.alloc;
+ auto & meta = allocr.meta;
+
+ alloc = ggml_allocr_new_measure_from_backend(backend);
+
+ meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
+
+ ggml_allocr_alloc_graph(alloc, get_graph());
+}
+
+static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) {
+ if (allocr.alloc == nullptr) {
+ // this can be null if we use external encoder like CoreML or OpenVINO
+ return;
+ }
+
+ auto & alloc = allocr.alloc;
+ auto & buffer = allocr.buffer;
+
+ size_t size = ggml_allocr_max_size(alloc);
+
+ ggml_allocr_free(alloc);
+
+ buffer = ggml_backend_alloc_buffer(backend, size);
+ alloc = ggml_allocr_new_from_buffer(buffer);
+}
+
+static void whisper_allocr_free(struct whisper_allocr & allocr) {
+ if (allocr.alloc) {
+ ggml_allocr_free(allocr.alloc);
+ ggml_backend_buffer_free(allocr.buffer);
+ allocr.alloc = nullptr;
+ }
+}
+
// medium
// hparams: {
// 'n_mels': 80,
struct ggml_tensor * mlp_1_b;
};
+struct whisper_kv_cell {
+ whisper_pos pos = -1;
+
+ std::set<whisper_seq_id> seq_id;
+
+ bool has_seq_id(const whisper_seq_id & id) const {
+ return seq_id.find(id) != seq_id.end();
+ }
+};
+
struct whisper_kv_cache {
+ uint32_t head = 0;
+ uint32_t size = 0;
+
+ // computed before each graph build
+ uint32_t n = 0;
+
+ std::vector<whisper_kv_cell> cells;
+
struct ggml_tensor * k;
struct ggml_tensor * v;
struct ggml_context * ctx;
ggml_backend_buffer_t buffer;
-
- int n; // number of tokens currently in the cache
};
struct whisper_model {
std::map<std::string, struct ggml_tensor *> tensors;
};
+struct whisper_partial_utf8 {
+ uint32_t value; // bit value so far (unshifted)
+ int n_remain; // num bytes remaining; -1 indicates invalid sequence
+};
+
+struct whisper_grammar {
+ /*const*/ std::vector<std::vector<whisper_grammar_element>> rules;
+ std::vector<std::vector<const whisper_grammar_element *>> stacks;
+
+ // buffer for partially generated UTF-8 sequence from accepted tokens
+ whisper_partial_utf8 partial_utf8;
+};
+
+struct whisper_grammar_candidate {
+ whisper_token id;
+ const uint32_t * code_points;
+ whisper_partial_utf8 partial_utf8;
+};
+
struct whisper_sequence {
std::vector<whisper_token_data> tokens;
// TAGS: WHISPER_DECODER_INIT
struct whisper_decoder {
- // each decoder keeps its own KV-cache
- whisper_kv_cache kv_self;
-
// the currently generated sequence of tokens
whisper_sequence sequence;
+ // grammar parse state of generated sequence of tokens
+ whisper_grammar grammar;
+
+ int i_batch; // the index of the token in the current batch
int seek_delta; // the window shift found so far based on the decoded timestamp tokens
bool failed; // has the current segment failed to decode?
std::vector<float> logits;
std::vector<float> logprobs;
- std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
-};
-
-// replace std::pair by using customized pair struct (reason: std::pair is very slow)
-template<typename A, typename B>
-struct whisper_pair {
- A first;
- B second;
-
- // Define a constructor that takes two arguments.
- whisper_pair(const A& a, const B& b) : first(a), second(b) {}
- // Define a constructor that takes no argument.
- whisper_pair() : first(A()), second(B()) {}
-};
-
-// beam-search helpers
-struct kv_buf {
- std::vector<uint8_t> k;
- std::vector<uint8_t> v;
-};
-
-// ggml_allocr wrapper for whisper usage
-struct whisper_allocr {
- ggml_allocr * alloc = nullptr;
-
- std::vector<uint8_t> meta;
+ // work container used to avoid memory allocations
+ std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
- ggml_backend_buffer_t buffer;
+ mutable std::mt19937 rng; // used for sampling at t > 0.0
};
-static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
- return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc);
-}
-
-// measure the memory usage of a graph and prepare the allocr's internal data buffer
-static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function<struct ggml_cgraph *()> && get_graph) {
- auto & alloc = allocr.alloc;
- auto & meta = allocr.meta;
-
- alloc = ggml_allocr_new_measure_from_backend(backend);
-
- meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
-
- ggml_allocr_alloc_graph(alloc, get_graph());
-}
-
-static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) {
- if (allocr.alloc == nullptr) {
- // this can be null if we use external encoder like CoreML or OpenVINO
- return;
- }
-
- auto & alloc = allocr.alloc;
- auto & buffer = allocr.buffer;
-
- size_t size = ggml_allocr_max_size(alloc);
-
- ggml_allocr_free(alloc);
-
- buffer = ggml_backend_alloc_buffer(backend, size);
- alloc = ggml_allocr_new_from_buffer(buffer);
-}
-
-static void whisper_allocr_free(struct whisper_allocr & allocr) {
- if (allocr.alloc) {
- ggml_allocr_free(allocr.alloc);
- ggml_backend_buffer_free(allocr.buffer);
- allocr.alloc = nullptr;
- }
-}
-
struct whisper_state {
int64_t t_sample_us = 0;
int64_t t_encode_us = 0;
int64_t t_decode_us = 0;
+ int64_t t_batchd_us = 0;
int64_t t_prompt_us = 0;
int64_t t_mel_us = 0;
int32_t n_sample = 0; // number of tokens sampled
int32_t n_encode = 0; // number of encoder calls
- int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
- int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
+ int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
+ int32_t n_batchd = 0; // number of decoder calls with n_tokens < 16 (batch decoding)
+ int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
int32_t n_fail_p = 0; // number of logprob threshold failures
int32_t n_fail_h = 0; // number of entropy threshold failures
+ // unified self-attention KV cache for all decoders
+ whisper_kv_cache kv_self;
+
// cross-attention KV cache for the decoders
// shared between all decoders
whisper_kv_cache kv_cross;
+
whisper_mel mel;
- whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
+ whisper_batch batch;
- // buffer for swapping KV caches between decoders during beam-search
- std::vector<kv_buf> kv_swap_bufs;
+ whisper_decoder decoders[WHISPER_MAX_DECODERS];
ggml_backend_t backend = nullptr;
struct ggml_tensor * embd_conv = nullptr;
struct ggml_tensor * embd_enc = nullptr;
- // helper for GPU offloading
+ // helpers for GPU offloading
std::vector<float> inp_mel;
+ std::vector<float> inp_mask;
// decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits;
std::vector<whisper_segment> result_all;
std::vector<whisper_token> prompt_past;
- // work container used to avoid memory allocations
- std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
-
- mutable std::mt19937 rng; // used for sampling at t > 0.0
-
int lang_id = 0; // english by default
std::string path_model; // populated by whisper_init_from_file_with_params()
/*.no_alloc =*/ true,
};
+ cache.head = 0;
+ cache.size = n_ctx;
+
+ cache.cells.clear();
+ cache.cells.resize(n_ctx);
+
cache.ctx = ggml_init(params);
if (!cache.ctx) {
return true;
}
-// TODO: remove after batched decoding
-static bool kv_cache_reinit(struct whisper_kv_cache & cache, ggml_backend_t backend) {
- WHISPER_ASSERT(cache.ctx);
+static void kv_cache_free(struct whisper_kv_cache & cache) {
+ if (cache.ctx) {
+ ggml_free(cache.ctx);
+ ggml_backend_buffer_free(cache.buffer);
+ cache.ctx = nullptr;
+ }
+}
- const int n_elements = ggml_nelements(cache.k);
- WHISPER_ASSERT(n_elements == ggml_nelements(cache.v));
+static bool whisper_kv_cache_find_slot(
+ struct whisper_kv_cache & cache,
+ const struct whisper_batch & batch) {
+ const uint32_t n_ctx = cache.size;
+ const uint32_t n_tokens = batch.n_tokens;
- const ggml_type wtype = cache.k->type;
- WHISPER_ASSERT(wtype == cache.v->type);
+ if (n_tokens > n_ctx) {
+ WHISPER_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx);
+ return false;
+ }
- struct ggml_init_params params = {
- /*.mem_size =*/ 2*ggml_tensor_overhead(),
- /*.mem_buffer =*/ nullptr,
- /*.no_alloc =*/ true,
- };
+ uint32_t n_tested = 0;
- cache.ctx = ggml_init(params);
+ while (true) {
+ if (cache.head + n_tokens > n_ctx) {
+ n_tested += n_ctx - cache.head;
+ cache.head = 0;
+ continue;
+ }
- if (!cache.ctx) {
- WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__);
- return false;
+ bool found = true;
+ for (uint32_t i = 0; i < n_tokens; i++) {
+ if (cache.cells[cache.head + i].pos >= 0) {
+ found = false;
+ cache.head += i + 1;
+ n_tested += i + 1;
+ break;
+ }
+ }
+
+ if (found) {
+ break;
+ }
+
+ if (n_tested >= n_ctx) {
+ //WHISPER_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
+ return false;
+ }
}
- cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
- cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
+ for (uint32_t i = 0; i < n_tokens; i++) {
+ cache.cells[cache.head + i].pos = batch.pos[i];
- const size_t mem_bytes = ggml_nbytes(cache.k) + ggml_nbytes(cache.v);
+ for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
+ cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]);
+ }
+ }
- cache.buffer = ggml_backend_alloc_buffer(backend, mem_bytes);
+ return true;
+}
- // allocate the tensors into the backend buffer
- {
- ggml_allocr * alloc = ggml_allocr_new_from_buffer(cache.buffer);
+// find how many cells are currently in use
+static int32_t whisper_kv_cache_cell_max(const struct whisper_kv_cache & cache) {
+ for (uint32_t i = cache.size - 1; i > 0; --i) {
+ if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) {
+ return i + 1;
+ }
+ }
- ggml_allocr_alloc(alloc, cache.k);
- ggml_allocr_alloc(alloc, cache.v);
+ return 1;
+}
- ggml_allocr_free(alloc);
+static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) {
+ for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
+ cache.cells[i].pos = -1;
+ cache.cells[i].seq_id.clear();
}
+ cache.head = 0;
+}
- return true;
+static void whisper_kv_cache_seq_rm(
+ struct whisper_kv_cache & cache,
+ whisper_seq_id seq_id,
+ whisper_pos p0,
+ whisper_pos p1) {
+ uint32_t new_head = cache.size;
+
+ if (p0 < 0) p0 = 0;
+ if (p1 < 0) p1 = std::numeric_limits<whisper_pos>::max();
+
+ for (uint32_t i = 0; i < cache.size; ++i) {
+ if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
+ if (seq_id < 0) {
+ cache.cells[i].seq_id.clear();
+ } else if (cache.cells[i].has_seq_id(seq_id)) {
+ cache.cells[i].seq_id.erase(seq_id);
+ } else {
+ continue;
+ }
+ if (cache.cells[i].seq_id.empty()) {
+ cache.cells[i].pos = -1;
+ if (new_head == cache.size) new_head = i;
+ }
+ }
+ }
+
+ // If we freed up a slot, set head to it so searching can start there.
+ if (new_head != cache.size) cache.head = new_head;
}
-static void kv_cache_free(struct whisper_kv_cache & cache) {
- if (cache.ctx) {
- ggml_free(cache.ctx);
- ggml_backend_buffer_free(cache.buffer);
- cache.ctx = nullptr;
+static void whisper_kv_cache_seq_cp(
+ struct whisper_kv_cache & cache,
+ whisper_seq_id seq_id_src,
+ whisper_seq_id seq_id_dst,
+ whisper_pos p0,
+ whisper_pos p1) {
+ if (p0 < 0) p0 = 0;
+ if (p1 < 0) p1 = std::numeric_limits<whisper_pos>::max();
+
+ cache.head = 0;
+
+ for (uint32_t i = 0; i < cache.size; ++i) {
+ if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
+ cache.cells[i].seq_id.insert(seq_id_dst);
+ }
}
}
// initialize the backends
#ifdef GGML_USE_CUBLAS
- if (params.use_gpu) {
+ if (params.use_gpu && ggml_cublas_loaded()) {
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
backend_gpu = ggml_backend_cuda_init();
if (!backend_gpu) {
word = "[_EOT_]";
} else if (i == vocab.token_sot) {
word = "[_SOT_]";
+ } else if (i == vocab.token_translate) {
+ word = "[_TRANSLATE_]";
+ } else if (i == vocab.token_transcribe) {
+ word = "[_TRANSCRIBE_]";
} else if (i == vocab.token_solm) {
word = "[_SOLM_]";
} else if (i == vocab.token_prev) {
word = "[_NOT_]";
} else if (i == vocab.token_beg) {
word = "[_BEG_]";
+ } else if (i > vocab.token_sot && i <= vocab.token_sot + vocab.num_languages()) {
+ word = "[_LANG_" + std::string(whisper_lang_str(i - vocab.token_sot - 1)) + "]";
} else {
word = "[_extra_token_" + std::to_string(i) + "]";
}
model.buffer = ggml_backend_alloc_buffer(wctx.backend, size_main);
- WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend), size_main / 1024.0 / 1024.0);
+ WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend), size_main / 1e6);
}
ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer);
ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
}
- //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype), ggml_nbytes(tensor)/1024.0/1024.0);
+ //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype), ggml_nbytes(tensor)/1e6);
total_size += ggml_nbytes(tensor);
model.n_loaded++;
}
- WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0);
+ WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6);
if (model.n_loaded == 0) {
WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
////////////////////////////////////////////////////////////////////////////
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
- // ggml_used_mem(ctx0)/1024.0/1024.0,
- // wstate.get_buf_max_mem(0)/1024.0/1024.0,
- // wstate.get_buf_max_mem(1)/1024.0/1024.0,
- // wstate.get_buf_max_mem(2)/1024.0/1024.0,
- // wstate.get_buf_max_mem(3)/1024.0/1024.0);
+ // ggml_used_mem(ctx0)/1e6,
+ // wstate.get_buf_max_mem(0)/1e6,
+ // wstate.get_buf_max_mem(1)/1e6,
+ // wstate.get_buf_max_mem(2)/1e6,
+ // wstate.get_buf_max_mem(3)/1e6);
ggml_free(ctx0);
static struct ggml_cgraph * whisper_build_graph_decoder(
whisper_context & wctx,
whisper_state & wstate,
- whisper_decoder & decoder,
- const whisper_token * tokens,
- int n_tokens,
- int n_past) {
+ const whisper_batch & batch) {
const auto & model = wctx.model;
const auto & hparams = model.hparams;
- auto & kv_self = decoder.kv_self;
+ auto & kv_self = wstate.kv_self;
WHISPER_ASSERT(!!kv_self.ctx);
- const int n_ctx = hparams.n_text_ctx;
+ ggml_allocr * alloc = wstate.alloc_decode.alloc;
+
+ const int n_ctx = kv_self.size;
const int n_state = hparams.n_text_state;
const int n_head = hparams.n_text_head;
const int n_layer = hparams.n_text_layer;
- const int N = n_tokens;
- const int M = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
+ const int n_tokens = batch.n_tokens;
+ const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
+
+ const int32_t n_kv = ggml_allocr_is_measure(alloc) ? n_ctx : kv_self.n;
+ const int32_t kv_head = ggml_allocr_is_measure(alloc) ? n_ctx - n_tokens : kv_self.head;
- //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
+ //WHISPER_PRINT_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
struct ggml_init_params params = {
/*.mem_size =*/ wstate.alloc_decode.meta.size(),
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
- ggml_allocr * alloc = wstate.alloc_decode.alloc;
-
- struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+ struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
ggml_allocr_alloc(alloc, embd);
if (!ggml_allocr_is_measure(alloc)) {
- ggml_backend_tensor_set(embd, tokens, 0, N*ggml_element_size(embd));
+ ggml_backend_tensor_set(embd, batch.token, 0, n_tokens*ggml_element_size(embd));
}
- struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+ struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
ggml_allocr_alloc(alloc, position);
if (!ggml_allocr_is_measure(alloc)) {
- for (int i = 0; i < N; ++i) {
- const int32_t val = n_past + i;
+ for (int i = 0; i < n_tokens; ++i) {
+ const int32_t val = batch.pos[i];
ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t));
}
}
ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
}
- // token encoding + position encoding
- struct ggml_tensor * cur =
- ggml_add(ctx0,
- ggml_get_rows(ctx0, model.d_te, embd),
- ggml_get_rows(ctx0, model.d_pe, position));
+ struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+ ggml_allocr_alloc(alloc, KQ_mask);
- struct ggml_tensor * inpL = cur;
+ if (!ggml_allocr_is_measure(alloc)) {
+ wstate.inp_mask.resize(n_kv*n_tokens);
- for (int il = 0; il < n_layer; ++il) {
- const auto & layer = model.layers_decoder[il];
+ float * data = wstate.inp_mask.data();
+ memset(data, 0, ggml_nbytes(KQ_mask));
- // norm
+ for (int h = 0; h < 1; ++h) {
+ for (int j = 0; j < n_tokens; ++j) {
+ const whisper_pos pos = batch.pos[j];
+ const whisper_seq_id seq_id = batch.seq_id[j][0];
+
+ for (int i = 0; i < n_kv; ++i) {
+ if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
+ data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
+ }
+ }
+ }
+ }
+
+ ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float));
+ }
+
+ // token encoding + position encoding
+ struct ggml_tensor * cur =
+ ggml_add(ctx0,
+ ggml_get_rows(ctx0, model.d_te, embd),
+ ggml_get_rows(ctx0, model.d_pe, position));
+
+ struct ggml_tensor * inpL = cur;
+
+ for (int il = 0; il < n_layer; ++il) {
+ const auto & layer = model.layers_decoder[il];
+
+ // norm
{
cur = ggml_norm(ctx0, inpL, hparams.eps);
Vcur,
layer.attn_v_b);
- Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N));
+ Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
- struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
- struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_state,
+ struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
+ struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
( n_ctx)*ggml_element_size(kv_self.v),
- (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v));
+ (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
struct ggml_tensor * Q =
ggml_permute(ctx0,
- ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
+ ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
0, 2, 1, 3);
struct ggml_tensor * K =
ggml_view_3d(ctx0, kv_self.k,
- n_state/n_head, n_past + N, n_head,
+ n_state/n_head, n_kv, n_head,
ggml_element_size(kv_self.k)*n_state,
ggml_element_size(kv_self.k)*n_state/n_head,
ggml_element_size(kv_self.k)*n_state*n_ctx*il);
//struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
- struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
+ //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
+ struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ, KQ_mask);
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
struct ggml_tensor * V =
ggml_view_3d(ctx0, kv_self.v,
- n_past + N, n_state/n_head, n_head,
+ n_kv, n_state/n_head, n_head,
n_ctx*ggml_element_size(kv_self.v),
n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
- il*n_ctx*ggml_element_size(kv_self.v)*n_state);
+ n_ctx*ggml_element_size(kv_self.v)*n_state*il);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
cur = ggml_cpy(ctx0,
KQV_merged,
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
}
// projection
// Kcross is already scaled
struct ggml_tensor * Kcross =
ggml_view_3d(ctx0, wstate.kv_cross.k,
- n_state/n_head, M, n_head,
+ n_state/n_head, n_audio_ctx, n_head,
ggml_element_size(wstate.kv_cross.k)*n_state,
ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
- ggml_element_size(wstate.kv_cross.k)*n_state*M*il);
+ ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
//struct ggml_tensor * Vcross =
// ggml_reshape_3d(ctx0,
- // ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
- // n_state/n_head, n_head, M);
+ // ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state),
+ // n_state/n_head, n_head, n_audio_ctx);
//struct ggml_tensor * V_trans =
// ggml_cpy(ctx0,
// ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
- // ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head));
+ // ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head));
struct ggml_tensor * V =
ggml_view_3d(ctx0, wstate.kv_cross.v,
- M, n_state/n_head, n_head,
- M*ggml_element_size(wstate.kv_cross.v),
- M*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
- il*M*ggml_element_size(wstate.kv_cross.v)*n_state);
+ n_audio_ctx, n_state/n_head, n_head,
+ n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
+ n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
+ n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
// ------
struct ggml_tensor * Q =
ggml_permute(ctx0,
- ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
+ ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
0, 2, 1, 3);
// K * Q
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
- // cur = KQV_merged.contiguous().view(n_state, N)
+ // cur = KQV_merged.contiguous().view(n_state, n_tokens)
cur = ggml_cpy(ctx0,
KQV_merged,
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
}
// projection
}
// compute logits only for the last token
- // comment this line to compute logits for all N tokens
+ // comment this line to compute logits for all n_tokens
// might be useful in the future
- cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
+ //cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
static bool whisper_decode_internal(
whisper_context & wctx,
whisper_state & wstate,
- whisper_decoder & decoder,
- const whisper_token * tokens,
- const int n_tokens,
- const int n_past,
+ const whisper_batch & batch,
const int n_threads,
whisper_abort_callback abort_callback,
void * abort_callback_data) {
const auto & model = wctx.model;
const auto & hparams = model.hparams;
- const int n_vocab = hparams.n_vocab;
+ const int n_vocab = hparams.n_vocab;
+ const int n_tokens = batch.n_tokens;
auto & logits_out = wstate.logits;
struct ggml_tensor * logits;
+ // find KV slot for the batch
+ {
+ auto & kv_self = wstate.kv_self;
+
+ if (!whisper_kv_cache_find_slot(kv_self, batch)) {
+ return false;
+ }
+
+ kv_self.n = whisper_kv_cache_cell_max(kv_self);
+ //kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
+ //printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
+ }
+
// decoder
{
auto & alloc = wstate.alloc_decode.alloc;
ggml_allocr_reset(alloc);
- ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, tokens, n_tokens, n_past);
+ ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch);
ggml_allocr_alloc_graph(alloc, gf);
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
}
- // extract logits for all N tokens
- //logits_out.resize(n_tokens*n_vocab);
- //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab);
- //ggml_backend_tensor_get(logits, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), sizeof(float)*n_vocab);
-
- // extract logits only for the last token
- logits_out.resize(n_vocab);
- //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
- ggml_backend_tensor_get(logits, logits_out.data(), 0, sizeof(float)*n_vocab);
+ logits_out.resize(n_tokens*n_vocab);
+ for (int i = 0; i < n_tokens; i++) {
+ if (batch.logits[i] == 0) {
+ continue;
+ }
+ ggml_backend_tensor_get(logits, logits_out.data() + (n_vocab*i), sizeof(float)*(n_vocab*i), sizeof(float)*n_vocab);
+ }
- if (n_tokens > 1) {
+ if (batch.n_tokens > 1) {
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
- // ggml_used_mem(ctx0)/1024.0/1024.0,
- // wstate.get_buf_max_mem(0)/1024.0/1024.0,
- // wstate.get_buf_max_mem(1)/1024.0/1024.0,
- // wstate.get_buf_max_mem(2)/1024.0/1024.0,
- // wstate.get_buf_max_mem(3)/1024.0/1024.0);
+ // ggml_used_mem(ctx0)/1e6,
+ // wstate.get_buf_max_mem(0)/1e6,
+ // wstate.get_buf_max_mem(1)/1e6,
+ // wstate.get_buf_max_mem(2)/1e6,
+ // wstate.get_buf_max_mem(3)/1e6);
}
- if (n_tokens == 1) {
+ if (batch.n_tokens == 1) {
wstate.t_decode_us += ggml_time_us() - t_start_us;
wstate.n_decode++;
+ } else if (batch.n_tokens < 16) {
+ wstate.t_batchd_us += ggml_time_us() - t_start_us;
+ wstate.n_batchd += n_tokens;
} else {
wstate.t_prompt_us += ggml_time_us() - t_start_us;
- wstate.n_prompt++;
+ wstate.n_prompt += n_tokens;
}
return !(abort_callback && abort_callback(abort_callback_data));
}
-
// 500 -> 00:05.000
// 6000 -> 01:00.000
static std::string to_timestamp(int64_t t, bool comma = false) {
state->backend = whisper_backend_init(ctx->params);
- if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) {
+ // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
+ // in theory, there can be a case where this is not enough, but in practice it should always be enough
+ const int factor = 3;
+
+ if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) {
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
delete state;
return nullptr;
}
{
- const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v);
- WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
+ const size_t memory_size = ggml_nbytes(state->kv_self.k) + ggml_nbytes(state->kv_self.v);
+ WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
}
if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
{
const size_t memory_size = ggml_nbytes(state->kv_cross.k) + ggml_nbytes(state->kv_cross.v);
- WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
+ WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
}
#ifdef WHISPER_USE_COREML
state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
- state->logits_id.reserve(ctx->model.hparams.n_vocab);
+ state->batch = whisper_batch_init(ctx->model.hparams.n_text_ctx, WHISPER_MAX_DECODERS);
// TAGS: WHISPER_DECODER_INIT
state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
- state->decoders[0].probs.reserve (ctx->vocab.n_vocab);
- state->decoders[0].logits.reserve (ctx->vocab.n_vocab);
- state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
+ state->decoders[0].probs.reserve (ctx->vocab.n_vocab);
+ state->decoders[0].logits.reserve (ctx->vocab.n_vocab);
+ state->decoders[0].logprobs.reserve (ctx->vocab.n_vocab);
+ state->decoders[0].logits_id.reserve(ctx->model.hparams.n_vocab);
+
+ state->decoders[0].rng = std::mt19937(0);
// conv allocator
{
return whisper_build_graph_conv(*ctx, *state, 0);
});
- WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0);
+ WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1e6);
}
// encoder allocator
return whisper_build_graph_encoder(*ctx, *state);
});
- WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0);
+ WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1e6);
}
// cross allocator
return whisper_build_graph_cross(*ctx, *state);
});
- WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0);
+ WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1e6);
}
// decoder allocator
const int n_tokens = hparams.n_text_ctx;
const int n_past = 0;
- return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
+ whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0);
+
+ return whisper_build_graph_decoder(*ctx, *state, state->batch);
});
- WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
+ WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1e6);
}
whisper_allocr_graph_realloc(state->alloc_conv, ctx->backend);
whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend);
whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend);
- state->rng = std::mt19937(0);
-
return state;
}
void whisper_free_state(struct whisper_state * state)
{
if (state) {
+ kv_cache_free(state->kv_self);
kv_cache_free(state->kv_cross);
- for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
- kv_cache_free(state->decoders[i].kv_self);
- }
-
#ifdef WHISPER_USE_COREML
if (state->ctx_coreml != nullptr) {
whisper_coreml_free(state->ctx_coreml);
}
#endif
+ whisper_batch_free(state->batch);
+
whisper_allocr_free(state->alloc_conv);
whisper_allocr_free(state->alloc_encode);
whisper_allocr_free(state->alloc_cross);
}
int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
- const int selected_decoder_id = 0;
+ whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past, 0);
- if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
+ whisper_kv_cache_seq_rm(ctx->state->kv_self, 0, n_past, -1);
+
+ if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, nullptr, nullptr)) {
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
return 1;
}
}
int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
- // TODO: add selected_decoder_id to state
- const int selected_decoder_id = 0;
-
if (ctx->state == nullptr) {
WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__);
return false;
}
- if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
+ whisper_kv_cache_seq_rm(ctx->state->kv_self, 0, n_past, -1);
+
+ whisper_batch_prep_legacy(ctx->state->batch, tokens, n_tokens, n_past, 0);
+
+ if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->batch, n_threads, nullptr, nullptr)) {
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
return 1;
}
return -7;
}
- auto & logits_id = state->logits_id;
+ auto & logits_id = state->decoders[0].logits_id;
logits_id.clear();
for (const auto & kv : g_lang) {
const int32_t n_sample = std::max(1, ctx->state->n_sample);
const int32_t n_encode = std::max(1, ctx->state->n_encode);
const int32_t n_decode = std::max(1, ctx->state->n_decode);
+ const int32_t n_batchd = std::max(1, ctx->state->n_batchd);
const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
+ WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
}
WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
ctx->state->n_sample = 0;
ctx->state->n_encode = 0;
ctx->state->n_decode = 0;
+ ctx->state->n_batchd = 0;
ctx->state->n_prompt = 0;
}
}
return s.c_str();
}
+//////////////////////////////////
+// Grammar - ported from llama.cpp
+//////////////////////////////////
+
+// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
+// pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
+std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
+ const char * src,
+ whisper_partial_utf8 partial_start) {
+ static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
+ const char * pos = src;
+ std::vector<uint32_t> code_points;
+ uint32_t value = partial_start.value;
+ int n_remain = partial_start.n_remain;
+
+ // continue previous decode, if applicable
+ while (*pos != 0 && n_remain > 0) {
+ uint8_t next_byte = static_cast<uint8_t>(*pos);
+ if ((next_byte >> 6) != 2) {
+ // invalid sequence, abort
+ code_points.push_back(0);
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
+ }
+ value = (value << 6) + (next_byte & 0x3F);
+ ++pos;
+ --n_remain;
+ }
+
+ if (partial_start.n_remain > 0 && n_remain == 0) {
+ code_points.push_back(value);
+ }
+
+ // decode any subsequent utf-8 sequences, which may end in an incomplete one
+ while (*pos != 0) {
+ uint8_t first_byte = static_cast<uint8_t>(*pos);
+ uint8_t highbits = first_byte >> 4;
+ n_remain = lookup[highbits] - 1;
+
+ if (n_remain < 0) {
+ // invalid sequence, abort
+ code_points.clear();
+ code_points.push_back(0);
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain });
+ }
+
+ uint8_t mask = (1 << (7 - n_remain)) - 1;
+ value = first_byte & mask;
+ ++pos;
+ while (*pos != 0 && n_remain > 0) {
+ value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
+ ++pos;
+ --n_remain;
+ }
+ if (n_remain == 0) {
+ code_points.push_back(value);
+ }
+ }
+ code_points.push_back(0);
+
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ value, n_remain });
+}
+
+// returns true iff pos points to the end of one of the definitions of a rule
+static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element * pos) {
+ switch (pos->type) {
+ case WHISPER_GRETYPE_END: return true; // NOLINT
+ case WHISPER_GRETYPE_ALT: return true; // NOLINT
+ default: return false;
+ }
+}
+
+// returns true iff chr satisfies the char range at pos (regular or inverse range)
+// asserts that pos is pointing to a char range element
+static std::pair<bool, const whisper_grammar_element *> whisper_grammar_match_char(
+ const whisper_grammar_element * pos,
+ const uint32_t chr) {
+
+ bool found = false;
+ bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
+
+ WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT
+
+ do {
+ if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
+ // inclusive range, e.g. [a-z]
+ found = found || (pos->value <= chr && chr <= pos[1].value);
+ pos += 2;
+ } else {
+ // exact char match, e.g. [a] or "a"
+ found = found || pos->value == chr;
+ pos += 1;
+ }
+ } while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
+
+ return std::make_pair(found == is_positive_char, pos);
+}
+
+// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
+// range at pos (regular or inverse range)
+// asserts that pos is pointing to a char range element
+static bool whisper_grammar_match_partial_char(
+ const whisper_grammar_element * pos,
+ const whisper_partial_utf8 partial_utf8) {
+
+ bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
+ WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT);
+
+ uint32_t partial_value = partial_utf8.value;
+ int n_remain = partial_utf8.n_remain;
+
+ // invalid sequence or 7-bit char split across 2 bytes (overlong)
+ if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
+ return false;
+ }
+
+ // range of possible code points this partial UTF-8 sequence could complete to
+ uint32_t low = partial_value << (n_remain * 6);
+ uint32_t high = low | ((1 << (n_remain * 6)) - 1);
+
+ if (low == 0) {
+ if (n_remain == 2) {
+ low = 1 << 11;
+ } else if (n_remain == 3) {
+ low = 1 << 16;
+ }
+ }
+
+ do {
+ if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
+ // inclusive range, e.g. [a-z]
+ if (pos->value <= high && low <= pos[1].value) {
+ return is_positive_char;
+ }
+ pos += 2;
+ } else {
+ // exact char match, e.g. [a] or "a"
+ if (low <= pos->value && pos->value <= high) {
+ return is_positive_char;
+ }
+ pos += 1;
+ }
+ } while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
+
+ return !is_positive_char;
+}
+
+
+// transforms a grammar pushdown stack into N possible stacks, all ending
+// at a character range (terminal element)
+static void whisper_grammar_advance_stack(
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
+ const std::vector<const whisper_grammar_element *> & stack,
+ std::vector<std::vector<const whisper_grammar_element *>> & new_stacks) {
+
+ if (stack.empty()) {
+ new_stacks.push_back(stack);
+ return;
+ }
+
+ const whisper_grammar_element * pos = stack.back();
+
+ switch (pos->type) {
+ case WHISPER_GRETYPE_RULE_REF: {
+ const size_t rule_id = static_cast<size_t>(pos->value);
+ const whisper_grammar_element * subpos = rules[rule_id].data();
+ do {
+ // init new stack without the top (pos)
+ std::vector<const whisper_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
+ if (!whisper_grammar_is_end_of_sequence(pos + 1)) {
+ // if this rule ref is followed by another element, add that to stack
+ new_stack.push_back(pos + 1);
+ }
+ if (!whisper_grammar_is_end_of_sequence(subpos)) {
+ // if alternate is nonempty, add to stack
+ new_stack.push_back(subpos);
+ }
+ whisper_grammar_advance_stack(rules, new_stack, new_stacks);
+ while (!whisper_grammar_is_end_of_sequence(subpos)) {
+ // scan to end of alternate def
+ subpos++;
+ }
+ if (subpos->type == WHISPER_GRETYPE_ALT) {
+ // there's another alternate def of this rule to process
+ subpos++;
+ } else {
+ break;
+ }
+ } while (true);
+ break;
+ }
+ case WHISPER_GRETYPE_CHAR:
+ case WHISPER_GRETYPE_CHAR_NOT:
+ new_stacks.push_back(stack);
+ break;
+ default:
+ // end of alternate (WHISPER_GRETYPE_END, WHISPER_GRETYPE_ALT) or middle of char range
+ // (WHISPER_GRETYPE_CHAR_ALT, WHISPER_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
+ // those
+ WHISPER_ASSERT(false);
+ }
+}
+
+// takes a set of possible pushdown stacks on a grammar, which are required to
+// be positioned at a character range (see `whisper_grammar_advance_stack`), and
+// produces the N possible stacks if the given char is accepted at those
+// positions
+static std::vector<std::vector<const whisper_grammar_element *>> whisper_grammar_accept(
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
+ const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
+ const uint32_t chr) {
+
+ std::vector<std::vector<const whisper_grammar_element *>> new_stacks;
+
+ for (const auto & stack : stacks) {
+ if (stack.empty()) {
+ continue;
+ }
+
+ auto match = whisper_grammar_match_char(stack.back(), chr);
+ if (match.first) {
+ const whisper_grammar_element * pos = match.second;
+
+ // update top of stack to next element, if any
+ std::vector<const whisper_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
+ if (!whisper_grammar_is_end_of_sequence(pos)) {
+ new_stack.push_back(pos);
+ }
+ whisper_grammar_advance_stack(rules, new_stack, new_stacks);
+ }
+ }
+
+ return new_stacks;
+}
+
+static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates(
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
+ const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
+ const std::vector<whisper_grammar_candidate> & candidates);
+
+static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates_for_stack(
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
+ const std::vector<const whisper_grammar_element *> & stack,
+ const std::vector<whisper_grammar_candidate> & candidates) {
+
+ std::vector<whisper_grammar_candidate> rejects;
+
+ if (stack.empty()) {
+ for (auto tok : candidates) {
+ if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
+ rejects.push_back(tok);
+ }
+ }
+ return rejects;
+ }
+
+ const whisper_grammar_element * stack_pos = stack.back();
+
+ std::vector<whisper_grammar_candidate> next_candidates;
+ for (auto tok : candidates) {
+ if (*tok.code_points == 0) {
+ // reached end of full codepoints in token, reject iff it ended in a partial sequence
+ // that cannot satisfy this position in grammar
+ if (tok.partial_utf8.n_remain != 0 && !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
+ rejects.push_back(tok);
+ }
+ } else if (whisper_grammar_match_char(stack_pos, *tok.code_points).first) {
+ next_candidates.push_back({ tok.id, tok.code_points + 1, tok.partial_utf8 });
+ } else {
+ rejects.push_back(tok);
+ }
+ }
+
+ const auto * stack_pos_after = whisper_grammar_match_char(stack_pos, 0).second;
+
+ // update top of stack to next element, if any
+ std::vector<const whisper_grammar_element *> stack_after(stack.begin(), stack.end() - 1);
+ if (!whisper_grammar_is_end_of_sequence(stack_pos_after)) {
+ stack_after.push_back(stack_pos_after);
+ }
+ std::vector<std::vector<const whisper_grammar_element *>> next_stacks;
+ whisper_grammar_advance_stack(rules, stack_after, next_stacks);
+
+ auto next_rejects = whisper_grammar_reject_candidates(rules, next_stacks, next_candidates);
+ for (auto tok : next_rejects) {
+ rejects.push_back({ tok.id, tok.code_points - 1, tok.partial_utf8 });
+ }
+
+ return rejects;
+}
+
+static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates(
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
+ const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
+ const std::vector<whisper_grammar_candidate> & candidates) {
+ if (candidates.empty() || stacks.empty()) {
+ return std::vector<whisper_grammar_candidate>();
+ }
+
+ auto rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
+
+ for (size_t i = 1, size = stacks.size(); i < size; ++i) {
+ rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
+ }
+ return rejects;
+}
+
+static struct whisper_grammar whisper_grammar_init(
+ const whisper_grammar_element ** rules,
+ size_t n_rules,
+ size_t i_start_rule) {
+ const whisper_grammar_element * pos;
+
+ // copy rule definitions into vectors
+ std::vector<std::vector<whisper_grammar_element>> vec_rules(n_rules);
+ for (size_t i = 0; i < n_rules; i++) {
+ for (pos = rules[i]; pos->type != WHISPER_GRETYPE_END; pos++) {
+ vec_rules[i].push_back(*pos);
+ }
+ vec_rules[i].push_back({WHISPER_GRETYPE_END, 0});
+ }
+
+ // loop over alternates of start rule to build initial stacks
+ std::vector<std::vector<const whisper_grammar_element *>> stacks;
+ pos = rules[i_start_rule];
+ do {
+ std::vector<const whisper_grammar_element *> stack;
+ if (!whisper_grammar_is_end_of_sequence(pos)) {
+ // if alternate is nonempty, add to stack
+ stack.push_back(pos);
+ }
+ whisper_grammar_advance_stack(vec_rules, stack, stacks);
+ while (!whisper_grammar_is_end_of_sequence(pos)) {
+ // scan to end of alternate def
+ pos++;
+ }
+ if (pos->type == WHISPER_GRETYPE_ALT) {
+ // there's another alternate def of this rule to process
+ pos++;
+ } else {
+ break;
+ }
+ } while (true);
+
+ return { std::move(vec_rules), std::move(stacks), {} };
+}
+
+static void whisper_suppress_invalid_grammar(
+ whisper_context & ctx,
+ const whisper_full_params & params,
+ std::vector<float> & logits,
+ const whisper_grammar & grammar) {
+
+ if (grammar.rules.empty() || grammar.stacks.empty()) {
+ return;
+ }
+
+ //bool allow_eot = false;
+ //for (const auto & stack : grammar.stacks) {
+ // if (stack.empty()) {
+ // allow_eot = true;
+ // break;
+ // }
+ //}
+
+ const whisper_token eot = whisper_token_eot(&ctx);
+
+ std::vector<std::pair<std::vector<uint32_t>, whisper_partial_utf8>> candidates_decoded;
+ std::vector<whisper_grammar_candidate> candidates_grammar;
+
+ for (whisper_token id = 0; id < eot; ++id) {
+ const std::string & text = ctx.vocab.id_to_token[id];
+ if (!text.empty()) {
+ candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8));
+ candidates_grammar.push_back({ id, candidates_decoded.back().first.data(), candidates_decoded.back().second });
+ }
+ }
+
+ const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
+
+ for (const auto & reject : rejects) {
+ logits[reject.id] -= params.grammar_penalty;
+ }
+
+ // when the grammar allows a continuation, we penalize the end-of-text token
+ //if (!allow_eot) {
+ // logits[eot] -= params.grammar_penalty;
+ //}
+ //fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size());
+}
+
+static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar & grammar, whisper_token token) {
+ if (grammar.rules.empty() || grammar.stacks.empty()) {
+ return;
+ }
+
+ //fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str());
+
+ const std::string & text = ctx.vocab.id_to_token[token];
+
+ if (text.rfind("[_", 0) == 0) {
+ // fprintf(stderr, " (skipped)\n");
+ return;
+ }
+ // fprintf(stderr, "\n");
+
+ // Note terminating 0 in decoded string
+ const auto decoded = decode_utf8(text.c_str(), grammar.partial_utf8);
+ const auto & code_points = decoded.first;
+ for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
+ grammar.stacks = whisper_grammar_accept(grammar.rules, grammar.stacks, *it);
+ }
+ grammar.partial_utf8 = decoded.second;
+}
+
+//////////////
+// END grammar
+//////////////
+
////////////////////////////////////////////////////////////////////////////
struct whisper_context_params * whisper_context_default_params_by_ref() {
/*.translate =*/ false,
/*.no_context =*/ true,
+ /*.no_timestamps =*/ false,
/*.single_segment =*/ false,
/*.print_special =*/ false,
/*.print_progress =*/ true,
/*.max_initial_ts =*/ 1.0f,
/*.length_penalty =*/ -1.0f,
- /*.temperature_inc =*/ 0.4f,
+ /*.temperature_inc =*/ 0.2f,
/*.entropy_thold =*/ 2.4f,
/*.logprob_thold =*/ -1.0f,
/*.no_speech_thold =*/ 0.6f,
/*.logits_filter_callback =*/ nullptr,
/*.logits_filter_callback_user_data =*/ nullptr,
+
+ /*.grammar_rules =*/ nullptr,
+ /*.n_grammar_rules =*/ 0,
+ /*.i_start_rule =*/ 0,
+ /*.grammar_penalty =*/ 100.0f,
};
switch (strategy) {
case WHISPER_SAMPLING_GREEDY:
{
result.greedy = {
- /*.best_of =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
+ /*.best_of =*/ 5,
};
} break;
case WHISPER_SAMPLING_BEAM_SEARCH:
{
result.beam_search = {
- /*.beam_size =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
+ /*.beam_size =*/ 5,
/*.patience =*/ -1.0f,
};
// process the logits for the selected decoder
// - applies logit filters
// - computes logprobs and probs
+// TODO: optimize
static void whisper_process_logits(
struct whisper_context & ctx,
struct whisper_state & state,
- const struct whisper_full_params params,
struct whisper_decoder & decoder,
+ const struct whisper_full_params params,
float temperature) {
const auto & vocab = ctx.vocab;
const auto & tokens_cur = decoder.sequence.tokens;
auto & logprobs = decoder.logprobs;
{
logits.resize(n_logits);
- memcpy(logits.data(), state.logits.data() + (state.logits.size() - n_logits), n_logits*sizeof(float));
+ memcpy(logits.data(), state.logits.data() + decoder.i_batch*n_logits, n_logits*sizeof(float));
if (temperature > 0.0f) {
for (int i = 0; i < n_logits; i++) {
// suppress <|notimestamps|> token
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
logits[vocab.token_not] = -INFINITY;
+ if (params.no_timestamps) {
+ for (int i = vocab.token_beg; i < n_logits; ++i) {
+ logits[i] = -INFINITY;
+ }
+ }
// suppress sot and nosp tokens
logits[vocab.token_sot] = -INFINITY;
logits[vocab.token_transcribe] = -INFINITY;
logits[vocab.token_prev] = -INFINITY;
+ // suppress lang tokens
+ for (size_t i = 0; i < g_lang.size(); ++i) {
+ logits[whisper_token_lang(&ctx, i)] = -INFINITY;
+ }
+
+ // suppress prev token
+ logits[vocab.token_prev] = -INFINITY;
+
if (params.logits_filter_callback) {
params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
}
logits[i] = -INFINITY;
logprobs[i] = -INFINITY;
}
+ } else {
+ if (params.n_grammar_rules > 0) {
+ whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar);
+
+ // populate the logprobs array (log_softmax)
+ {
+ const float logit_max = *std::max_element(logits.begin(), logits.end());
+ float logsumexp = 0.0f;
+ for (int i = 0; i < n_logits; ++i) {
+ if (logits[i] > -INFINITY) {
+ logsumexp += expf(logits[i] - logit_max);
+ }
+ }
+ logsumexp = logf(logsumexp) + logit_max;
+
+ for (int i = 0; i < n_logits; ++i) {
+ if (logits[i] > -INFINITY) {
+ logprobs[i] = logits[i] - logsumexp;
+ } else {
+ logprobs[i] = -INFINITY;
+ }
+ }
+ }
+ }
}
}
}
#if 0
// print first 100 logits - token string : logit
- for (int i = 0; i < 100; i++) {
- const auto token = vocab.id_to_token.at(i);
- const auto prob = probs[i];
- const auto logit = logits[i];
- const auto logprob = logprobs[i];
- printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
+ //for (int i = 0; i < 10; i++) {
+ // const auto token = vocab.id_to_token.at(i);
+ // const auto prob = probs[i];
+ // const auto logit = logits[i];
+ // const auto logprob = logprobs[i];
+ // printf("%16s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
+ //}
+
+ // print sorted
+ {
+ std::vector<std::pair<float, int>> pairs;
+
+ for (int i = 0; i < n_logits; ++i) {
+ pairs.push_back(std::make_pair(probs[i], i));
+ }
+
+ std::sort(pairs.begin(), pairs.end(), [](const std::pair<float, int>& a, const std::pair<float, int>& b) {
+ return a.first > b.first;
+ });
+
+ for (int i = 0; i < 10; i++) {
+ const auto token = vocab.id_to_token.at(pairs[i].second);
+ const auto prob = pairs[i].first;
+ const auto logit = logits[pairs[i].second];
+ const auto logprob = logprobs[pairs[i].second];
+ printf("%16s : id=%6d prob=%9.5f logit=%9.5f logprob=%9.5f '%s'\n", token.c_str(), pairs[i].second, prob, logit, logprob, token.c_str());
+ }
+
+ printf("----------------\n");
}
// "And", "and", " And", " and"
- printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
- printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
- printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
- printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
- printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);
-
- printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
- printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
- printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
- printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
- printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
-
- printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
- printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
- printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
- printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
- printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
+ //printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
+ //printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
+ //printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
+ //printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
+ //printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);
+
+ //printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
+ //printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
+ //printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
+ //printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
+ //printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
+
+ //printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
+ //printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
+ //printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
+ //printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
+ //printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
#endif
}
static whisper_token_data whisper_sample_token(
whisper_context & ctx,
- whisper_state & state,
const whisper_decoder & decoder,
bool best) {
whisper_token_data result = {
} else {
std::discrete_distribution<> dist(probs.begin(), probs.end());
- result.id = dist(state.rng);
+ result.id = dist(decoder.rng);
result.p = probs[result.id];
result.plog = logprobs[result.id];
}
result.pt = result.p;
}
- state.n_sample++;
-
return result;
}
static std::vector<whisper_token_data> whisper_sample_token_topk(
whisper_context & ctx,
- whisper_state & state,
- const whisper_decoder & decoder,
+ whisper_decoder & decoder,
int k) {
const auto & vocab = ctx.vocab;
const int n_logits = vocab.n_vocab;
- auto & logits_id = state.logits_id;
+ auto & logits_id = decoder.logits_id;
logits_id.resize(n_logits);
for (int i = 0; i < n_logits; ++i) {
ptsum = sum_ts;
}
+ std::discrete_distribution<> dist(probs.begin(), probs.end());
+
for (int i = 0; i < k; ++i) {
- const auto id = logits_id[i].second;
+ const auto id = dist(decoder.rng);
+ //printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
}
}
- state.n_sample++;
-
return result;
}
}
}
-static bool whisper_kv_swap_fast(
- std::vector<int> & view,
- whisper_decoder src[],
- std::vector<kv_buf> & kv_swap_bufs,
- const int & n_decoders) {
- WHISPER_PRINT_DEBUG("%s: n_decoders %d\n", __func__, n_decoders);
-
- // (decoder->buffer->decoder or decoder->buffer + decoder->decoder)
- std::set<int> two_copy; // decoder indices require two copies to safely modify KV caches
-
- // (buffer->decoder or decoder->decoder)
- std::set<int> one_copy; // decoder indices require one copy to safely modify KV caches
-
- // (decoder<->decoder)
- std::set<int> p_swap_set; // decoder indices able to swap KV-cache pointers
- std::vector<whisper_pair<int, int>> p_swap_vec;
- p_swap_vec.reserve(n_decoders);
-
- // see https://github.com/ggerganov/whisper.cpp/wiki
- for (int i = 0; i < n_decoders; i++) {
- // zero-copy (no modification)
- if (i == view[i] || view[i] < 0) {
- continue;
- }
-
- bool is_one_copy = true;
- // since we modify data sequentially, we only consider decoder indices after current index
- for (int j = i + 1; j < n_decoders; j++) {
- if (i == view[j]) {
- // detect symmetric diagram
- if (j == view[i]) {
- p_swap_set.insert(i);
- p_swap_set.insert(j);
- p_swap_vec.emplace_back(i, j);
- } else {
- two_copy.insert(i);
- is_one_copy = false;
- }
- break;
- }
- }
- if (is_one_copy) {
- one_copy.insert(i);
- }
- }
-
- kv_swap_bufs.resize(n_decoders);
-
- for (int i = 0; i < n_decoders; i++) {
- kv_swap_bufs[i].k.resize(ggml_nbytes(src[i].kv_self.k));
- kv_swap_bufs[i].v.resize(ggml_nbytes(src[i].kv_self.v));
- }
-
- for (auto & i : two_copy) {
- // make a copy of KV caches
- WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i);
- //memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
- //memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
- ggml_backend_tensor_get(src[i].kv_self.k, kv_swap_bufs[i].k.data(), 0, kv_swap_bufs[i].k.size());
- ggml_backend_tensor_get(src[i].kv_self.v, kv_swap_bufs[i].v.data(), 0, kv_swap_bufs[i].v.size());
- }
-
- // since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
- for (auto & i : two_copy) {
- // skip the decoder indices that require pointer swapping
- if (p_swap_set.find(i) != p_swap_set.end()) {
- continue;
- }
-
- if (two_copy.find(view[i]) != two_copy.end()) {
- // modify KV caches of decoder using data from kv_swap_bufs
- WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
- //memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
- //memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
- ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size());
- ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size());
- } else {
- // modify KV caches of decoder using data from correspond decoder KV caches directly
- WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
- //memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
- //memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
- ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k);
- ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v);
- }
- }
-
- // then modify one-copy decoder KV caches
- for (auto & i : one_copy) {
- // skip the decoder indices that require pointer swapping
- if (p_swap_set.find(i) != p_swap_set.end()) {
- continue;
- }
-
- if (two_copy.find(view[i]) != two_copy.end()) {
- // modify KV caches of decoder using data from kv_swap_bufs
- WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
- //memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
- //memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
- ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size());
- ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size());
- } else {
- // modify KV caches of decoder using data from correspond decoder KV caches directly
- WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
- //memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
- //memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
- ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k);
- ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v);
- }
- }
-
- // swap the pointers
- for (auto & i : p_swap_vec) {
- WHISPER_PRINT_DEBUG("%s: swap pointers: %d <-> %d\n", __func__, i.first, i.second);
- std::swap(src[i.first].kv_self, src[i.second].kv_self);
- }
-
- return true;
-}
-
int whisper_full_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
n_decoders = std::max(1, n_decoders);
+ if (n_decoders > WHISPER_MAX_DECODERS) {
+ WHISPER_LOG_ERROR("%s: too many decoders requested (%d), max = %d\n", __func__, n_decoders, WHISPER_MAX_DECODERS);
+ return -4;
+ }
+
// TAGS: WHISPER_DECODER_INIT
for (int j = 1; j < n_decoders; j++) {
auto & decoder = state->decoders[j];
- if (decoder.kv_self.ctx == nullptr) {
- decoder.kv_self = state->decoders[0].kv_self;
- if (!kv_cache_reinit(decoder.kv_self, ctx->backend)) {
- WHISPER_LOG_ERROR("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
- return -4;
- }
-
- WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
+ decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
- decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
+ decoder.probs.resize (ctx->vocab.n_vocab);
+ decoder.logits.resize (ctx->vocab.n_vocab);
+ decoder.logprobs.resize(ctx->vocab.n_vocab);
+ decoder.logits_id.reserve(ctx->model.hparams.n_vocab);
- decoder.probs.resize (ctx->vocab.n_vocab);
- decoder.logits.resize (ctx->vocab.n_vocab);
- decoder.logprobs.resize(ctx->vocab.n_vocab);
- }
+ decoder.rng = std::mt19937(0);
}
// the accumulated text context so far
state->exp_n_audio_ctx = params.audio_ctx;
// these tokens determine the task that will be performed
- std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
+ std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx), };
if (whisper_is_multilingual(ctx)) {
const int lang_id = whisper_lang_id(params.language);
}
}
+ // distilled models require the "no_timestamps" token
{
const bool is_distil = ctx->model.hparams.n_text_layer == 2;
-
- // distilled models require the "no_timestamps" token
- // TODO: add input parameter (#1229)
- if (is_distil) {
+ if (is_distil && !params.no_timestamps) {
WHISPER_LOG_WARN("%s: using distilled model - forcing no_timestamps\n", __func__);
- prompt_init.push_back(whisper_token_not(ctx));
+ params.no_timestamps = true;
}
}
+ if (params.no_timestamps) {
+ prompt_init.push_back(whisper_token_not(ctx));
+ }
+
int seek = seek_start;
std::vector<whisper_token> prompt;
bool has_ts;
whisper_sequence sequence;
+ whisper_grammar grammar;
};
+ std::vector<std::vector<beam_candidate>> bc_per_dec(n_decoders);
std::vector<beam_candidate> beam_candidates;
// main loop
n_decoders_cur = std::max(1, n_decoders_cur);
- WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur);
+ WHISPER_PRINT_DEBUG("\n%s: strategy = %d, decoding with %d decoders, temperature = %.2f\n", __func__, params.strategy, n_decoders_cur, t_cur);
// TAGS: WHISPER_DECODER_INIT
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];
- decoder.kv_self.n = 0;
-
decoder.sequence.tokens.clear();
decoder.sequence.result_len = 0;
decoder.sequence.sum_logprobs_all = 0.0;
decoder.failed = false;
decoder.completed = false;
decoder.has_ts = false;
+
+ if (params.grammar_rules != nullptr) {
+ decoder.grammar = whisper_grammar_init(params.grammar_rules, params.n_grammar_rules, params.i_start_rule);
+ } else {
+ decoder.grammar = {};
+ }
}
// init prompt and kv cache for the current iteration
- // run whisper_decoder() only for decoder 0 and copy the results for the other decoders
+ // TODO: do not recompute the prompt if it is the same as previous time
{
prompt.clear();
}
WHISPER_PRINT_DEBUG("\n\n");
- if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
+ whisper_kv_cache_clear(state->kv_self);
+
+ whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
+
+ if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
return -7;
}
{
const int64_t t_start_sample_us = ggml_time_us();
- whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur);
+ state->decoders[0].i_batch = prompt.size() - 1;
- state->decoders[0].kv_self.n += prompt.size();
+ whisper_process_logits(*ctx, *state, state->decoders[0], params, t_cur);
for (int j = 1; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];
- // TODO: fix CUDA
- //memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
- //memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
- ggml_backend_tensor_copy(state->decoders[0].kv_self.k, decoder.kv_self.k);
- ggml_backend_tensor_copy(state->decoders[0].kv_self.v, decoder.kv_self.v);
-
- decoder.kv_self.n += prompt.size();
+ whisper_kv_cache_seq_cp(state->kv_self, 0, j, -1, -1);
memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
const int64_t t_start_sample_us = ggml_time_us();
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
- beam_candidates.clear();
+ for (auto & bc : bc_per_dec) {
+ bc.clear();
+ }
}
- // generate new sequence candidates for each decoder
- for (int j = 0; j < n_decoders_cur; ++j) {
- auto & decoder = state->decoders[j];
+ // sampling
+ // TODO: avoid memory allocations, optimize, avoid threads?
+ {
+ std::atomic<int> j_cur(0);
- if (decoder.completed || decoder.failed) {
- continue;
- }
+ auto process = [&]() {
+ while (true) {
+ const int j = j_cur.fetch_add(1);
- switch (params.strategy) {
- case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
- {
- if (t_cur < 1e-6f) {
- decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, true));
- } else {
- decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, false));
- }
+ if (j >= n_decoders_cur) {
+ break;
+ }
- decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
- } break;
- case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
- {
- const auto tokens_new = whisper_sample_token_topk(*ctx, *state, decoder, params.beam_search.beam_size);
+ auto & decoder = state->decoders[j];
- for (const auto & token : tokens_new) {
- beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence });
- beam_candidates.back().sequence.tokens.push_back(token);
- beam_candidates.back().sequence.sum_logprobs_all += token.plog;
+ if (decoder.completed || decoder.failed) {
+ continue;
+ }
- //WHISPER_PRINT_DEBUG("%s: beam candidate: %s (%f, %f)\n", __func__, ctx->vocab.id_to_token.at(token.id).c_str(), token.plog, beam_candidates.back().sequence.sum_logprobs_all);
- }
- } break;
+ switch (params.strategy) {
+ case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
+ {
+ if (t_cur < 1e-6f) {
+ decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
+ } else {
+ decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
+ }
+
+ decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
+ } break;
+ case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
+ {
+ const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
+
+ for (const auto & token : tokens_new) {
+ bc_per_dec[j].push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, });
+ bc_per_dec[j].back().sequence.tokens.push_back(token);
+ bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog;
+ }
+ } break;
+ };
+ }
};
+
+ const int n_threads = std::min(params.n_threads, n_decoders_cur);
+
+ if (n_threads == 1) {
+ process();
+ } else {
+ std::vector<std::thread> threads(n_threads - 1);
+
+ for (int t = 0; t < n_threads - 1; ++t) {
+ threads[t] = std::thread(process);
+ }
+
+ process();
+
+ for (int t = 0; t < n_threads - 1; ++t) {
+ threads[t].join();
+ }
+ }
+ }
+
+ beam_candidates.clear();
+ for (const auto & bc : bc_per_dec) {
+ beam_candidates.insert(beam_candidates.end(), bc.begin(), bc.end());
+
+ if (!bc.empty()) {
+ state->n_sample += 1;
+ }
}
// for beam-search, choose the top candidates and update the KV caches
});
uint32_t cur_c = 0;
- std::vector<int> decoder_idx(n_decoders_cur, -1);
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];
continue;
}
+ if (cur_c >= beam_candidates.size()) {
+ cur_c = 0;
+ }
+
auto & cur = beam_candidates[cur_c++];
while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
++cur_c;
}
- decoder.sequence = cur.sequence;
decoder.seek_delta = cur.seek_delta;
decoder.has_ts = cur.has_ts;
+ decoder.sequence = cur.sequence;
+ decoder.grammar = cur.grammar;
+
+ whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1);
- decoder_idx[j] = cur.decoder_idx;
WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
__func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
}
- // update KV caches
- whisper_kv_swap_fast(decoder_idx, state->decoders, state->kv_swap_bufs, n_decoders_cur);
+ for (int j = 0; j < n_decoders_cur; ++j) {
+ auto & decoder = state->decoders[j];
+
+ if (decoder.completed || decoder.failed) {
+ continue;
+ }
+
+ whisper_kv_cache_seq_rm(state->kv_self, j, -1, -1);
+ whisper_kv_cache_seq_cp(state->kv_self, WHISPER_MAX_DECODERS + j, j, -1, -1);
+ whisper_kv_cache_seq_rm(state->kv_self, WHISPER_MAX_DECODERS + j, -1, -1);
+ }
}
// update the decoder state
has_ts = true;
}
+ whisper_grammar_accept_token(*ctx, decoder.grammar, token.id);
+
#ifdef WHISPER_DEBUG
{
const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]";
state->t_sample_us += ggml_time_us() - t_start_sample_us;
// obtain logits for the next token
- for (int j = 0; j < n_decoders_cur; ++j) {
- auto & decoder = state->decoders[j];
+ {
+ auto & batch = state->batch;
- if (decoder.failed || decoder.completed) {
- continue;
- }
+ batch.n_tokens = 0;
- decoder.tokens_tmp.resize(1);
- decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id;
+ const int n_past = prompt.size() + i;
- //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
+ for (int j = 0; j < n_decoders_cur; ++j) {
+ auto & decoder = state->decoders[j];
- if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
+ if (decoder.failed || decoder.completed) {
+ continue;
+ }
+
+ //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta);
+
+ decoder.i_batch = batch.n_tokens;
+
+ batch.token [batch.n_tokens] = decoder.sequence.tokens.back().id;
+ batch.pos [batch.n_tokens] = n_past;
+ batch.n_seq_id[batch.n_tokens] = 1;
+ batch.seq_id [batch.n_tokens][0] = j;
+ batch.logits [batch.n_tokens] = 1;
+ batch.n_tokens++;
+ }
+
+ assert(batch.n_tokens > 0);
+
+ if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
return -8;
}
+ const int64_t t_start_sample_us = ggml_time_us();
+
+ // TODO: avoid memory allocations, optimize, avoid threads?
{
- const int64_t t_start_sample_us = ggml_time_us();
+ std::atomic<int> j_cur(0);
+
+ auto process = [&]() {
+ while (true) {
+ const int j = j_cur.fetch_add(1);
- whisper_process_logits(*ctx, *state, params, decoder, t_cur);
+ if (j >= n_decoders_cur) {
+ break;
+ }
- ++decoder.kv_self.n;
+ auto & decoder = state->decoders[j];
+
+ if (decoder.failed || decoder.completed) {
+ continue;
+ }
+
+ whisper_process_logits(*ctx, *state, decoder, params, t_cur);
+ }
+ };
- state->t_sample_us += ggml_time_us() - t_start_sample_us;
+ const int n_threads = std::min(params.n_threads, n_decoders_cur);
+
+ if (n_threads == 1) {
+ process();
+ } else {
+ std::vector<std::thread> threads(n_threads - 1);
+
+ for (int t = 0; t < n_threads - 1; ++t) {
+ threads[t] = std::thread(process);
+ }
+
+ process();
+
+ for (int t = 0; t < n_threads - 1; ++t) {
+ threads[t].join();
+ }
+ }
}
+
+ state->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}
ctx->state->t_sample_us += states[i]->t_sample_us;
ctx->state->t_encode_us += states[i]->t_encode_us;
ctx->state->t_decode_us += states[i]->t_decode_us;
+ ctx->state->t_batchd_us += states[i]->t_batchd_us;
ctx->state->t_prompt_us += states[i]->t_prompt_us;
ctx->state->n_sample += states[i]->n_sample;
ctx->state->n_encode += states[i]->n_encode;
ctx->state->n_decode += states[i]->n_decode;
+ ctx->state->n_batchd += states[i]->n_batchd;
ctx->state->n_prompt += states[i]->n_prompt;
whisper_free_state(states[i]);
size_t n = 20;
size_t arr = n_threads > 0 ? 1024llu : n_threads; // trick to avoid compiler optimizations
- // 1GB MB array
- const size_t size = arr*1024llu*1024llu;
+ // 1GB array
+ const size_t size = arr*1e6;
// single-thread
{
src[rand() % size] = rand() % 256;
}
- snprintf(strbuf, sizeof(strbuf), "memcpy: %.2f GB/s (1 thread)\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu));
+ snprintf(strbuf, sizeof(strbuf), "memcpy: %.2f GB/s (1 thread)\n", (double) (n*size)/(tsum*1e9));
s += strbuf;
// needed to prevent the compiler from optimizing the memcpy away