From: Georgi Gerganov Date: Fri, 17 Nov 2023 08:00:11 +0000 (+0200) Subject: sync : whisper.cpp (update whisper example + minor) (#613) X-Git-Tag: upstream/0.0.1642~1194 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=98b3155750db8c23e1cc577fdee4b7e67f3c8ef0;p=pkg%2Fggml%2Fsources%2Fggml sync : whisper.cpp (update whisper example + minor) (#613) ggml-ci --- diff --git a/examples/common-ggml.cpp b/examples/common-ggml.cpp index 33ae03ae..e69bd510 100644 --- a/examples/common-ggml.cpp +++ b/examples/common-ggml.cpp @@ -9,6 +9,11 @@ static const std::map GGML_FTYPE_MAP = { {"q5_0", GGML_FTYPE_MOSTLY_Q5_0}, {"q5_1", GGML_FTYPE_MOSTLY_Q5_1}, {"q8_0", GGML_FTYPE_MOSTLY_Q8_0}, + {"q2_k", GGML_FTYPE_MOSTLY_Q2_K}, + {"q3_k", GGML_FTYPE_MOSTLY_Q3_K}, + {"q4_k", GGML_FTYPE_MOSTLY_Q4_K}, + {"q5_k", GGML_FTYPE_MOSTLY_Q5_K}, + {"q6_k", GGML_FTYPE_MOSTLY_Q6_K}, }; void ggml_print_ftypes(FILE * fp) { @@ -48,15 +53,15 @@ bool ggml_common_quantize_0( case GGML_FTYPE_MOSTLY_Q5_0: qtype = GGML_TYPE_Q5_0; break; case GGML_FTYPE_MOSTLY_Q5_1: qtype = GGML_TYPE_Q5_1; break; case GGML_FTYPE_MOSTLY_Q8_0: qtype = GGML_TYPE_Q8_0; break; + case GGML_FTYPE_MOSTLY_Q2_K: qtype = GGML_TYPE_Q2_K; break; + case GGML_FTYPE_MOSTLY_Q3_K: qtype = GGML_TYPE_Q3_K; break; + case GGML_FTYPE_MOSTLY_Q4_K: qtype = GGML_TYPE_Q4_K; break; + case GGML_FTYPE_MOSTLY_Q5_K: qtype = GGML_TYPE_Q5_K; break; + case GGML_FTYPE_MOSTLY_Q6_K: qtype = GGML_TYPE_Q6_K; break; case GGML_FTYPE_UNKNOWN: case GGML_FTYPE_ALL_F32: case GGML_FTYPE_MOSTLY_F16: case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: - case GGML_FTYPE_MOSTLY_Q2_K: - case GGML_FTYPE_MOSTLY_Q3_K: - case GGML_FTYPE_MOSTLY_Q4_K: - case GGML_FTYPE_MOSTLY_Q5_K: - case GGML_FTYPE_MOSTLY_Q6_K: { fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype); return false; @@ -167,24 +172,17 @@ bool ggml_common_quantize_0( switch ((ggml_type) ttype) { case GGML_TYPE_Q4_0: - { - cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); - } break; case GGML_TYPE_Q4_1: - { - cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); - } break; case GGML_TYPE_Q5_0: - { - cur_size = ggml_quantize_q5_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); - } break; case GGML_TYPE_Q5_1: - { - cur_size = ggml_quantize_q5_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); - } break; case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: { - cur_size = ggml_quantize_q8_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); + cur_size = ggml_quantize_chunk((ggml_type) ttype, data_f32.data(), work.data(), 0, nelements, hist_cur.data()); } break; case GGML_TYPE_F32: case GGML_TYPE_F16: @@ -192,11 +190,6 @@ bool ggml_common_quantize_0( case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_Q8_1: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: case GGML_TYPE_Q8_K: case GGML_TYPE_COUNT: { diff --git a/examples/whisper/main.cpp b/examples/whisper/main.cpp index e43dfe3f..98af5839 100644 --- a/examples/whisper/main.cpp +++ b/examples/whisper/main.cpp @@ -62,8 +62,8 @@ struct whisper_params { int32_t progress_step = 5; int32_t max_context = -1; int32_t max_len = 0; - int32_t best_of = 2; - int32_t beam_size = -1; + int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of; + int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size; float word_thold = 0.01f; float entropy_thold = 2.40f; @@ -925,9 +925,9 @@ int main(int argc, char ** argv) { if (params.detect_language) { params.language = "auto"; } - fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n", + fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, %d beams + best of %d, lang = %s, task = %s, %stimestamps = %d ...\n", __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, - params.n_threads, params.n_processors, + params.n_threads, params.n_processors, params.beam_size, params.best_of, params.language.c_str(), params.translate ? "translate" : "transcribe", params.tinydiarize ? "tdrz = 1, " : "", diff --git a/examples/whisper/whisper.cpp b/examples/whisper/whisper.cpp index 1c7d7e94..e2197ff2 100644 --- a/examples/whisper/whisper.cpp +++ b/examples/whisper/whisper.cpp @@ -20,6 +20,7 @@ #include "ggml-alloc.h" #include "ggml-backend.h" +#include #include #include #define _USE_MATH_DEFINES @@ -147,7 +148,7 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text //#define WHISPER_USE_FLASH_ATTN //#define WHISPER_USE_FLASH_FF -#define WHISPER_MAX_DECODERS 16 +#define WHISPER_MAX_DECODERS 8 #define WHISPER_MAX_NODES 4096 // @@ -406,6 +407,121 @@ struct whisper_segment { bool speaker_turn_next; }; +struct whisper_batch { + int32_t n_tokens; + + whisper_token * token; + whisper_pos * pos; + int32_t * n_seq_id; + whisper_seq_id ** seq_id; // null terminated + int8_t * logits; +}; + +static struct whisper_batch whisper_batch_init(int32_t n_tokens, int32_t n_seq_max) { + whisper_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, }; + + batch.token = (whisper_token * ) malloc(sizeof(whisper_token) * (n_tokens)); + batch.pos = (whisper_pos *) malloc(sizeof(whisper_pos) * (n_tokens)); + batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * (n_tokens)); + batch.seq_id = (whisper_seq_id **) malloc(sizeof(whisper_seq_id *) * (n_tokens + 1)); + for (int i = 0; i < n_tokens; ++i) { + batch.seq_id[i] = (whisper_seq_id *) malloc(sizeof(whisper_seq_id) * n_seq_max); + } + batch.seq_id[n_tokens] = nullptr; + batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + + return batch; +} + +static void whisper_batch_free(struct whisper_batch batch) { + if (batch.token) free(batch.token); + if (batch.pos) free(batch.pos); + if (batch.n_seq_id) free(batch.n_seq_id); + if (batch.seq_id) { + for (int i = 0; batch.seq_id[i]; ++i) { + free(batch.seq_id[i]); + } + free(batch.seq_id); + } + if (batch.logits) free(batch.logits); +} + +static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token * tokens, int n_tokens, int n_past, int seq_id) { + batch.n_tokens = n_tokens; + for (int i = 0; i < n_tokens; ++i) { + if (tokens) { + batch.token[i] = tokens[i]; + } + batch.pos [i] = n_past + i; + batch.n_seq_id[i] = 1; + batch.seq_id [i][0] = seq_id; + batch.logits [i] = 0; + } + batch.logits[n_tokens - 1] = 1; +} + +// replace std::pair by using customized pair struct (reason: std::pair is very slow) +template +struct whisper_pair { + A first; + B second; + + // Define a constructor that takes two arguments. + whisper_pair(const A& a, const B& b) : first(a), second(b) {} + // Define a constructor that takes no argument. + whisper_pair() : first(A()), second(B()) {} +}; + +// ggml_allocr wrapper for whisper usage +struct whisper_allocr { + ggml_allocr * alloc = nullptr; + + std::vector meta; + + ggml_backend_buffer_t buffer; +}; + +static size_t whisper_allocr_size(struct whisper_allocr & allocr) { + return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc); +} + +// measure the memory usage of a graph and prepare the allocr's internal data buffer +static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function && get_graph) { + auto & alloc = allocr.alloc; + auto & meta = allocr.meta; + + alloc = ggml_allocr_new_measure_from_backend(backend); + + meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead()); + + ggml_allocr_alloc_graph(alloc, get_graph()); +} + +static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) { + if (allocr.alloc == nullptr) { + // this can be null if we use external encoder like CoreML or OpenVINO + return; + } + + auto & alloc = allocr.alloc; + auto & buffer = allocr.buffer; + + size_t size = ggml_allocr_max_size(alloc); + + ggml_allocr_free(alloc); + + buffer = ggml_backend_alloc_buffer(backend, size); + alloc = ggml_allocr_new_from_buffer(buffer); +} + +static void whisper_allocr_free(struct whisper_allocr & allocr) { + if (allocr.alloc) { + ggml_allocr_free(allocr.alloc); + ggml_backend_buffer_free(allocr.buffer); + allocr.alloc = nullptr; + } +} + // medium // hparams: { // 'n_mels': 80, @@ -523,15 +639,31 @@ struct whisper_layer_decoder { struct ggml_tensor * mlp_1_b; }; +struct whisper_kv_cell { + whisper_pos pos = -1; + + std::set seq_id; + + bool has_seq_id(const whisper_seq_id & id) const { + return seq_id.find(id) != seq_id.end(); + } +}; + struct whisper_kv_cache { + uint32_t head = 0; + uint32_t size = 0; + + // computed before each graph build + uint32_t n = 0; + + std::vector cells; + struct ggml_tensor * k; struct ggml_tensor * v; struct ggml_context * ctx; ggml_backend_buffer_t buffer; - - int n; // number of tokens currently in the cache }; struct whisper_model { @@ -579,6 +711,25 @@ struct whisper_model { std::map tensors; }; +struct whisper_partial_utf8 { + uint32_t value; // bit value so far (unshifted) + int n_remain; // num bytes remaining; -1 indicates invalid sequence +}; + +struct whisper_grammar { + /*const*/ std::vector> rules; + std::vector> stacks; + + // buffer for partially generated UTF-8 sequence from accepted tokens + whisper_partial_utf8 partial_utf8; +}; + +struct whisper_grammar_candidate { + whisper_token id; + const uint32_t * code_points; + whisper_partial_utf8 partial_utf8; +}; + struct whisper_sequence { std::vector tokens; @@ -594,12 +745,13 @@ struct whisper_sequence { // TAGS: WHISPER_DECODER_INIT struct whisper_decoder { - // each decoder keeps its own KV-cache - whisper_kv_cache kv_self; - // the currently generated sequence of tokens whisper_sequence sequence; + // grammar parse state of generated sequence of tokens + whisper_grammar grammar; + + int i_batch; // the index of the token in the current batch int seek_delta; // the window shift found so far based on the decoded timestamp tokens bool failed; // has the current segment failed to decode? @@ -611,100 +763,40 @@ struct whisper_decoder { std::vector logits; std::vector logprobs; - std::vector tokens_tmp; // used for whisper_decode calls -}; - -// replace std::pair by using customized pair struct (reason: std::pair is very slow) -template -struct whisper_pair { - A first; - B second; - - // Define a constructor that takes two arguments. - whisper_pair(const A& a, const B& b) : first(a), second(b) {} - // Define a constructor that takes no argument. - whisper_pair() : first(A()), second(B()) {} -}; - -// beam-search helpers -struct kv_buf { - std::vector k; - std::vector v; -}; - -// ggml_allocr wrapper for whisper usage -struct whisper_allocr { - ggml_allocr * alloc = nullptr; - - std::vector meta; + // work container used to avoid memory allocations + std::vector> logits_id; - ggml_backend_buffer_t buffer; + mutable std::mt19937 rng; // used for sampling at t > 0.0 }; -static size_t whisper_allocr_size(struct whisper_allocr & allocr) { - return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc); -} - -// measure the memory usage of a graph and prepare the allocr's internal data buffer -static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function && get_graph) { - auto & alloc = allocr.alloc; - auto & meta = allocr.meta; - - alloc = ggml_allocr_new_measure_from_backend(backend); - - meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead()); - - ggml_allocr_alloc_graph(alloc, get_graph()); -} - -static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) { - if (allocr.alloc == nullptr) { - // this can be null if we use external encoder like CoreML or OpenVINO - return; - } - - auto & alloc = allocr.alloc; - auto & buffer = allocr.buffer; - - size_t size = ggml_allocr_max_size(alloc); - - ggml_allocr_free(alloc); - - buffer = ggml_backend_alloc_buffer(backend, size); - alloc = ggml_allocr_new_from_buffer(buffer); -} - -static void whisper_allocr_free(struct whisper_allocr & allocr) { - if (allocr.alloc) { - ggml_allocr_free(allocr.alloc); - ggml_backend_buffer_free(allocr.buffer); - allocr.alloc = nullptr; - } -} - struct whisper_state { int64_t t_sample_us = 0; int64_t t_encode_us = 0; int64_t t_decode_us = 0; + int64_t t_batchd_us = 0; int64_t t_prompt_us = 0; int64_t t_mel_us = 0; int32_t n_sample = 0; // number of tokens sampled int32_t n_encode = 0; // number of encoder calls - int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation) - int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding) + int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation) + int32_t n_batchd = 0; // number of decoder calls with n_tokens < 16 (batch decoding) + int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding) int32_t n_fail_p = 0; // number of logprob threshold failures int32_t n_fail_h = 0; // number of entropy threshold failures + // unified self-attention KV cache for all decoders + whisper_kv_cache kv_self; + // cross-attention KV cache for the decoders // shared between all decoders whisper_kv_cache kv_cross; + whisper_mel mel; - whisper_decoder decoders[WHISPER_MAX_DECODERS] = {}; + whisper_batch batch; - // buffer for swapping KV caches between decoders during beam-search - std::vector kv_swap_bufs; + whisper_decoder decoders[WHISPER_MAX_DECODERS]; ggml_backend_t backend = nullptr; @@ -720,8 +812,9 @@ struct whisper_state { struct ggml_tensor * embd_conv = nullptr; struct ggml_tensor * embd_enc = nullptr; - // helper for GPU offloading + // helpers for GPU offloading std::vector inp_mel; + std::vector inp_mask; // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; @@ -729,11 +822,6 @@ struct whisper_state { std::vector result_all; std::vector prompt_past; - // work container used to avoid memory allocations - std::vector> logits_id; - - mutable std::mt19937 rng; // used for sampling at t > 0.0 - int lang_id = 0; // english by default std::string path_model; // populated by whisper_init_from_file_with_params() @@ -809,6 +897,12 @@ static bool kv_cache_init( /*.no_alloc =*/ true, }; + cache.head = 0; + cache.size = n_ctx; + + cache.cells.clear(); + cache.cells.resize(n_ctx); + cache.ctx = ggml_init(params); if (!cache.ctx) { @@ -836,54 +930,129 @@ static bool kv_cache_init( return true; } -// TODO: remove after batched decoding -static bool kv_cache_reinit(struct whisper_kv_cache & cache, ggml_backend_t backend) { - WHISPER_ASSERT(cache.ctx); +static void kv_cache_free(struct whisper_kv_cache & cache) { + if (cache.ctx) { + ggml_free(cache.ctx); + ggml_backend_buffer_free(cache.buffer); + cache.ctx = nullptr; + } +} - const int n_elements = ggml_nelements(cache.k); - WHISPER_ASSERT(n_elements == ggml_nelements(cache.v)); +static bool whisper_kv_cache_find_slot( + struct whisper_kv_cache & cache, + const struct whisper_batch & batch) { + const uint32_t n_ctx = cache.size; + const uint32_t n_tokens = batch.n_tokens; - const ggml_type wtype = cache.k->type; - WHISPER_ASSERT(wtype == cache.v->type); + if (n_tokens > n_ctx) { + WHISPER_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx); + return false; + } - struct ggml_init_params params = { - /*.mem_size =*/ 2*ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; + uint32_t n_tested = 0; - cache.ctx = ggml_init(params); + while (true) { + if (cache.head + n_tokens > n_ctx) { + n_tested += n_ctx - cache.head; + cache.head = 0; + continue; + } - if (!cache.ctx) { - WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__); - return false; + bool found = true; + for (uint32_t i = 0; i < n_tokens; i++) { + if (cache.cells[cache.head + i].pos >= 0) { + found = false; + cache.head += i + 1; + n_tested += i + 1; + break; + } + } + + if (found) { + break; + } + + if (n_tested >= n_ctx) { + //WHISPER_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return false; + } } - cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); - cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + for (uint32_t i = 0; i < n_tokens; i++) { + cache.cells[cache.head + i].pos = batch.pos[i]; - const size_t mem_bytes = ggml_nbytes(cache.k) + ggml_nbytes(cache.v); + for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { + cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]); + } + } - cache.buffer = ggml_backend_alloc_buffer(backend, mem_bytes); + return true; +} - // allocate the tensors into the backend buffer - { - ggml_allocr * alloc = ggml_allocr_new_from_buffer(cache.buffer); +// find how many cells are currently in use +static int32_t whisper_kv_cache_cell_max(const struct whisper_kv_cache & cache) { + for (uint32_t i = cache.size - 1; i > 0; --i) { + if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) { + return i + 1; + } + } - ggml_allocr_alloc(alloc, cache.k); - ggml_allocr_alloc(alloc, cache.v); + return 1; +} - ggml_allocr_free(alloc); +static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) { + for (int32_t i = 0; i < (int32_t) cache.size; ++i) { + cache.cells[i].pos = -1; + cache.cells[i].seq_id.clear(); } + cache.head = 0; +} - return true; +static void whisper_kv_cache_seq_rm( + struct whisper_kv_cache & cache, + whisper_seq_id seq_id, + whisper_pos p0, + whisper_pos p1) { + uint32_t new_head = cache.size; + + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + if (seq_id < 0) { + cache.cells[i].seq_id.clear(); + } else if (cache.cells[i].has_seq_id(seq_id)) { + cache.cells[i].seq_id.erase(seq_id); + } else { + continue; + } + if (cache.cells[i].seq_id.empty()) { + cache.cells[i].pos = -1; + if (new_head == cache.size) new_head = i; + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.size) cache.head = new_head; } -static void kv_cache_free(struct whisper_kv_cache & cache) { - if (cache.ctx) { - ggml_free(cache.ctx); - ggml_backend_buffer_free(cache.buffer); - cache.ctx = nullptr; +static void whisper_kv_cache_seq_cp( + struct whisper_kv_cache & cache, + whisper_seq_id seq_id_src, + whisper_seq_id seq_id_dst, + whisper_pos p0, + whisper_pos p1) { + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + + cache.head = 0; + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + cache.cells[i].seq_id.insert(seq_id_dst); + } } } @@ -892,7 +1061,7 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params // initialize the backends #ifdef GGML_USE_CUBLAS - if (params.use_gpu) { + if (params.use_gpu && ggml_cublas_loaded()) { WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__); backend_gpu = ggml_backend_cuda_init(); if (!backend_gpu) { @@ -1094,6 +1263,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con word = "[_EOT_]"; } else if (i == vocab.token_sot) { word = "[_SOT_]"; + } else if (i == vocab.token_translate) { + word = "[_TRANSLATE_]"; + } else if (i == vocab.token_transcribe) { + word = "[_TRANSCRIBE_]"; } else if (i == vocab.token_solm) { word = "[_SOLM_]"; } else if (i == vocab.token_prev) { @@ -1104,6 +1277,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con word = "[_NOT_]"; } else if (i == vocab.token_beg) { word = "[_BEG_]"; + } else if (i > vocab.token_sot && i <= vocab.token_sot + vocab.num_languages()) { + word = "[_LANG_" + std::string(whisper_lang_str(i - vocab.token_sot - 1)) + "]"; } else { word = "[_extra_token_" + std::to_string(i) + "]"; } @@ -1347,7 +1522,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con model.buffer = ggml_backend_alloc_buffer(wctx.backend, size_main); - WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend), size_main / 1024.0 / 1024.0); + WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend), size_main / 1e6); } ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer); @@ -1462,12 +1637,12 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor)); } - //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype), ggml_nbytes(tensor)/1024.0/1024.0); + //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype), ggml_nbytes(tensor)/1e6); total_size += ggml_nbytes(tensor); model.n_loaded++; } - WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0); + WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6); if (model.n_loaded == 0) { WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); @@ -1852,11 +2027,11 @@ static struct ggml_cgraph * whisper_build_graph_encoder( //////////////////////////////////////////////////////////////////////////// //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, - // ggml_used_mem(ctx0)/1024.0/1024.0, - // wstate.get_buf_max_mem(0)/1024.0/1024.0, - // wstate.get_buf_max_mem(1)/1024.0/1024.0, - // wstate.get_buf_max_mem(2)/1024.0/1024.0, - // wstate.get_buf_max_mem(3)/1024.0/1024.0); + // ggml_used_mem(ctx0)/1e6, + // wstate.get_buf_max_mem(0)/1e6, + // wstate.get_buf_max_mem(1)/1e6, + // wstate.get_buf_max_mem(2)/1e6, + // wstate.get_buf_max_mem(3)/1e6); ggml_free(ctx0); @@ -2009,26 +2184,28 @@ static bool whisper_encode_internal( static struct ggml_cgraph * whisper_build_graph_decoder( whisper_context & wctx, whisper_state & wstate, - whisper_decoder & decoder, - const whisper_token * tokens, - int n_tokens, - int n_past) { + const whisper_batch & batch) { const auto & model = wctx.model; const auto & hparams = model.hparams; - auto & kv_self = decoder.kv_self; + auto & kv_self = wstate.kv_self; WHISPER_ASSERT(!!kv_self.ctx); - const int n_ctx = hparams.n_text_ctx; + ggml_allocr * alloc = wstate.alloc_decode.alloc; + + const int n_ctx = kv_self.size; const int n_state = hparams.n_text_state; const int n_head = hparams.n_text_head; const int n_layer = hparams.n_text_layer; - const int N = n_tokens; - const int M = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; + const int n_tokens = batch.n_tokens; + const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; + + const int32_t n_kv = ggml_allocr_is_measure(alloc) ? n_ctx : kv_self.n; + const int32_t kv_head = ggml_allocr_is_measure(alloc) ? n_ctx - n_tokens : kv_self.head; - //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx); + //WHISPER_PRINT_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx); struct ggml_init_params params = { /*.mem_size =*/ wstate.alloc_decode.meta.size(), @@ -2040,21 +2217,19 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false); - ggml_allocr * alloc = wstate.alloc_decode.alloc; - - struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_allocr_alloc(alloc, embd); if (!ggml_allocr_is_measure(alloc)) { - ggml_backend_tensor_set(embd, tokens, 0, N*ggml_element_size(embd)); + ggml_backend_tensor_set(embd, batch.token, 0, n_tokens*ggml_element_size(embd)); } - struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_allocr_alloc(alloc, position); if (!ggml_allocr_is_measure(alloc)) { - for (int i = 0; i < N; ++i) { - const int32_t val = n_past + i; + for (int i = 0; i < n_tokens; ++i) { + const int32_t val = batch.pos[i]; ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t)); } } @@ -2067,18 +2242,43 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float)); } - // token encoding + position encoding - struct ggml_tensor * cur = - ggml_add(ctx0, - ggml_get_rows(ctx0, model.d_te, embd), - ggml_get_rows(ctx0, model.d_pe, position)); + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + ggml_allocr_alloc(alloc, KQ_mask); - struct ggml_tensor * inpL = cur; + if (!ggml_allocr_is_measure(alloc)) { + wstate.inp_mask.resize(n_kv*n_tokens); - for (int il = 0; il < n_layer; ++il) { - const auto & layer = model.layers_decoder[il]; + float * data = wstate.inp_mask.data(); + memset(data, 0, ggml_nbytes(KQ_mask)); - // norm + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const whisper_pos pos = batch.pos[j]; + const whisper_seq_id seq_id = batch.seq_id[j][0]; + + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + } + } + + ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float)); + } + + // token encoding + position encoding + struct ggml_tensor * cur = + ggml_add(ctx0, + ggml_get_rows(ctx0, model.d_te, embd), + ggml_get_rows(ctx0, model.d_pe, position)); + + struct ggml_tensor * inpL = cur; + + for (int il = 0; il < n_layer; ++il) { + const auto & layer = model.layers_decoder[il]; + + // norm { cur = ggml_norm(ctx0, inpL, hparams.eps); @@ -2119,12 +2319,12 @@ static struct ggml_cgraph * whisper_build_graph_decoder( Vcur, layer.attn_v_b); - Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N)); + Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens)); - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past)); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_state, + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head)); + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state, ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v)); + (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); @@ -2134,12 +2334,12 @@ static struct ggml_cgraph * whisper_build_graph_decoder( struct ggml_tensor * Q = ggml_permute(ctx0, - ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N), + ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens), 0, 2, 1, 3); struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, - n_state/n_head, n_past + N, n_head, + n_state/n_head, n_kv, n_head, ggml_element_size(kv_self.k)*n_state, ggml_element_size(kv_self.k)*n_state/n_head, ggml_element_size(kv_self.k)*n_state*n_ctx*il); @@ -2149,16 +2349,17 @@ static struct ggml_cgraph * whisper_build_graph_decoder( //struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past); + //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past); + struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ, KQ_mask); struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, - n_past + N, n_state/n_head, n_head, + n_kv, n_state/n_head, n_head, n_ctx*ggml_element_size(kv_self.v), n_ctx*ggml_element_size(kv_self.v)*n_state/n_head, - il*n_ctx*ggml_element_size(kv_self.v)*n_state); + n_ctx*ggml_element_size(kv_self.v)*n_state*il); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); @@ -2166,7 +2367,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( cur = ggml_cpy(ctx0, KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N)); + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens)); } // projection @@ -2210,33 +2411,33 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // Kcross is already scaled struct ggml_tensor * Kcross = ggml_view_3d(ctx0, wstate.kv_cross.k, - n_state/n_head, M, n_head, + n_state/n_head, n_audio_ctx, n_head, ggml_element_size(wstate.kv_cross.k)*n_state, ggml_element_size(wstate.kv_cross.k)*n_state/n_head, - ggml_element_size(wstate.kv_cross.k)*n_state*M*il); + ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il); //struct ggml_tensor * Vcross = // ggml_reshape_3d(ctx0, - // ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state), - // n_state/n_head, n_head, M); + // ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state), + // n_state/n_head, n_head, n_audio_ctx); //struct ggml_tensor * V_trans = // ggml_cpy(ctx0, // ggml_permute(ctx0, Vcross, 1, 2, 0, 3), - // ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head)); + // ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head)); struct ggml_tensor * V = ggml_view_3d(ctx0, wstate.kv_cross.v, - M, n_state/n_head, n_head, - M*ggml_element_size(wstate.kv_cross.v), - M*ggml_element_size(wstate.kv_cross.v)*n_state/n_head, - il*M*ggml_element_size(wstate.kv_cross.v)*n_state); + n_audio_ctx, n_state/n_head, n_head, + n_audio_ctx*ggml_element_size(wstate.kv_cross.v), + n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state/n_head, + n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il); // ------ struct ggml_tensor * Q = ggml_permute(ctx0, - ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N), + ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens), 0, 2, 1, 3); // K * Q @@ -2257,10 +2458,10 @@ static struct ggml_cgraph * whisper_build_graph_decoder( struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - // cur = KQV_merged.contiguous().view(n_state, N) + // cur = KQV_merged.contiguous().view(n_state, n_tokens) cur = ggml_cpy(ctx0, KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N)); + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens)); } // projection @@ -2332,9 +2533,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder( } // compute logits only for the last token - // comment this line to compute logits for all N tokens + // comment this line to compute logits for all n_tokens // might be useful in the future - cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]); + //cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]); struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur); @@ -2358,10 +2559,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( static bool whisper_decode_internal( whisper_context & wctx, whisper_state & wstate, - whisper_decoder & decoder, - const whisper_token * tokens, - const int n_tokens, - const int n_past, + const whisper_batch & batch, const int n_threads, whisper_abort_callback abort_callback, void * abort_callback_data) { @@ -2370,19 +2568,33 @@ static bool whisper_decode_internal( const auto & model = wctx.model; const auto & hparams = model.hparams; - const int n_vocab = hparams.n_vocab; + const int n_vocab = hparams.n_vocab; + const int n_tokens = batch.n_tokens; auto & logits_out = wstate.logits; struct ggml_tensor * logits; + // find KV slot for the batch + { + auto & kv_self = wstate.kv_self; + + if (!whisper_kv_cache_find_slot(kv_self, batch)) { + return false; + } + + kv_self.n = whisper_kv_cache_cell_max(kv_self); + //kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self))); + //printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]); + } + // decoder { auto & alloc = wstate.alloc_decode.alloc; ggml_allocr_reset(alloc); - ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, tokens, n_tokens, n_past); + ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch); ggml_allocr_alloc_graph(alloc, gf); @@ -2391,37 +2603,37 @@ static bool whisper_decode_internal( ggml_graph_compute_helper(wstate.backend, gf, n_threads); } - // extract logits for all N tokens - //logits_out.resize(n_tokens*n_vocab); - //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab); - //ggml_backend_tensor_get(logits, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), sizeof(float)*n_vocab); - - // extract logits only for the last token - logits_out.resize(n_vocab); - //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab); - ggml_backend_tensor_get(logits, logits_out.data(), 0, sizeof(float)*n_vocab); + logits_out.resize(n_tokens*n_vocab); + for (int i = 0; i < n_tokens; i++) { + if (batch.logits[i] == 0) { + continue; + } + ggml_backend_tensor_get(logits, logits_out.data() + (n_vocab*i), sizeof(float)*(n_vocab*i), sizeof(float)*n_vocab); + } - if (n_tokens > 1) { + if (batch.n_tokens > 1) { //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, - // ggml_used_mem(ctx0)/1024.0/1024.0, - // wstate.get_buf_max_mem(0)/1024.0/1024.0, - // wstate.get_buf_max_mem(1)/1024.0/1024.0, - // wstate.get_buf_max_mem(2)/1024.0/1024.0, - // wstate.get_buf_max_mem(3)/1024.0/1024.0); + // ggml_used_mem(ctx0)/1e6, + // wstate.get_buf_max_mem(0)/1e6, + // wstate.get_buf_max_mem(1)/1e6, + // wstate.get_buf_max_mem(2)/1e6, + // wstate.get_buf_max_mem(3)/1e6); } - if (n_tokens == 1) { + if (batch.n_tokens == 1) { wstate.t_decode_us += ggml_time_us() - t_start_us; wstate.n_decode++; + } else if (batch.n_tokens < 16) { + wstate.t_batchd_us += ggml_time_us() - t_start_us; + wstate.n_batchd += n_tokens; } else { wstate.t_prompt_us += ggml_time_us() - t_start_us; - wstate.n_prompt++; + wstate.n_prompt += n_tokens; } return !(abort_callback && abort_callback(abort_callback_data)); } - // 500 -> 00:05.000 // 6000 -> 01:00.000 static std::string to_timestamp(int64_t t, bool comma = false) { @@ -2833,15 +3045,19 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { state->backend = whisper_backend_init(ctx->params); - if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) { + // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx + // in theory, there can be a case where this is not enough, but in practice it should always be enough + const int factor = 3; + + if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) { WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); delete state; return nullptr; } { - const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v); - WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); + const size_t memory_size = ggml_nbytes(state->kv_self.k) + ggml_nbytes(state->kv_self.v); + WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6); } if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) { @@ -2852,7 +3068,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { { const size_t memory_size = ggml_nbytes(state->kv_cross.k) + ggml_nbytes(state->kv_cross.v); - WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); + WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6); } #ifdef WHISPER_USE_COREML @@ -2875,14 +3091,17 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx); - state->logits_id.reserve(ctx->model.hparams.n_vocab); + state->batch = whisper_batch_init(ctx->model.hparams.n_text_ctx, WHISPER_MAX_DECODERS); // TAGS: WHISPER_DECODER_INIT state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx); - state->decoders[0].probs.reserve (ctx->vocab.n_vocab); - state->decoders[0].logits.reserve (ctx->vocab.n_vocab); - state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab); + state->decoders[0].probs.reserve (ctx->vocab.n_vocab); + state->decoders[0].logits.reserve (ctx->vocab.n_vocab); + state->decoders[0].logprobs.reserve (ctx->vocab.n_vocab); + state->decoders[0].logits_id.reserve(ctx->model.hparams.n_vocab); + + state->decoders[0].rng = std::mt19937(0); // conv allocator { @@ -2891,7 +3110,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { return whisper_build_graph_conv(*ctx, *state, 0); }); - WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0); + WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1e6); } // encoder allocator @@ -2901,7 +3120,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { return whisper_build_graph_encoder(*ctx, *state); }); - WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0); + WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1e6); } // cross allocator @@ -2911,7 +3130,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { return whisper_build_graph_cross(*ctx, *state); }); - WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0); + WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1e6); } // decoder allocator @@ -2924,10 +3143,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { const int n_tokens = hparams.n_text_ctx; const int n_past = 0; - return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past); + whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0); + + return whisper_build_graph_decoder(*ctx, *state, state->batch); }); - WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0); + WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1e6); } whisper_allocr_graph_realloc(state->alloc_conv, ctx->backend); @@ -2935,8 +3156,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend); whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend); - state->rng = std::mt19937(0); - return state; } @@ -3161,12 +3380,9 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa void whisper_free_state(struct whisper_state * state) { if (state) { + kv_cache_free(state->kv_self); kv_cache_free(state->kv_cross); - for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) { - kv_cache_free(state->decoders[i].kv_self); - } - #ifdef WHISPER_USE_COREML if (state->ctx_coreml != nullptr) { whisper_coreml_free(state->ctx_coreml); @@ -3181,6 +3397,8 @@ void whisper_free_state(struct whisper_state * state) } #endif + whisper_batch_free(state->batch); + whisper_allocr_free(state->alloc_conv); whisper_allocr_free(state->alloc_encode); whisper_allocr_free(state->alloc_cross); @@ -3307,9 +3525,11 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { } int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { - const int selected_decoder_id = 0; + whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past, 0); - if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) { + whisper_kv_cache_seq_rm(ctx->state->kv_self, 0, n_past, -1); + + if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, nullptr, nullptr)) { WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); return 1; } @@ -3318,15 +3538,16 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state } int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { - // TODO: add selected_decoder_id to state - const int selected_decoder_id = 0; - if (ctx->state == nullptr) { WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__); return false; } - if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) { + whisper_kv_cache_seq_rm(ctx->state->kv_self, 0, n_past, -1); + + whisper_batch_prep_legacy(ctx->state->batch, tokens, n_tokens, n_past, 0); + + if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->batch, n_threads, nullptr, nullptr)) { WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); return 1; } @@ -3414,7 +3635,7 @@ int whisper_lang_auto_detect_with_state( return -7; } - auto & logits_id = state->logits_id; + auto & logits_id = state->decoders[0].logits_id; logits_id.clear(); for (const auto & kv : g_lang) { @@ -3617,6 +3838,7 @@ void whisper_print_timings(struct whisper_context * ctx) { const int32_t n_sample = std::max(1, ctx->state->n_sample); const int32_t n_encode = std::max(1, ctx->state->n_encode); const int32_t n_decode = std::max(1, ctx->state->n_decode); + const int32_t n_batchd = std::max(1, ctx->state->n_batchd); const int32_t n_prompt = std::max(1, ctx->state->n_prompt); WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); @@ -3624,6 +3846,7 @@ void whisper_print_timings(struct whisper_context * ctx) { WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); + WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd); WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt); } WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); @@ -3640,6 +3863,7 @@ void whisper_reset_timings(struct whisper_context * ctx) { ctx->state->n_sample = 0; ctx->state->n_encode = 0; ctx->state->n_decode = 0; + ctx->state->n_batchd = 0; ctx->state->n_prompt = 0; } } @@ -3685,6 +3909,424 @@ const char * whisper_print_system_info(void) { return s.c_str(); } +////////////////////////////////// +// Grammar - ported from llama.cpp +////////////////////////////////// + +// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as +// pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`. +std::pair, whisper_partial_utf8> decode_utf8( + const char * src, + whisper_partial_utf8 partial_start) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; + const char * pos = src; + std::vector code_points; + uint32_t value = partial_start.value; + int n_remain = partial_start.n_remain; + + // continue previous decode, if applicable + while (*pos != 0 && n_remain > 0) { + uint8_t next_byte = static_cast(*pos); + if ((next_byte >> 6) != 2) { + // invalid sequence, abort + code_points.push_back(0); + return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 }); + } + value = (value << 6) + (next_byte & 0x3F); + ++pos; + --n_remain; + } + + if (partial_start.n_remain > 0 && n_remain == 0) { + code_points.push_back(value); + } + + // decode any subsequent utf-8 sequences, which may end in an incomplete one + while (*pos != 0) { + uint8_t first_byte = static_cast(*pos); + uint8_t highbits = first_byte >> 4; + n_remain = lookup[highbits] - 1; + + if (n_remain < 0) { + // invalid sequence, abort + code_points.clear(); + code_points.push_back(0); + return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain }); + } + + uint8_t mask = (1 << (7 - n_remain)) - 1; + value = first_byte & mask; + ++pos; + while (*pos != 0 && n_remain > 0) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + ++pos; + --n_remain; + } + if (n_remain == 0) { + code_points.push_back(value); + } + } + code_points.push_back(0); + + return std::make_pair(std::move(code_points), whisper_partial_utf8{ value, n_remain }); +} + +// returns true iff pos points to the end of one of the definitions of a rule +static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element * pos) { + switch (pos->type) { + case WHISPER_GRETYPE_END: return true; // NOLINT + case WHISPER_GRETYPE_ALT: return true; // NOLINT + default: return false; + } +} + +// returns true iff chr satisfies the char range at pos (regular or inverse range) +// asserts that pos is pointing to a char range element +static std::pair whisper_grammar_match_char( + const whisper_grammar_element * pos, + const uint32_t chr) { + + bool found = false; + bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR; + + WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT + + do { + if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) { + // inclusive range, e.g. [a-z] + found = found || (pos->value <= chr && chr <= pos[1].value); + pos += 2; + } else { + // exact char match, e.g. [a] or "a" + found = found || pos->value == chr; + pos += 1; + } + } while (pos->type == WHISPER_GRETYPE_CHAR_ALT); + + return std::make_pair(found == is_positive_char, pos); +} + +// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char +// range at pos (regular or inverse range) +// asserts that pos is pointing to a char range element +static bool whisper_grammar_match_partial_char( + const whisper_grammar_element * pos, + const whisper_partial_utf8 partial_utf8) { + + bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR; + WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); + + uint32_t partial_value = partial_utf8.value; + int n_remain = partial_utf8.n_remain; + + // invalid sequence or 7-bit char split across 2 bytes (overlong) + if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) { + return false; + } + + // range of possible code points this partial UTF-8 sequence could complete to + uint32_t low = partial_value << (n_remain * 6); + uint32_t high = low | ((1 << (n_remain * 6)) - 1); + + if (low == 0) { + if (n_remain == 2) { + low = 1 << 11; + } else if (n_remain == 3) { + low = 1 << 16; + } + } + + do { + if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) { + // inclusive range, e.g. [a-z] + if (pos->value <= high && low <= pos[1].value) { + return is_positive_char; + } + pos += 2; + } else { + // exact char match, e.g. [a] or "a" + if (low <= pos->value && pos->value <= high) { + return is_positive_char; + } + pos += 1; + } + } while (pos->type == WHISPER_GRETYPE_CHAR_ALT); + + return !is_positive_char; +} + + +// transforms a grammar pushdown stack into N possible stacks, all ending +// at a character range (terminal element) +static void whisper_grammar_advance_stack( + const std::vector> & rules, + const std::vector & stack, + std::vector> & new_stacks) { + + if (stack.empty()) { + new_stacks.push_back(stack); + return; + } + + const whisper_grammar_element * pos = stack.back(); + + switch (pos->type) { + case WHISPER_GRETYPE_RULE_REF: { + const size_t rule_id = static_cast(pos->value); + const whisper_grammar_element * subpos = rules[rule_id].data(); + do { + // init new stack without the top (pos) + std::vector new_stack(stack.begin(), stack.end() - 1); + if (!whisper_grammar_is_end_of_sequence(pos + 1)) { + // if this rule ref is followed by another element, add that to stack + new_stack.push_back(pos + 1); + } + if (!whisper_grammar_is_end_of_sequence(subpos)) { + // if alternate is nonempty, add to stack + new_stack.push_back(subpos); + } + whisper_grammar_advance_stack(rules, new_stack, new_stacks); + while (!whisper_grammar_is_end_of_sequence(subpos)) { + // scan to end of alternate def + subpos++; + } + if (subpos->type == WHISPER_GRETYPE_ALT) { + // there's another alternate def of this rule to process + subpos++; + } else { + break; + } + } while (true); + break; + } + case WHISPER_GRETYPE_CHAR: + case WHISPER_GRETYPE_CHAR_NOT: + new_stacks.push_back(stack); + break; + default: + // end of alternate (WHISPER_GRETYPE_END, WHISPER_GRETYPE_ALT) or middle of char range + // (WHISPER_GRETYPE_CHAR_ALT, WHISPER_GRETYPE_CHAR_RNG_UPPER); stack should never be left on + // those + WHISPER_ASSERT(false); + } +} + +// takes a set of possible pushdown stacks on a grammar, which are required to +// be positioned at a character range (see `whisper_grammar_advance_stack`), and +// produces the N possible stacks if the given char is accepted at those +// positions +static std::vector> whisper_grammar_accept( + const std::vector> & rules, + const std::vector> & stacks, + const uint32_t chr) { + + std::vector> new_stacks; + + for (const auto & stack : stacks) { + if (stack.empty()) { + continue; + } + + auto match = whisper_grammar_match_char(stack.back(), chr); + if (match.first) { + const whisper_grammar_element * pos = match.second; + + // update top of stack to next element, if any + std::vector new_stack(stack.begin(), stack.end() - 1); + if (!whisper_grammar_is_end_of_sequence(pos)) { + new_stack.push_back(pos); + } + whisper_grammar_advance_stack(rules, new_stack, new_stacks); + } + } + + return new_stacks; +} + +static std::vector whisper_grammar_reject_candidates( + const std::vector> & rules, + const std::vector> & stacks, + const std::vector & candidates); + +static std::vector whisper_grammar_reject_candidates_for_stack( + const std::vector> & rules, + const std::vector & stack, + const std::vector & candidates) { + + std::vector rejects; + + if (stack.empty()) { + for (auto tok : candidates) { + if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) { + rejects.push_back(tok); + } + } + return rejects; + } + + const whisper_grammar_element * stack_pos = stack.back(); + + std::vector next_candidates; + for (auto tok : candidates) { + if (*tok.code_points == 0) { + // reached end of full codepoints in token, reject iff it ended in a partial sequence + // that cannot satisfy this position in grammar + if (tok.partial_utf8.n_remain != 0 && !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) { + rejects.push_back(tok); + } + } else if (whisper_grammar_match_char(stack_pos, *tok.code_points).first) { + next_candidates.push_back({ tok.id, tok.code_points + 1, tok.partial_utf8 }); + } else { + rejects.push_back(tok); + } + } + + const auto * stack_pos_after = whisper_grammar_match_char(stack_pos, 0).second; + + // update top of stack to next element, if any + std::vector stack_after(stack.begin(), stack.end() - 1); + if (!whisper_grammar_is_end_of_sequence(stack_pos_after)) { + stack_after.push_back(stack_pos_after); + } + std::vector> next_stacks; + whisper_grammar_advance_stack(rules, stack_after, next_stacks); + + auto next_rejects = whisper_grammar_reject_candidates(rules, next_stacks, next_candidates); + for (auto tok : next_rejects) { + rejects.push_back({ tok.id, tok.code_points - 1, tok.partial_utf8 }); + } + + return rejects; +} + +static std::vector whisper_grammar_reject_candidates( + const std::vector> & rules, + const std::vector> & stacks, + const std::vector & candidates) { + if (candidates.empty() || stacks.empty()) { + return std::vector(); + } + + auto rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates); + + for (size_t i = 1, size = stacks.size(); i < size; ++i) { + rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks[i], rejects); + } + return rejects; +} + +static struct whisper_grammar whisper_grammar_init( + const whisper_grammar_element ** rules, + size_t n_rules, + size_t i_start_rule) { + const whisper_grammar_element * pos; + + // copy rule definitions into vectors + std::vector> vec_rules(n_rules); + for (size_t i = 0; i < n_rules; i++) { + for (pos = rules[i]; pos->type != WHISPER_GRETYPE_END; pos++) { + vec_rules[i].push_back(*pos); + } + vec_rules[i].push_back({WHISPER_GRETYPE_END, 0}); + } + + // loop over alternates of start rule to build initial stacks + std::vector> stacks; + pos = rules[i_start_rule]; + do { + std::vector stack; + if (!whisper_grammar_is_end_of_sequence(pos)) { + // if alternate is nonempty, add to stack + stack.push_back(pos); + } + whisper_grammar_advance_stack(vec_rules, stack, stacks); + while (!whisper_grammar_is_end_of_sequence(pos)) { + // scan to end of alternate def + pos++; + } + if (pos->type == WHISPER_GRETYPE_ALT) { + // there's another alternate def of this rule to process + pos++; + } else { + break; + } + } while (true); + + return { std::move(vec_rules), std::move(stacks), {} }; +} + +static void whisper_suppress_invalid_grammar( + whisper_context & ctx, + const whisper_full_params & params, + std::vector & logits, + const whisper_grammar & grammar) { + + if (grammar.rules.empty() || grammar.stacks.empty()) { + return; + } + + //bool allow_eot = false; + //for (const auto & stack : grammar.stacks) { + // if (stack.empty()) { + // allow_eot = true; + // break; + // } + //} + + const whisper_token eot = whisper_token_eot(&ctx); + + std::vector, whisper_partial_utf8>> candidates_decoded; + std::vector candidates_grammar; + + for (whisper_token id = 0; id < eot; ++id) { + const std::string & text = ctx.vocab.id_to_token[id]; + if (!text.empty()) { + candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8)); + candidates_grammar.push_back({ id, candidates_decoded.back().first.data(), candidates_decoded.back().second }); + } + } + + const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar); + + for (const auto & reject : rejects) { + logits[reject.id] -= params.grammar_penalty; + } + + // when the grammar allows a continuation, we penalize the end-of-text token + //if (!allow_eot) { + // logits[eot] -= params.grammar_penalty; + //} + //fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size()); +} + +static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar & grammar, whisper_token token) { + if (grammar.rules.empty() || grammar.stacks.empty()) { + return; + } + + //fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str()); + + const std::string & text = ctx.vocab.id_to_token[token]; + + if (text.rfind("[_", 0) == 0) { + // fprintf(stderr, " (skipped)\n"); + return; + } + // fprintf(stderr, "\n"); + + // Note terminating 0 in decoded string + const auto decoded = decode_utf8(text.c_str(), grammar.partial_utf8); + const auto & code_points = decoded.first; + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + grammar.stacks = whisper_grammar_accept(grammar.rules, grammar.stacks, *it); + } + grammar.partial_utf8 = decoded.second; +} + +////////////// +// END grammar +////////////// + //////////////////////////////////////////////////////////////////////////// struct whisper_context_params * whisper_context_default_params_by_ref() { @@ -3714,6 +4356,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.translate =*/ false, /*.no_context =*/ true, + /*.no_timestamps =*/ false, /*.single_segment =*/ false, /*.print_special =*/ false, /*.print_progress =*/ true, @@ -3747,7 +4390,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.max_initial_ts =*/ 1.0f, /*.length_penalty =*/ -1.0f, - /*.temperature_inc =*/ 0.4f, + /*.temperature_inc =*/ 0.2f, /*.entropy_thold =*/ 2.4f, /*.logprob_thold =*/ -1.0f, /*.no_speech_thold =*/ 0.6f, @@ -3776,19 +4419,24 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.logits_filter_callback =*/ nullptr, /*.logits_filter_callback_user_data =*/ nullptr, + + /*.grammar_rules =*/ nullptr, + /*.n_grammar_rules =*/ 0, + /*.i_start_rule =*/ 0, + /*.grammar_penalty =*/ 100.0f, }; switch (strategy) { case WHISPER_SAMPLING_GREEDY: { result.greedy = { - /*.best_of =*/ 2, // TODO: increase to 5 when we speed-up batch decoding + /*.best_of =*/ 5, }; } break; case WHISPER_SAMPLING_BEAM_SEARCH: { result.beam_search = { - /*.beam_size =*/ 2, // TODO: increase to 5 when we speed-up batch decoding + /*.beam_size =*/ 5, /*.patience =*/ -1.0f, }; @@ -3878,11 +4526,12 @@ static const std::vector non_speech_tokens = { // process the logits for the selected decoder // - applies logit filters // - computes logprobs and probs +// TODO: optimize static void whisper_process_logits( struct whisper_context & ctx, struct whisper_state & state, - const struct whisper_full_params params, struct whisper_decoder & decoder, + const struct whisper_full_params params, float temperature) { const auto & vocab = ctx.vocab; const auto & tokens_cur = decoder.sequence.tokens; @@ -3899,7 +4548,7 @@ static void whisper_process_logits( auto & logprobs = decoder.logprobs; { logits.resize(n_logits); - memcpy(logits.data(), state.logits.data() + (state.logits.size() - n_logits), n_logits*sizeof(float)); + memcpy(logits.data(), state.logits.data() + decoder.i_batch*n_logits, n_logits*sizeof(float)); if (temperature > 0.0f) { for (int i = 0; i < n_logits; i++) { @@ -3927,6 +4576,11 @@ static void whisper_process_logits( // suppress <|notimestamps|> token // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412 logits[vocab.token_not] = -INFINITY; + if (params.no_timestamps) { + for (int i = vocab.token_beg; i < n_logits; ++i) { + logits[i] = -INFINITY; + } + } // suppress sot and nosp tokens logits[vocab.token_sot] = -INFINITY; @@ -3942,6 +4596,14 @@ static void whisper_process_logits( logits[vocab.token_transcribe] = -INFINITY; logits[vocab.token_prev] = -INFINITY; + // suppress lang tokens + for (size_t i = 0; i < g_lang.size(); ++i) { + logits[whisper_token_lang(&ctx, i)] = -INFINITY; + } + + // suppress prev token + logits[vocab.token_prev] = -INFINITY; + if (params.logits_filter_callback) { params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data); } @@ -4056,6 +4718,30 @@ static void whisper_process_logits( logits[i] = -INFINITY; logprobs[i] = -INFINITY; } + } else { + if (params.n_grammar_rules > 0) { + whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar); + + // populate the logprobs array (log_softmax) + { + const float logit_max = *std::max_element(logits.begin(), logits.end()); + float logsumexp = 0.0f; + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logsumexp += expf(logits[i] - logit_max); + } + } + logsumexp = logf(logsumexp) + logit_max; + + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logprobs[i] = logits[i] - logsumexp; + } else { + logprobs[i] = -INFINITY; + } + } + } + } } } } @@ -4073,38 +4759,60 @@ static void whisper_process_logits( #if 0 // print first 100 logits - token string : logit - for (int i = 0; i < 100; i++) { - const auto token = vocab.id_to_token.at(i); - const auto prob = probs[i]; - const auto logit = logits[i]; - const auto logprob = logprobs[i]; - printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob); + //for (int i = 0; i < 10; i++) { + // const auto token = vocab.id_to_token.at(i); + // const auto prob = probs[i]; + // const auto logit = logits[i]; + // const auto logprob = logprobs[i]; + // printf("%16s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob); + //} + + // print sorted + { + std::vector> pairs; + + for (int i = 0; i < n_logits; ++i) { + pairs.push_back(std::make_pair(probs[i], i)); + } + + std::sort(pairs.begin(), pairs.end(), [](const std::pair& a, const std::pair& b) { + return a.first > b.first; + }); + + for (int i = 0; i < 10; i++) { + const auto token = vocab.id_to_token.at(pairs[i].second); + const auto prob = pairs[i].first; + const auto logit = logits[pairs[i].second]; + const auto logprob = logprobs[pairs[i].second]; + printf("%16s : id=%6d prob=%9.5f logit=%9.5f logprob=%9.5f '%s'\n", token.c_str(), pairs[i].second, prob, logit, logprob, token.c_str()); + } + + printf("----------------\n"); } // "And", "and", " And", " and" - printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]); - printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]); - printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]); - printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]); - printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]); - - printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]); - printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]); - printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]); - printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]); - printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); - - printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]); - printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]); - printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]); - printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]); - printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]); + //printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]); + //printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]); + //printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]); + //printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]); + //printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]); + + //printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]); + //printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]); + //printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]); + //printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]); + //printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); + + //printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]); + //printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]); + //printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]); + //printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]); + //printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]); #endif } static whisper_token_data whisper_sample_token( whisper_context & ctx, - whisper_state & state, const whisper_decoder & decoder, bool best) { whisper_token_data result = { @@ -4149,7 +4857,7 @@ static whisper_token_data whisper_sample_token( } else { std::discrete_distribution<> dist(probs.begin(), probs.end()); - result.id = dist(state.rng); + result.id = dist(decoder.rng); result.p = probs[result.id]; result.plog = logprobs[result.id]; } @@ -4159,15 +4867,12 @@ static whisper_token_data whisper_sample_token( result.pt = result.p; } - state.n_sample++; - return result; } static std::vector whisper_sample_token_topk( whisper_context & ctx, - whisper_state & state, - const whisper_decoder & decoder, + whisper_decoder & decoder, int k) { const auto & vocab = ctx.vocab; @@ -4177,7 +4882,7 @@ static std::vector whisper_sample_token_topk( const int n_logits = vocab.n_vocab; - auto & logits_id = state.logits_id; + auto & logits_id = decoder.logits_id; logits_id.resize(n_logits); for (int i = 0; i < n_logits; ++i) { @@ -4223,8 +4928,11 @@ static std::vector whisper_sample_token_topk( ptsum = sum_ts; } + std::discrete_distribution<> dist(probs.begin(), probs.end()); + for (int i = 0; i < k; ++i) { - const auto id = logits_id[i].second; + const auto id = dist(decoder.rng); + //printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum); result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, }); @@ -4234,8 +4942,6 @@ static std::vector whisper_sample_token_topk( } } - state.n_sample++; - return result; } @@ -4288,125 +4994,6 @@ static void whisper_sequence_score( } } -static bool whisper_kv_swap_fast( - std::vector & view, - whisper_decoder src[], - std::vector & kv_swap_bufs, - const int & n_decoders) { - WHISPER_PRINT_DEBUG("%s: n_decoders %d\n", __func__, n_decoders); - - // (decoder->buffer->decoder or decoder->buffer + decoder->decoder) - std::set two_copy; // decoder indices require two copies to safely modify KV caches - - // (buffer->decoder or decoder->decoder) - std::set one_copy; // decoder indices require one copy to safely modify KV caches - - // (decoder<->decoder) - std::set p_swap_set; // decoder indices able to swap KV-cache pointers - std::vector> p_swap_vec; - p_swap_vec.reserve(n_decoders); - - // see https://github.com/ggerganov/whisper.cpp/wiki - for (int i = 0; i < n_decoders; i++) { - // zero-copy (no modification) - if (i == view[i] || view[i] < 0) { - continue; - } - - bool is_one_copy = true; - // since we modify data sequentially, we only consider decoder indices after current index - for (int j = i + 1; j < n_decoders; j++) { - if (i == view[j]) { - // detect symmetric diagram - if (j == view[i]) { - p_swap_set.insert(i); - p_swap_set.insert(j); - p_swap_vec.emplace_back(i, j); - } else { - two_copy.insert(i); - is_one_copy = false; - } - break; - } - } - if (is_one_copy) { - one_copy.insert(i); - } - } - - kv_swap_bufs.resize(n_decoders); - - for (int i = 0; i < n_decoders; i++) { - kv_swap_bufs[i].k.resize(ggml_nbytes(src[i].kv_self.k)); - kv_swap_bufs[i].v.resize(ggml_nbytes(src[i].kv_self.v)); - } - - for (auto & i : two_copy) { - // make a copy of KV caches - WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i); - //memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size()); - //memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size()); - ggml_backend_tensor_get(src[i].kv_self.k, kv_swap_bufs[i].k.data(), 0, kv_swap_bufs[i].k.size()); - ggml_backend_tensor_get(src[i].kv_self.v, kv_swap_bufs[i].v.data(), 0, kv_swap_bufs[i].v.size()); - } - - // since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first - for (auto & i : two_copy) { - // skip the decoder indices that require pointer swapping - if (p_swap_set.find(i) != p_swap_set.end()) { - continue; - } - - if (two_copy.find(view[i]) != two_copy.end()) { - // modify KV caches of decoder using data from kv_swap_bufs - WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i); - //memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size()); - //memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size()); - ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size()); - ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size()); - } else { - // modify KV caches of decoder using data from correspond decoder KV caches directly - WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i); - //memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k)); - //memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v)); - ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k); - ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v); - } - } - - // then modify one-copy decoder KV caches - for (auto & i : one_copy) { - // skip the decoder indices that require pointer swapping - if (p_swap_set.find(i) != p_swap_set.end()) { - continue; - } - - if (two_copy.find(view[i]) != two_copy.end()) { - // modify KV caches of decoder using data from kv_swap_bufs - WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i); - //memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size()); - //memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size()); - ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size()); - ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size()); - } else { - // modify KV caches of decoder using data from correspond decoder KV caches directly - WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i); - //memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k)); - //memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v)); - ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k); - ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v); - } - } - - // swap the pointers - for (auto & i : p_swap_vec) { - WHISPER_PRINT_DEBUG("%s: swap pointers: %d <-> %d\n", __func__, i.first, i.second); - std::swap(src[i.first].kv_self, src[i.second].kv_self); - } - - return true; -} - int whisper_full_with_state( struct whisper_context * ctx, struct whisper_state * state, @@ -4496,25 +5083,23 @@ int whisper_full_with_state( n_decoders = std::max(1, n_decoders); + if (n_decoders > WHISPER_MAX_DECODERS) { + WHISPER_LOG_ERROR("%s: too many decoders requested (%d), max = %d\n", __func__, n_decoders, WHISPER_MAX_DECODERS); + return -4; + } + // TAGS: WHISPER_DECODER_INIT for (int j = 1; j < n_decoders; j++) { auto & decoder = state->decoders[j]; - if (decoder.kv_self.ctx == nullptr) { - decoder.kv_self = state->decoders[0].kv_self; - if (!kv_cache_reinit(decoder.kv_self, ctx->backend)) { - WHISPER_LOG_ERROR("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j); - return -4; - } - - WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j); + decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity()); - decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity()); + decoder.probs.resize (ctx->vocab.n_vocab); + decoder.logits.resize (ctx->vocab.n_vocab); + decoder.logprobs.resize(ctx->vocab.n_vocab); + decoder.logits_id.reserve(ctx->model.hparams.n_vocab); - decoder.probs.resize (ctx->vocab.n_vocab); - decoder.logits.resize (ctx->vocab.n_vocab); - decoder.logprobs.resize(ctx->vocab.n_vocab); - } + decoder.rng = std::mt19937(0); } // the accumulated text context so far @@ -4553,7 +5138,7 @@ int whisper_full_with_state( state->exp_n_audio_ctx = params.audio_ctx; // these tokens determine the task that will be performed - std::vector prompt_init = { whisper_token_sot(ctx) }; + std::vector prompt_init = { whisper_token_sot(ctx), }; if (whisper_is_multilingual(ctx)) { const int lang_id = whisper_lang_id(params.language); @@ -4566,17 +5151,19 @@ int whisper_full_with_state( } } + // distilled models require the "no_timestamps" token { const bool is_distil = ctx->model.hparams.n_text_layer == 2; - - // distilled models require the "no_timestamps" token - // TODO: add input parameter (#1229) - if (is_distil) { + if (is_distil && !params.no_timestamps) { WHISPER_LOG_WARN("%s: using distilled model - forcing no_timestamps\n", __func__); - prompt_init.push_back(whisper_token_not(ctx)); + params.no_timestamps = true; } } + if (params.no_timestamps) { + prompt_init.push_back(whisper_token_not(ctx)); + } + int seek = seek_start; std::vector prompt; @@ -4589,8 +5176,10 @@ int whisper_full_with_state( bool has_ts; whisper_sequence sequence; + whisper_grammar grammar; }; + std::vector> bc_per_dec(n_decoders); std::vector beam_candidates; // main loop @@ -4652,14 +5241,12 @@ int whisper_full_with_state( n_decoders_cur = std::max(1, n_decoders_cur); - WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur); + WHISPER_PRINT_DEBUG("\n%s: strategy = %d, decoding with %d decoders, temperature = %.2f\n", __func__, params.strategy, n_decoders_cur, t_cur); // TAGS: WHISPER_DECODER_INIT for (int j = 0; j < n_decoders_cur; ++j) { auto & decoder = state->decoders[j]; - decoder.kv_self.n = 0; - decoder.sequence.tokens.clear(); decoder.sequence.result_len = 0; decoder.sequence.sum_logprobs_all = 0.0; @@ -4673,10 +5260,16 @@ int whisper_full_with_state( decoder.failed = false; decoder.completed = false; decoder.has_ts = false; + + if (params.grammar_rules != nullptr) { + decoder.grammar = whisper_grammar_init(params.grammar_rules, params.n_grammar_rules, params.i_start_rule); + } else { + decoder.grammar = {}; + } } // init prompt and kv cache for the current iteration - // run whisper_decoder() only for decoder 0 and copy the results for the other decoders + // TODO: do not recompute the prompt if it is the same as previous time { prompt.clear(); @@ -4698,7 +5291,11 @@ int whisper_full_with_state( } WHISPER_PRINT_DEBUG("\n\n"); - if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { + whisper_kv_cache_clear(state->kv_self); + + whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0); + + if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); return -7; } @@ -4706,20 +5303,14 @@ int whisper_full_with_state( { const int64_t t_start_sample_us = ggml_time_us(); - whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur); + state->decoders[0].i_batch = prompt.size() - 1; - state->decoders[0].kv_self.n += prompt.size(); + whisper_process_logits(*ctx, *state, state->decoders[0], params, t_cur); for (int j = 1; j < n_decoders_cur; ++j) { auto & decoder = state->decoders[j]; - // TODO: fix CUDA - //memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k)); - //memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v)); - ggml_backend_tensor_copy(state->decoders[0].kv_self.k, decoder.kv_self.k); - ggml_backend_tensor_copy(state->decoders[0].kv_self.v, decoder.kv_self.v); - - decoder.kv_self.n += prompt.size(); + whisper_kv_cache_seq_cp(state->kv_self, 0, j, -1, -1); memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0])); memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); @@ -4734,41 +5325,81 @@ int whisper_full_with_state( const int64_t t_start_sample_us = ggml_time_us(); if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { - beam_candidates.clear(); + for (auto & bc : bc_per_dec) { + bc.clear(); + } } - // generate new sequence candidates for each decoder - for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = state->decoders[j]; + // sampling + // TODO: avoid memory allocations, optimize, avoid threads? + { + std::atomic j_cur(0); - if (decoder.completed || decoder.failed) { - continue; - } + auto process = [&]() { + while (true) { + const int j = j_cur.fetch_add(1); - switch (params.strategy) { - case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: - { - if (t_cur < 1e-6f) { - decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, true)); - } else { - decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, false)); - } + if (j >= n_decoders_cur) { + break; + } - decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog; - } break; - case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: - { - const auto tokens_new = whisper_sample_token_topk(*ctx, *state, decoder, params.beam_search.beam_size); + auto & decoder = state->decoders[j]; - for (const auto & token : tokens_new) { - beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence }); - beam_candidates.back().sequence.tokens.push_back(token); - beam_candidates.back().sequence.sum_logprobs_all += token.plog; + if (decoder.completed || decoder.failed) { + continue; + } - //WHISPER_PRINT_DEBUG("%s: beam candidate: %s (%f, %f)\n", __func__, ctx->vocab.id_to_token.at(token.id).c_str(), token.plog, beam_candidates.back().sequence.sum_logprobs_all); - } - } break; + switch (params.strategy) { + case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: + { + if (t_cur < 1e-6f) { + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true)); + } else { + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false)); + } + + decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog; + } break; + case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: + { + const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size); + + for (const auto & token : tokens_new) { + bc_per_dec[j].push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, }); + bc_per_dec[j].back().sequence.tokens.push_back(token); + bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog; + } + } break; + }; + } }; + + const int n_threads = std::min(params.n_threads, n_decoders_cur); + + if (n_threads == 1) { + process(); + } else { + std::vector threads(n_threads - 1); + + for (int t = 0; t < n_threads - 1; ++t) { + threads[t] = std::thread(process); + } + + process(); + + for (int t = 0; t < n_threads - 1; ++t) { + threads[t].join(); + } + } + } + + beam_candidates.clear(); + for (const auto & bc : bc_per_dec) { + beam_candidates.insert(beam_candidates.end(), bc.begin(), bc.end()); + + if (!bc.empty()) { + state->n_sample += 1; + } } // for beam-search, choose the top candidates and update the KV caches @@ -4781,7 +5412,6 @@ int whisper_full_with_state( }); uint32_t cur_c = 0; - std::vector decoder_idx(n_decoders_cur, -1); for (int j = 0; j < n_decoders_cur; ++j) { auto & decoder = state->decoders[j]; @@ -4790,23 +5420,38 @@ int whisper_full_with_state( continue; } + if (cur_c >= beam_candidates.size()) { + cur_c = 0; + } + auto & cur = beam_candidates[cur_c++]; while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) { ++cur_c; } - decoder.sequence = cur.sequence; decoder.seek_delta = cur.seek_delta; decoder.has_ts = cur.has_ts; + decoder.sequence = cur.sequence; + decoder.grammar = cur.grammar; + + whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1); - decoder_idx[j] = cur.decoder_idx; WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n", __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all); } - // update KV caches - whisper_kv_swap_fast(decoder_idx, state->decoders, state->kv_swap_bufs, n_decoders_cur); + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = state->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + whisper_kv_cache_seq_rm(state->kv_self, j, -1, -1); + whisper_kv_cache_seq_cp(state->kv_self, WHISPER_MAX_DECODERS + j, j, -1, -1); + whisper_kv_cache_seq_rm(state->kv_self, WHISPER_MAX_DECODERS + j, -1, -1); + } } // update the decoder state @@ -4844,6 +5489,8 @@ int whisper_full_with_state( has_ts = true; } + whisper_grammar_accept_token(*ctx, decoder.grammar, token.id); + #ifdef WHISPER_DEBUG { const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]"; @@ -4913,32 +5560,83 @@ int whisper_full_with_state( state->t_sample_us += ggml_time_us() - t_start_sample_us; // obtain logits for the next token - for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = state->decoders[j]; + { + auto & batch = state->batch; - if (decoder.failed || decoder.completed) { - continue; - } + batch.n_tokens = 0; - decoder.tokens_tmp.resize(1); - decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id; + const int n_past = prompt.size() + i; - //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta); + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = state->decoders[j]; - if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { + if (decoder.failed || decoder.completed) { + continue; + } + + //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta); + + decoder.i_batch = batch.n_tokens; + + batch.token [batch.n_tokens] = decoder.sequence.tokens.back().id; + batch.pos [batch.n_tokens] = n_past; + batch.n_seq_id[batch.n_tokens] = 1; + batch.seq_id [batch.n_tokens][0] = j; + batch.logits [batch.n_tokens] = 1; + batch.n_tokens++; + } + + assert(batch.n_tokens > 0); + + if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); return -8; } + const int64_t t_start_sample_us = ggml_time_us(); + + // TODO: avoid memory allocations, optimize, avoid threads? { - const int64_t t_start_sample_us = ggml_time_us(); + std::atomic j_cur(0); + + auto process = [&]() { + while (true) { + const int j = j_cur.fetch_add(1); - whisper_process_logits(*ctx, *state, params, decoder, t_cur); + if (j >= n_decoders_cur) { + break; + } - ++decoder.kv_self.n; + auto & decoder = state->decoders[j]; + + if (decoder.failed || decoder.completed) { + continue; + } + + whisper_process_logits(*ctx, *state, decoder, params, t_cur); + } + }; - state->t_sample_us += ggml_time_us() - t_start_sample_us; + const int n_threads = std::min(params.n_threads, n_decoders_cur); + + if (n_threads == 1) { + process(); + } else { + std::vector threads(n_threads - 1); + + for (int t = 0; t < n_threads - 1; ++t) { + threads[t] = std::thread(process); + } + + process(); + + for (int t = 0; t < n_threads - 1; ++t) { + threads[t].join(); + } + } } + + state->t_sample_us += ggml_time_us() - t_start_sample_us; } } @@ -5235,11 +5933,13 @@ int whisper_full_parallel( ctx->state->t_sample_us += states[i]->t_sample_us; ctx->state->t_encode_us += states[i]->t_encode_us; ctx->state->t_decode_us += states[i]->t_decode_us; + ctx->state->t_batchd_us += states[i]->t_batchd_us; ctx->state->t_prompt_us += states[i]->t_prompt_us; ctx->state->n_sample += states[i]->n_sample; ctx->state->n_encode += states[i]->n_encode; ctx->state->n_decode += states[i]->n_decode; + ctx->state->n_batchd += states[i]->n_batchd; ctx->state->n_prompt += states[i]->n_prompt; whisper_free_state(states[i]); @@ -5372,8 +6072,8 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) { size_t n = 20; size_t arr = n_threads > 0 ? 1024llu : n_threads; // trick to avoid compiler optimizations - // 1GB MB array - const size_t size = arr*1024llu*1024llu; + // 1GB array + const size_t size = arr*1e6; // single-thread { @@ -5399,7 +6099,7 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) { src[rand() % size] = rand() % 256; } - snprintf(strbuf, sizeof(strbuf), "memcpy: %.2f GB/s (1 thread)\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu)); + snprintf(strbuf, sizeof(strbuf), "memcpy: %.2f GB/s (1 thread)\n", (double) (n*size)/(tsum*1e9)); s += strbuf; // needed to prevent the compiler from optimizing the memcpy away diff --git a/examples/whisper/whisper.h b/examples/whisper/whisper.h index 0ea5237e..84540989 100644 --- a/examples/whisper/whisper.h +++ b/examples/whisper/whisper.h @@ -78,7 +78,9 @@ extern "C" { struct whisper_state; struct whisper_full_params; - typedef int whisper_token; + typedef int32_t whisper_pos; + typedef int32_t whisper_token; + typedef int32_t whisper_seq_id; struct whisper_context_params { bool use_gpu; @@ -109,6 +111,37 @@ extern "C" { void (*close)(void * ctx); } whisper_model_loader; + // grammar element type + enum whisper_gretype { + // end of rule definition + WHISPER_GRETYPE_END = 0, + + // start of alternate definition for rule + WHISPER_GRETYPE_ALT = 1, + + // non-terminal element: reference to rule + WHISPER_GRETYPE_RULE_REF = 2, + + // terminal element: character (code point) + WHISPER_GRETYPE_CHAR = 3, + + // inverse char(s) ([^a], [^a-b] [^abc]) + WHISPER_GRETYPE_CHAR_NOT = 4, + + // modifies a preceding WHISPER_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to + // be an inclusive range ([a-z]) + WHISPER_GRETYPE_CHAR_RNG_UPPER = 5, + + // modifies a preceding WHISPER_GRETYPE_CHAR or + // WHISPER_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) + WHISPER_GRETYPE_CHAR_ALT = 6, + }; + + typedef struct whisper_grammar_element { + enum whisper_gretype type; + uint32_t value; // Unicode code point or rule ID + } whisper_grammar_element; + // Various functions for loading a ggml whisper model. // Allocate (almost) all memory needed for the model. // Return NULL on failure @@ -402,6 +435,7 @@ extern "C" { bool translate; bool no_context; // do not use past transcription (if any) as initial prompt for the decoder + bool no_timestamps; // do not generate timestamps bool single_segment; // force single segment output (useful for streaming) bool print_special; // print special tokens (e.g. , , , etc.) bool print_progress; // print progress information @@ -479,6 +513,11 @@ extern "C" { // called by each decoder to filter obtained logits whisper_logits_filter_callback logits_filter_callback; void * logits_filter_callback_user_data; + + const whisper_grammar_element ** grammar_rules; + size_t n_grammar_rules; + size_t i_start_rule; + float grammar_penalty; }; // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d65a58e5..d250000d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -26,6 +26,15 @@ if (NOT UNAME_M) endif() #message(STATUS "UNAME_S: ${UNAME_S} UNAME_P: ${UNAME_P} UNAME_M: ${UNAME_M}") +# this version of Apple ld64 is buggy +execute_process( + COMMAND ${CMAKE_C_COMPILER} ${CMAKE_EXE_LINKER_FLAGS} -Wl,-v + ERROR_VARIABLE output +) +if (output MATCHES "dyld-1015\.7") + add_compile_definitions(HAVE_BUGGY_APPLE_LINKER) +endif() + # Mac OS + Arm can report x86_64 # ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789 if (UNAME_S MATCHES "Darwin") diff --git a/src/ggml-metal.m b/src/ggml-metal.m index 3d22b0b2..4fe9cc48 100644 --- a/src/ggml-metal.m +++ b/src/ggml-metal.m @@ -346,9 +346,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { } GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false"); - GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); + GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6); if (ctx->device.maxTransferRate != 0) { - GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0); + GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6); } else { GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__); } @@ -541,11 +541,11 @@ bool ggml_metal_add_buffer( ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; if (ctx->buffers[ctx->n_buffers].metal == nil) { - GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0); + GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1e6); return false; } - GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0); + GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1e6); ++ctx->n_buffers; } else { @@ -565,11 +565,11 @@ bool ggml_metal_add_buffer( ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; if (ctx->buffers[ctx->n_buffers].metal == nil) { - GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0); + GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1e6); return false; } - GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i); + GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1e6, i); if (i + size_step < size) { GGML_METAL_LOG_INFO("\n"); } @@ -580,8 +580,8 @@ bool ggml_metal_add_buffer( #if TARGET_OS_OSX GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)", - ctx->device.currentAllocatedSize / 1024.0 / 1024.0, - ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); + ctx->device.currentAllocatedSize / 1e6, + ctx->device.recommendedMaxWorkingSetSize / 1e6); if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) { GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__); @@ -589,7 +589,7 @@ bool ggml_metal_add_buffer( GGML_METAL_LOG_INFO("\n"); } #else - GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0); + GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1e6); #endif }