-#include "whisper.h"
+#include "common.h"
-// third-party utilities
-// use your favorite implementations
-#define DR_WAV_IMPLEMENTATION
-#include "dr_wav.h"
+#include "whisper.h"
#include <cmath>
#include <fstream>
// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
- int32_t n_processors = 1;
- int32_t offset_t_ms = 0;
- int32_t offset_n = 0;
- int32_t duration_ms = 0;
+ int32_t n_processors = 1;
+ int32_t offset_t_ms = 0;
+ int32_t offset_n = 0;
+ int32_t duration_ms = 0;
int32_t max_context = -1;
- int32_t max_len = 0;
- int32_t best_of = 5;
+ int32_t max_len = 0;
+ int32_t best_of = 5;
int32_t beam_size = -1;
- float word_thold = 0.01f;
- float entropy_thold = 2.4f;
- float logprob_thold = -1.0f;
+ float word_thold = 0.01f;
+ float entropy_thold = 2.40f;
+ float logprob_thold = -1.00f;
bool speed_up = false;
bool translate = false;
bool diarize = false;
+ bool split_on_word = false;
+ bool no_fallback = false;
bool output_txt = false;
bool output_vtt = false;
bool output_srt = false;
std::string model = "models/ggml-base.en.bin";
std::vector<std::string> fname_inp = {};
+ std::vector<std::string> fname_out = {};
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
-
+
+ if (arg == "-"){
+ params.fname_inp.push_back(arg);
+ continue;
+ }
+
if (arg[0] != '-') {
params.fname_inp.push_back(arg);
continue;
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
+ else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
+ else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
+ else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
- fprintf(stderr, " -h, --help [default] show this help message and exit\n");
- fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
- fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
- fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
- fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
- fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
- fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
- fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
- fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
- fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
- fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
- fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
- fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
- fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
- fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
- fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
- fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
- fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
- fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
- fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
- fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
- fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
- fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
- fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
- fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
- fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
- fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
- fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
- fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
+ fprintf(stderr, " -h, --help [default] show this help message and exit\n");
+ fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
+ fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
+ fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
+ fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
+ fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
+ fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
+ fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
+ fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false");
+ fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
+ fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
+ fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
+ fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
+ fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
+ fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
+ fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
+ fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
+ fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
+ fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
+ fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
+ fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
+ fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
+ fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
+ fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", "");
+ fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
+ fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
+ fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
+ fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
+ fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
+ fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
+ fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
+ fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, "\n");
}
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
- if (text[0] == ' ') {
- text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character.
- }
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f];
+ const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
- std::vector<float> pcmf32; // mono-channel F32 PCM
+ std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
- // WAV input
- {
- drwav wav;
- std::vector<uint8_t> wav_data; // used for pipe input from stdin
-
- if (fname_inp == "-") {
- {
- uint8_t buf[1024];
- while (true)
- {
- const size_t n = fread(buf, 1, sizeof(buf), stdin);
- if (n == 0) {
- break;
- }
- wav_data.insert(wav_data.end(), buf, buf + n);
- }
- }
-
- if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
- fprintf(stderr, "error: failed to open WAV file from stdin\n");
- return 4;
- }
-
- fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
- }
- else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
- fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
- return 5;
- }
-
- if (wav.channels != 1 && wav.channels != 2) {
- fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
- return 6;
- }
-
- if (params.diarize && wav.channels != 2 && params.no_timestamps == false) {
- fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", argv[0], fname_inp.c_str());
- return 6;
- }
-
- if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
- fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", argv[0], fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
- return 8;
- }
-
- if (wav.bitsPerSample != 16) {
- fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
- return 9;
- }
-
- const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
-
- std::vector<int16_t> pcm16;
- pcm16.resize(n*wav.channels);
- drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
- drwav_uninit(&wav);
-
- // convert to mono, float
- pcmf32.resize(n);
- if (wav.channels == 1) {
- for (uint64_t i = 0; i < n; i++) {
- pcmf32[i] = float(pcm16[i])/32768.0f;
- }
- } else {
- for (uint64_t i = 0; i < n; i++) {
- pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
- }
- }
-
- if (params.diarize) {
- // convert to stereo, float
- pcmf32s.resize(2);
-
- pcmf32s[0].resize(n);
- pcmf32s[1].resize(n);
- for (uint64_t i = 0; i < n; i++) {
- pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
- pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
- }
- }
+ if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) {
+ fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str());
+ continue;
}
// print system information
wparams.token_timestamps = params.output_wts || params.max_len > 0;
wparams.thold_pt = params.word_thold;
- wparams.entropy_thold = params.entropy_thold;
- wparams.logprob_thold = params.logprob_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
+ wparams.split_on_word = params.split_on_word;
wparams.speed_up = params.speed_up;
+ wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
+ wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
+
wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;
- wparams.temperature_inc = -1;
- wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
- wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
+ wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
+ wparams.entropy_thold = params.entropy_thold;
+ wparams.logprob_thold = params.logprob_thold;
whisper_print_user_data user_data = { ¶ms, &pcmf32s };
// output to text file
if (params.output_txt) {
- const auto fname_txt = fname_inp + ".txt";
+ const auto fname_txt = fname_out + ".txt";
output_txt(ctx, fname_txt.c_str());
}
// output to VTT file
if (params.output_vtt) {
- const auto fname_vtt = fname_inp + ".vtt";
+ const auto fname_vtt = fname_out + ".vtt";
output_vtt(ctx, fname_vtt.c_str());
}
// output to SRT file
if (params.output_srt) {
- const auto fname_srt = fname_inp + ".srt";
+ const auto fname_srt = fname_out + ".srt";
output_srt(ctx, fname_srt.c_str(), params);
}
// output to WTS file
if (params.output_wts) {
- const auto fname_wts = fname_inp + ".wts";
+ const auto fname_wts = fname_out + ".wts";
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
}
- // output to CSV file
+ // output to CSV file
if (params.output_csv) {
- const auto fname_csv = fname_inp + ".csv";
+ const auto fname_csv = fname_out + ".csv";
output_csv(ctx, fname_csv.c_str());
}
-
}
}
#include <regex>
#include <random>
+#if defined(GGML_BIG_ENDIAN)
+#include <bit>
+
+template<typename T>
+static T byteswap(T value) {
+ return std::byteswap(value);
+}
+
+template<>
+float byteswap(float value) {
+ return std::bit_cast<float>(byteswap(std::bit_cast<std::uint32_t>(value)));
+}
+
+template<typename T>
+static void byteswap_tensor_data(ggml_tensor * tensor) {
+ T * datum = reinterpret_cast<T *>(tensor->data);
+ for (int i = 0; i < ggml_nelements(tensor); i++) {
+ datum[i] = byteswap(datum[i]);
+ }
+}
+
+static void byteswap_tensor(ggml_tensor * tensor) {
+ switch (tensor->type) {
+ case GGML_TYPE_I16: {
+ byteswap_tensor_data<int16_t>(tensor);
+ break;
+ }
+ case GGML_TYPE_F16: {
+ byteswap_tensor_data<ggml_fp16_t>(tensor);
+ break;
+ }
+ case GGML_TYPE_I32: {
+ byteswap_tensor_data<int32_t>(tensor);
+ break;
+ }
+ case GGML_TYPE_F32: {
+ byteswap_tensor_data<float>(tensor);
+ break;
+ }
+ default: { // GML_TYPE_I8
+ break;
+ }
+ }
+}
+
+#define BYTESWAP_VALUE(d) d = byteswap(d)
+#define BYTESWAP_FILTERS(f) \
+ do { \
+ for (auto & datum : f.data) { \
+ datum = byteswap(datum); \
+ } \
+ } while (0)
+#define BYTESWAP_TENSOR(t) \
+ do { \
+ byteswap_tensor(tensor); \
+ } while (0)
+#else
+#define BYTESWAP_VALUE(d) do {} while (0)
+#define BYTESWAP_FILTERS(f) do {} while (0)
+#define BYTESWAP_TENSOR(t) do {} while (0)
+#endif
+
#define WHISPER_ASSERT(x) \
do { \
if (!(x)) { \
//#define WHISPER_USE_FLASH_FF
#define WHISPER_MAX_DECODERS 16
+#define WHISPER_USE_SCRATCH
+#define WHISPER_MAX_SCRATCH_BUFFERS 16
+
// available whisper models
enum e_model {
MODEL_UNKNOWN,
static const size_t MB = 1024*1024;
+static const std::map<e_model, size_t> MEM_REQ_SCRATCH0 = {
+ { MODEL_TINY, 12ull*MB },
+ { MODEL_BASE, 15ull*MB },
+ { MODEL_SMALL, 23ull*MB },
+ { MODEL_MEDIUM, 31ull*MB },
+ { MODEL_LARGE, 38ull*MB },
+};
+
+static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
+ { MODEL_TINY, 18ull*MB },
+ { MODEL_BASE, 24ull*MB },
+ { MODEL_SMALL, 36ull*MB },
+ { MODEL_MEDIUM, 48ull*MB },
+ { MODEL_LARGE, 60ull*MB },
+};
+
+static const std::map<e_model, size_t> MEM_REQ_SCRATCH2 = {
+ { MODEL_TINY, 4ull*MB },
+ { MODEL_BASE, 4ull*MB },
+ { MODEL_SMALL, 6ull*MB },
+ { MODEL_MEDIUM, 7ull*MB },
+ { MODEL_LARGE, 9ull*MB },
+};
+
+static const std::map<e_model, size_t> MEM_REQ_SCRATCH3 = {
+ { MODEL_TINY, 4ull*MB },
+ { MODEL_BASE, 4ull*MB },
+ { MODEL_SMALL, 6ull*MB },
+ { MODEL_MEDIUM, 7ull*MB },
+ { MODEL_LARGE, 9ull*MB },
+};
+
static const std::map<e_model, size_t> MEM_REQ_MODEL = {
{ MODEL_TINY, 74ull*MB },
{ MODEL_BASE, 142ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
- { MODEL_TINY, 80ull*MB },
- { MODEL_BASE, 128ull*MB },
- { MODEL_SMALL, 300ull*MB },
- { MODEL_MEDIUM, 680ull*MB },
- { MODEL_LARGE, 1100ull*MB },
-};
-
-static const std::map<e_model, size_t> MEM_REQ_ENCODE_LAYER = {
- { MODEL_TINY, 104ull*MB },
- { MODEL_BASE, 138ull*MB },
- { MODEL_SMALL, 208ull*MB },
- { MODEL_MEDIUM, 280ull*MB },
- { MODEL_LARGE, 354ull*MB },
+ { MODEL_TINY, 6ull*MB },
+ { MODEL_BASE, 8ull*MB },
+ { MODEL_SMALL, 13ull*MB },
+ { MODEL_MEDIUM, 22ull*MB },
+ { MODEL_LARGE, 33ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_DECODE = {
- { MODEL_TINY, 200ull*MB },
- { MODEL_BASE, 202ull*MB },
- { MODEL_SMALL, 204ull*MB },
- { MODEL_MEDIUM, 206ull*MB },
- { MODEL_LARGE, 208ull*MB },
-};
-
-static const std::map<e_model, size_t> MEM_REQ_DECODE_LAYER = {
- { MODEL_TINY, 32ull*MB },
- { MODEL_BASE, 44ull*MB },
- { MODEL_SMALL, 64ull*MB },
- { MODEL_MEDIUM, 84ull*MB },
- { MODEL_LARGE, 110ull*MB },
+ { MODEL_TINY, 3ull*MB },
+ { MODEL_BASE, 5ull*MB },
+ { MODEL_SMALL, 10ull*MB },
+ { MODEL_MEDIUM, 18ull*MB },
+ { MODEL_LARGE, 27ull*MB },
};
struct whisper_mel {
int64_t t_decode_us = 0;
int64_t t_start_us = 0;
+ int32_t n_sample = 0; // number of tokens sampled
+ int32_t n_encode = 0; // number of encoder calls
+ int32_t n_decode = 0; // number of decoder calls
+ int32_t n_fail_p = 0; // number of logprob threshold failures
+ int32_t n_fail_h = 0; // number of entropy threshold failures
+
ggml_type wtype; // weight type (FP32 or FP16)
whisper_mel mel;
// memory buffers used by encode / decode contexts
std::vector<uint8_t> buf_compute;
- std::vector<uint8_t> buf_compute_layer;
+ std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
+
+ int buf_last = 0;
+ size_t buf_max_size[WHISPER_MAX_SCRATCH_BUFFERS] = { 0 };
// decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits;
mutable std::mt19937 rng; // used for sampling at t > 0.0
+ int lang_id;
+
// [EXPERIMENTAL] token-level timestamps data
int64_t t_beg;
int64_t t_last;
// [EXPERIMENTAL] speed-up techniques
int32_t exp_n_audio_ctx; // 0 - use default
+
+ void use_buf(struct ggml_context * ctx, int i) {
+#if defined(WHISPER_USE_SCRATCH)
+ size_t last_size = 0;
+
+ if (i == -1) {
+ last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, });
+ } else {
+ auto & buf = buf_scratch[i];
+ last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), });
+ }
+
+ if (buf_last >= 0) {
+ buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
+ }
+
+ buf_last = i;
+#else
+ (void) i;
+ (void) ctx;
+#endif
+ }
+
+ size_t get_buf_max_mem(int i) const {
+#if defined(WHISPER_USE_SCRATCH)
+ return buf_max_size[i];
+#else
+ (void) i;
+ return 0;
+#endif
+ }
};
template<typename T>
static void read_safe(whisper_model_loader * loader, T & dest) {
loader->read(loader->context, &dest, sizeof(T));
+ BYTESWAP_VALUE(dest);
}
static bool kv_cache_init(
{
// this is the total memory required to run the inference
const size_t mem_required =
- scale*MEM_REQ_MODEL.at (model.type) +
- scale*MEM_REQ_KV_CROSS.at (model.type) +
- scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)) +
- scale*std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type));
+ MEM_REQ_SCRATCH0.at (model.type) +
+ MEM_REQ_SCRATCH1.at (model.type) +
+ MEM_REQ_SCRATCH2.at (model.type) +
+ MEM_REQ_SCRATCH3.at (model.type) +
+ scale*MEM_REQ_MODEL.at (model.type) +
+ scale*MEM_REQ_KV_CROSS.at(model.type) +
+ scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
// this is the memory required by one decoder
const size_t mem_required_decoder =
fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
}
- wctx.buf_compute.resize (scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
- wctx.buf_compute_layer.resize(scale*std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
+ wctx.buf_compute.resize(scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
+
+ wctx.buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(model.type));
+ wctx.buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(model.type));
+ wctx.buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(model.type));
+ wctx.buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(model.type));
}
// load mel filters
filters.data.resize(filters.n_mel * filters.n_fft);
loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float));
+ BYTESWAP_FILTERS(filters);
}
// load vocab
}
loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
+ BYTESWAP_TENSOR(tensor);
//printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
total_size += ggml_nbytes(tensor);
struct ggml_context * ctx0 = ggml_init(params);
+ wctx.use_buf(ctx0, 0);
+
struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
assert(mel->type == GGML_TYPE_F32);
{
// convolution + gelu
{
+ wctx.use_buf(ctx0, 1);
+
cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
cur = ggml_add(ctx0,
ggml_repeat(ctx0,
cur = ggml_gelu(ctx0, cur);
+ wctx.use_buf(ctx0, 0);
+
cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0,
cur = ggml_gelu(ctx0, cur);
}
+ wctx.use_buf(ctx0, 3);
+
// ===================================================================
// NOTE: experimenting with partial evaluation of the encoder (ignore)
//static int iter = -1;
struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur));
+
// ===================================================================
// original:
for (int il = 0; il < n_layer; ++il) {
const auto & layer = model.layers_encoder[il];
- // create separate context for each layer to reduce memory usage
-
- struct ggml_init_params paramsL;
- paramsL.mem_size = wctx.buf_compute_layer.size();
- paramsL.mem_buffer = wctx.buf_compute_layer.data();
-
- struct ggml_context * ctxL = ggml_init(paramsL);
-
// norm
{
- cur = ggml_norm(ctxL, inpL);
+ wctx.use_buf(ctx0, 0);
+
+ cur = ggml_norm(ctx0, inpL);
// cur = ln_0_w*cur + ln_0_b
- cur = ggml_add(ctxL,
- ggml_mul(ctxL,
- ggml_repeat(ctxL, layer.attn_ln_0_w, cur),
+ cur = ggml_add(ctx0,
+ ggml_mul(ctx0,
+ ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
cur),
- ggml_repeat(ctxL, layer.attn_ln_0_b, cur));
+ ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
}
// self-attention
{
- struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
+ wctx.use_buf(ctx0, 1);
+
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
layer.attn_q_w,
cur);
- Qcur = ggml_add(ctxL,
- ggml_repeat(ctxL,
+ Qcur = ggml_add(ctx0,
+ ggml_repeat(ctx0,
layer.attn_q_b,
Qcur),
Qcur);
- //Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
+ //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
// note: no bias for Key
- struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
layer.attn_k_w,
cur);
- //Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
+ //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
- struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
layer.attn_v_w,
cur);
- Vcur = ggml_add(ctxL,
- ggml_repeat(ctxL,
+ Vcur = ggml_add(ctx0,
+ ggml_repeat(ctx0,
layer.attn_v_b,
Vcur),
Vcur);
// ------
+ wctx.use_buf(ctx0, 0);
+
#ifdef WHISPER_USE_FLASH_ATTN
struct ggml_tensor * Q =
- ggml_permute(ctxL,
- ggml_cpy(ctxL,
+ ggml_permute(ctx0,
+ ggml_cpy(ctx0,
Qcur,
- ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
+ ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3);
struct ggml_tensor * K =
- ggml_permute(ctxL,
- ggml_cpy(ctxL,
+ ggml_permute(ctx0,
+ ggml_cpy(ctx0,
Kcur,
- ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
+ ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3);
struct ggml_tensor * V =
- ggml_cpy(ctxL,
- ggml_permute(ctxL,
- ggml_reshape_3d(ctxL,
+ ggml_cpy(ctx0,
+ ggml_permute(ctx0,
+ ggml_reshape_3d(ctx0,
Vcur,
n_state/n_head, n_head, n_ctx),
1, 2, 0, 3),
- ggml_new_tensor_3d(ctxL, wctx.wtype, n_ctx, n_state/n_head, n_head)
+ ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head)
);
- struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
+ struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
#else
struct ggml_tensor * Q =
- ggml_permute(ctxL,
- ggml_cpy(ctxL,
+ ggml_permute(ctx0,
+ ggml_cpy(ctx0,
Qcur,
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
+ ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3);
struct ggml_tensor * K =
- ggml_permute(ctxL,
- ggml_cpy(ctxL,
+ ggml_permute(ctx0,
+ ggml_cpy(ctx0,
Kcur,
- ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
+ ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3);
// K * Q
- struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
struct ggml_tensor * KQ_scaled =
- ggml_scale(ctxL,
+ ggml_scale(ctx0,
KQ,
- ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
+ ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
);
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_scaled);
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled);
//struct ggml_tensor * V_trans =
- // ggml_permute(ctxL,
- // ggml_cpy(ctxL,
+ // ggml_permute(ctx0,
+ // ggml_cpy(ctx0,
// Vcur,
- // ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
+ // ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
// 1, 2, 0, 3);
- //struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
+ //struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
struct ggml_tensor * V =
- ggml_cpy(ctxL,
- ggml_permute(ctxL,
- ggml_reshape_3d(ctxL,
+ ggml_cpy(ctx0,
+ ggml_permute(ctx0,
+ ggml_reshape_3d(ctx0,
Vcur,
n_state/n_head, n_head, n_ctx),
0, 2, 1, 3),
- ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_ctx, n_head)
+ ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_ctx, n_head)
);
- struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_transpose(ctx0, V), KQ_soft_max);
#endif
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
- struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
+ wctx.use_buf(ctx0, 1);
- cur = ggml_cpy(ctxL,
+ cur = ggml_cpy(ctx0,
KQV_merged,
- ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, n_ctx));
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
}
// projection
{
- cur = ggml_mul_mat(ctxL,
+ wctx.use_buf(ctx0, 0);
+
+ cur = ggml_mul_mat(ctx0,
layer.attn_ln_1_w,
cur);
- cur = ggml_add(ctxL,
- ggml_repeat(ctxL, layer.attn_ln_1_b, cur),
+ wctx.use_buf(ctx0, 1);
+
+ cur = ggml_add(ctx0,
+ ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
cur);
}
+ wctx.use_buf(ctx0, 2);
+
// add the input
- cur = ggml_add(ctxL, cur, inpL);
+ cur = ggml_add(ctx0, cur, inpL);
struct ggml_tensor * inpFF = cur;
{
// norm
{
- cur = ggml_norm(ctxL, inpFF);
+ wctx.use_buf(ctx0, 0);
+
+ cur = ggml_norm(ctx0, inpFF);
+
+ wctx.use_buf(ctx0, 1);
// cur = mlp_ln_w*cur + mlp_ln_b
- cur = ggml_add(ctxL,
- ggml_mul(ctxL,
- ggml_repeat(ctxL, layer.mlp_ln_w, cur),
+ cur = ggml_add(ctx0,
+ ggml_mul(ctx0,
+ ggml_repeat(ctx0, layer.mlp_ln_w, cur),
cur),
- ggml_repeat(ctxL, layer.mlp_ln_b, cur));
+ ggml_repeat(ctx0, layer.mlp_ln_b, cur));
}
#ifdef WHISPER_USE_FLASH_FF
- cur = ggml_flash_ff(ctxL,
- ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, wctx.wtype, n_state, N)),
+ wctx.use_buf(ctx0, 0);
+
+ cur = ggml_flash_ff(ctx0,
+ ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wctx.wtype, n_state, n_ctx)),
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
#else
+ wctx.use_buf(ctx0, 0);
+
// fully connected
- cur = ggml_mul_mat(ctxL,
+ cur = ggml_mul_mat(ctx0,
layer.mlp_0_w,
cur);
- cur = ggml_add(ctxL,
- ggml_repeat(ctxL, layer.mlp_0_b, cur),
+ wctx.use_buf(ctx0, 1);
+
+ cur = ggml_add(ctx0,
+ ggml_repeat(ctx0, layer.mlp_0_b, cur),
cur);
+ wctx.use_buf(ctx0, 0);
+
// GELU activation
- cur = ggml_gelu(ctxL, cur);
+ cur = ggml_gelu(ctx0, cur);
+
+ wctx.use_buf(ctx0, 1);
// projection
- cur = ggml_mul_mat(ctxL,
+ cur = ggml_mul_mat(ctx0,
layer.mlp_1_w,
cur);
- cur = ggml_add(ctxL,
- ggml_repeat(ctxL, layer.mlp_1_b, cur),
+ wctx.use_buf(ctx0, 0);
+
+ cur = ggml_add(ctx0,
+ ggml_repeat(ctx0, layer.mlp_1_b, cur),
cur);
#endif
}
- // output from this layer
- struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF);
-
- {
- struct ggml_cgraph gf = {};
- gf.n_threads = n_threads;
-
- ggml_build_forward_expand(&gf, inpO);
- ggml_graph_compute (ctxL, &gf);
-
- //ggml_graph_print(&gf);
- }
-
- // TODO: this is a hack to have per-layer computation graphs - need to come up with something better
- // input for next layer (inpO -> inpL)
- memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
- inpL->op = GGML_OP_NONE;
- inpL->src0 = nullptr;
- inpL->src1 = nullptr;
-
- //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
+ wctx.use_buf(ctx0, 3);
- ggml_free(ctxL);
+ inpL = ggml_add(ctx0, cur, inpFF);
}
cur = inpL;
// norm
{
+ wctx.use_buf(ctx0, 0);
+
cur = ggml_norm(ctx0, cur);
+ wctx.use_buf(ctx0, 1);
+
// cur = ln_f_g*cur + ln_f_b
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.e_ln_b, cur));
}
+ wctx.use_buf(ctx0, -1);
+
// run the computation
{
struct ggml_cgraph gf = {};
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
auto & layer = model.layers_decoder[il];
+ wctx.use_buf(ctx0, 0);
+
struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
layer.cross_attn_k_w,
cur);
Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
+ wctx.use_buf(ctx0, 1);
+
struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
layer.cross_attn_v_w,
cur);
Vcross),
Vcross);
+ wctx.use_buf(ctx0, -1);
+
//struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
//struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*n_ctx));
////////////////////////////////////////////////////////////////////////////
- //printf("%s: used_mem = %f MB\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0);
+ //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
+ // ggml_used_mem(ctx0)/1024.0/1024.0,
+ // wctx.get_buf_max_mem(0)/1024.0/1024.0,
+ // wctx.get_buf_max_mem(1)/1024.0/1024.0,
+ // wctx.get_buf_max_mem(2)/1024.0/1024.0,
+ // wctx.get_buf_max_mem(3)/1024.0/1024.0);
ggml_free(ctx0);
wctx.t_encode_us += ggml_time_us() - t_start_us;
+ wctx.n_encode++;
return true;
}
// evaluate the decoder
//
-// given text prompt + audio features -> predicts the probabilities for the next token
+// given text prompt + audio features -> computes the logits for the next token
//
// - model: the model
// - n_threads: number of threads to use
struct ggml_context * ctx0 = ggml_init(params);
+ struct ggml_cgraph gf = {};
+ gf.n_threads = n_threads;
+
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(embd->data, tokens, N*ggml_element_size(embd));
((int32_t *) position->data)[i] = n_past + i;
}
+ wctx.use_buf(ctx0, 3);
+
// token encoding + position encoding
struct ggml_tensor * cur =
ggml_add(ctx0,
for (int il = 0; il < n_layer; ++il) {
const auto & layer = model.layers_decoder[il];
- struct ggml_init_params paramsL;
- paramsL.mem_size = wctx.buf_compute_layer.size();
- paramsL.mem_buffer = wctx.buf_compute_layer.data();
-
- struct ggml_context * ctxL = ggml_init(paramsL);
- struct ggml_cgraph gf = {};
- gf.n_threads = n_threads;
-
// norm
{
- cur = ggml_norm(ctxL, inpL);
+ wctx.use_buf(ctx0, 0);
+
+ cur = ggml_norm(ctx0, inpL);
// cur = ln_0_w*cur + ln_0_b
- cur = ggml_add(ctxL,
- ggml_mul(ctxL,
- ggml_repeat(ctxL, layer.attn_ln_0_w, cur),
+ cur = ggml_add(ctx0,
+ ggml_mul(ctx0,
+ ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
cur),
- ggml_repeat(ctxL, layer.attn_ln_0_b, cur));
+ ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
}
// self-attention
{
- struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
+ wctx.use_buf(ctx0, 1);
+
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
layer.attn_q_w,
cur);
- Qcur = ggml_add(ctxL,
- ggml_repeat(ctxL,
+ Qcur = ggml_add(ctx0,
+ ggml_repeat(ctx0,
layer.attn_q_b,
Qcur),
Qcur);
- Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
+ Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
// note: no bias for Key
- struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
layer.attn_k_w,
cur);
- Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
+ Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
- struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
layer.attn_v_w,
cur);
- Vcur = ggml_add(ctxL,
- ggml_repeat(ctxL,
+ Vcur = ggml_add(ctx0,
+ ggml_repeat(ctx0,
layer.attn_v_b,
Vcur),
Vcur);
// store key and value to memory
{
- struct ggml_tensor * k = ggml_view_1d(ctxL, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
- struct ggml_tensor * v = ggml_view_1d(ctxL, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past));
+ struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
+ struct ggml_tensor * v = ggml_view_1d(ctx0, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past));
- ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k));
- ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v));
+ ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
+ ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
}
// ------
+ wctx.use_buf(ctx0, 0);
+
struct ggml_tensor * Q =
- ggml_permute(ctxL,
- ggml_cpy(ctxL,
+ ggml_permute(ctx0,
+ ggml_cpy(ctx0,
Qcur,
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
+ ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)),
0, 2, 1, 3);
struct ggml_tensor * K =
- ggml_permute(ctxL,
- ggml_reshape_3d(ctxL,
- ggml_view_1d(ctxL, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state),
+ ggml_permute(ctx0,
+ ggml_reshape_3d(ctx0,
+ ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state),
n_state/n_head, n_head, n_past + N),
0, 2, 1, 3);
+ wctx.use_buf(ctx0, 1);
+
// K * Q
- struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+
+ wctx.use_buf(ctx0, 0);
//struct ggml_tensor * KQ_scaled =
- // ggml_scale(ctxL,
+ // ggml_scale(ctx0,
// KQ,
- // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
+ // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
// );
- struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ, n_past);
+ struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
+
+ wctx.use_buf(ctx0, 1);
+
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked);
+ wctx.use_buf(ctx0, 0);
struct ggml_tensor * V_trans =
- ggml_permute(ctxL,
- ggml_reshape_3d(ctxL,
- ggml_view_1d(ctxL, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state),
+ ggml_permute(ctx0,
+ ggml_reshape_3d(ctx0,
+ ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state),
n_state/n_head, n_head, n_past + N),
1, 2, 0, 3);
- struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
+ wctx.use_buf(ctx0, 1);
- struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
- cur = ggml_cpy(ctxL,
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+
+ cur = ggml_cpy(ctx0,
KQV_merged,
- ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
}
+ // projection
{
- cur = ggml_mul_mat(ctxL,
+ wctx.use_buf(ctx0, 0);
+
+ cur = ggml_mul_mat(ctx0,
layer.attn_ln_1_w,
cur);
- cur = ggml_add(ctxL,
- ggml_repeat(ctxL, layer.attn_ln_1_b, cur),
+ wctx.use_buf(ctx0, 1);
+
+ cur = ggml_add(ctx0,
+ ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
cur);
}
+ wctx.use_buf(ctx0, 2);
+
// add the input
- struct ggml_tensor * inpCA = ggml_add(ctxL, cur, inpL);
+ struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
// norm
{
- cur = ggml_norm(ctxL, inpCA); // note: we use inpCA here
+ wctx.use_buf(ctx0, 0);
+
+ cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
+
+ wctx.use_buf(ctx0, 1);
// cur = ln_0_w*cur + ln_0_b
- cur = ggml_add(ctxL,
- ggml_mul(ctxL,
- ggml_repeat(ctxL, layer.cross_attn_ln_0_w, cur),
+ cur = ggml_add(ctx0,
+ ggml_mul(ctx0,
+ ggml_repeat(ctx0, layer.cross_attn_ln_0_w, cur),
cur),
- ggml_repeat(ctxL, layer.cross_attn_ln_0_b, cur));
+ ggml_repeat(ctx0, layer.cross_attn_ln_0_b, cur));
}
// cross-attention
{
- struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
+ wctx.use_buf(ctx0, 0);
+
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
layer.cross_attn_q_w,
cur);
- Qcur = ggml_add(ctxL,
- ggml_repeat(ctxL,
+ Qcur = ggml_add(ctx0,
+ ggml_repeat(ctx0,
layer.cross_attn_q_b,
Qcur),
Qcur);
- Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
+ Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
// Kcross is already scaled
struct ggml_tensor * Kcross =
- ggml_reshape_3d(ctxL,
- ggml_view_1d(ctxL, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state),
+ ggml_reshape_3d(ctx0,
+ ggml_view_1d(ctx0, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state),
n_state/n_head, n_head, M);
struct ggml_tensor * Vcross =
- ggml_reshape_3d(ctxL,
- ggml_view_1d(ctxL, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state),
+ ggml_reshape_3d(ctx0,
+ ggml_view_1d(ctx0, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state),
n_state/n_head, n_head, M);
+ struct ggml_tensor * V_trans = ggml_permute(ctx0, Vcross, 1, 2, 0, 3);
+
// ------
+ wctx.use_buf(ctx0, 1);
+
struct ggml_tensor * Q =
- ggml_permute(ctxL,
- ggml_cpy(ctxL,
+ ggml_permute(ctx0,
+ ggml_cpy(ctx0,
Qcur,
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
+ ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)),
0, 2, 1, 3);
- struct ggml_tensor * K = ggml_permute(ctxL, Kcross, 0, 2, 1, 3);
+ struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
+
+ wctx.use_buf(ctx0, 0);
// K * Q
- struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
//struct ggml_tensor * KQ_scaled =
- // ggml_scale(ctxL,
+ // ggml_scale(ctx0,
// KQ,
- // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
+ // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
// );
// no masking for cross-attention
- //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past);
+ //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ);
+ wctx.use_buf(ctx0, 1);
- struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3);
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
- struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
+ wctx.use_buf(ctx0, 0);
- struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
+
+ wctx.use_buf(ctx0, 1);
+
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
// cur = KQV_merged.contiguous().view(n_state, N)
- cur = ggml_cpy(ctxL,
+ cur = ggml_cpy(ctx0,
KQV_merged,
- ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
}
// projection
{
- cur = ggml_mul_mat(ctxL,
+ wctx.use_buf(ctx0, 0);
+
+ cur = ggml_mul_mat(ctx0,
layer.cross_attn_ln_1_w,
cur);
- cur = ggml_add(ctxL,
- ggml_repeat(ctxL, layer.cross_attn_ln_1_b, cur),
+ wctx.use_buf(ctx0, 1);
+
+ cur = ggml_add(ctx0,
+ ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
cur);
}
+ wctx.use_buf(ctx0, 2);
+
// add the input
- cur = ggml_add(ctxL, cur, inpCA);
+ cur = ggml_add(ctx0, cur, inpCA);
struct ggml_tensor * inpFF = cur;
{
// norm
{
- cur = ggml_norm(ctxL, inpFF);
+ wctx.use_buf(ctx0, 0);
+
+ cur = ggml_norm(ctx0, inpFF);
+
+ wctx.use_buf(ctx0, 1);
// cur = mlp_ln_w*cur + mlp_ln_b
- cur = ggml_add(ctxL,
- ggml_mul(ctxL,
- ggml_repeat(ctxL, layer.mlp_ln_w, cur),
+ cur = ggml_add(ctx0,
+ ggml_mul(ctx0,
+ ggml_repeat(ctx0, layer.mlp_ln_w, cur),
cur),
- ggml_repeat(ctxL, layer.mlp_ln_b, cur));
+ ggml_repeat(ctx0, layer.mlp_ln_b, cur));
}
+ wctx.use_buf(ctx0, 0);
+
// fully connected
- cur = ggml_mul_mat(ctxL,
+ cur = ggml_mul_mat(ctx0,
layer.mlp_0_w,
cur);
- cur = ggml_add(ctxL,
- ggml_repeat(ctxL, layer.mlp_0_b, cur),
+ wctx.use_buf(ctx0, 1);
+
+ cur = ggml_add(ctx0,
+ ggml_repeat(ctx0, layer.mlp_0_b, cur),
cur);
+ wctx.use_buf(ctx0, 0);
+
// GELU activation
- cur = ggml_gelu(ctxL, cur);
+ cur = ggml_gelu(ctx0, cur);
+
+ wctx.use_buf(ctx0, 1);
// projection
- cur = ggml_mul_mat(ctxL,
+ cur = ggml_mul_mat(ctx0,
layer.mlp_1_w,
cur);
- cur = ggml_add(ctxL,
- ggml_repeat(ctxL, layer.mlp_1_b, cur),
- cur);
- }
+ wctx.use_buf(ctx0, 0);
- // output from this layer
- struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF);
-
- {
- ggml_build_forward_expand(&gf, inpO);
- ggml_graph_compute (ctxL, &gf);
-
- //ggml_graph_print(&gf);
+ cur = ggml_add(ctx0,
+ ggml_repeat(ctx0, layer.mlp_1_b, cur),
+ cur);
}
- // TODO: this is a hack to have per-layer computation graphs - need to come up with something better
- // input for next layer (inpO -> inpL)
- memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
- inpL->op = GGML_OP_NONE;
- inpL->src0 = nullptr;
- inpL->src1 = nullptr;
+ wctx.use_buf(ctx0, 3);
- if (N > 1) {
- //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
- }
-
- ggml_free(ctxL);
+ inpL = ggml_add(ctx0, cur, inpFF);
}
cur = inpL;
// norm
{
+ wctx.use_buf(ctx0, 0);
+
cur = ggml_norm(ctx0, cur);
+ wctx.use_buf(ctx0, 1);
+
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.d_ln_w, cur),
ggml_repeat(ctx0, model.d_ln_b, cur));
}
+ wctx.use_buf(ctx0, 0);
+
+ // compute logits only for the last token
+ // comment this line to compute logits for all N tokens
+ // might be useful in the future
+ cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
+
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
+ wctx.use_buf(ctx0, -1);
+
// run the computation
{
- struct ggml_cgraph gf = {};
- gf.n_threads = n_threads;
-
ggml_build_forward_expand(&gf, logits);
ggml_graph_compute (ctx0, &gf);
}
- logits_out.resize(N*n_vocab);
- memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
+ // extract logits for all N tokens
+ //logits_out.resize(N*n_vocab);
+ //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
+
+ // extract logits only for the last token
+ logits_out.resize(n_vocab);
+ memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
if (N > 1) {
- //const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N;
- //printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token);
- //printf("%s: max mem = %f MB\n", __func__, mem_per_token*model.hparams.n_text_ctx);
+ //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
+ // ggml_used_mem(ctx0)/1024.0/1024.0,
+ // wctx.get_buf_max_mem(0)/1024.0/1024.0,
+ // wctx.get_buf_max_mem(1)/1024.0/1024.0,
+ // wctx.get_buf_max_mem(2)/1024.0/1024.0,
+ // wctx.get_buf_max_mem(3)/1024.0/1024.0);
}
ggml_free(ctx0);
wctx.t_decode_us += ggml_time_us() - t_start_us;
+ wctx.n_decode++;
return true;
}
void whisper_print_timings(struct whisper_context * ctx) {
const int64_t t_end_us = ggml_time_us();
+ const int32_t n_sample = std::max(1, ctx->n_sample);
+ const int32_t n_encode = std::max(1, ctx->n_encode);
+ const int32_t n_decode = std::max(1, ctx->n_decode);
+
fprintf(stderr, "\n");
+ fprintf(stderr, "%s: fallbacks = %3d p / %3d h\n", __func__, ctx->n_fail_p, ctx->n_fail_h);
fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us/1000.0f);
fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, ctx->t_mel_us/1000.0f);
- fprintf(stderr, "%s: sample time = %8.2f ms\n", __func__, ctx->t_sample_us/1000.0f);
- fprintf(stderr, "%s: encode time = %8.2f ms / %.2f ms per layer\n", __func__, ctx->t_encode_us/1000.0f, ctx->t_encode_us/1000.0f/ctx->model.hparams.n_audio_layer);
- fprintf(stderr, "%s: decode time = %8.2f ms / %.2f ms per layer\n", __func__, ctx->t_decode_us/1000.0f, ctx->t_decode_us/1000.0f/ctx->model.hparams.n_text_layer);
+ fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_sample_us, n_sample, 1e-3f*ctx->t_sample_us/n_sample);
+ fprintf(stderr, "%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_encode_us, n_encode, 1e-3f*ctx->t_encode_us/n_encode);
+ fprintf(stderr, "%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_decode_us, n_decode, 1e-3f*ctx->t_decode_us/n_decode);
fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
}
struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
struct whisper_full_params result = {
- /*.strategy =*/ WHISPER_SAMPLING_GREEDY,
+ /*.strategy =*/ strategy,
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
/*.n_max_text_ctx =*/ 16384,
/*.thold_pt =*/ 0.01f,
/*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0,
+ /*.split_on_word =*/ false,
/*.max_tokens =*/ 0,
/*.speed_up =*/ false,
/*.language =*/ "en",
/*.suppress_blank =*/ true,
+ /*.suppress_non_speech_tokens =*/true,
/*.temperature =*/ 0.0f,
/*.max_initial_ts =*/ 1.0f,
float thold_pt,
float thold_ptsum);
+// trim from start (in place)
+static inline void ltrim(std::string &s) {
+ s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) {
+ return !std::isspace(ch);
+ }));
+}
+
+// trim from end (in place)
+static inline void rtrim(std::string &s) {
+ s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) {
+ return !std::isspace(ch);
+ }).base(), s.end());
+}
+
+// trim from both ends (in place)
+static inline void trim(std::string &s) {
+ rtrim(s);
+ ltrim(s);
+}
+
+static inline bool should_split_on_word(const char * txt, bool split_on_word) {
+ if (!split_on_word) return true;
+
+ return txt[0] == ' ';
+}
+
// wrap the last segment to max_len characters
// returns the number of new segments
-static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
+static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool split_on_word) {
auto segment = ctx.result_all.back();
int res = 1;
}
const auto txt = whisper_token_to_str(&ctx, token.id);
-
const int cur = strlen(txt);
- if (acc + cur > max_len && i > 0) {
+ if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) {
// split here
+ if (split_on_word) {
+ trim(text);
+ }
+
ctx.result_all.back().text = std::move(text);
ctx.result_all.back().t1 = token.t0;
ctx.result_all.back().tokens.resize(i);
}
}
+ if (split_on_word) {
+ trim(text);
+ }
ctx.result_all.back().text = std::move(text);
return res;
}
+static const std::vector<std::string> non_speech_tokens
+{
+ "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
+ "_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--",
+ "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
+ "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
+};
+
// process the logits for the selected decoder
// - applies logit filters
// - computes logprobs and probs
logits[vocab.token_sot] = -INFINITY;
logits[vocab.token_solm] = -INFINITY;
+ // suppress task tokens
+ logits[vocab.token_translate] = -INFINITY;
+ logits[vocab.token_transcribe] = -INFINITY;
+
+
+ // suppress non-speech tokens
+ // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
+ if (params.suppress_non_speech_tokens)
+ {
+ for (const std::string &token : non_speech_tokens)
+ {
+ std::string suppress_tokens[] = {token, " " + token};
+ for (const std::string &suppress_token : suppress_tokens)
+ {
+ if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end())
+ {
+ logits[vocab.token_to_id.at(suppress_token)] = -INFINITY;
+ }
+ }
+ }
+ // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
+ if (vocab.token_to_id.find(" -") != vocab.token_to_id.end())
+ {
+ logits[vocab.token_to_id.at(" -")] = -INFINITY;
+ }
+ if (vocab.token_to_id.find(" '") != vocab.token_to_id.end())
+ {
+ logits[vocab.token_to_id.at(" '")] = -INFINITY;
+ }
+ }
+
// timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
// https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424
{
}
}
+ // condition timestamp tokens to be increasing
+ // ref: https://github.com/openai/whisper/pull/831#issuecomment-1385910556
+ if (decoder.has_ts) {
+ const int tid0 = decoder.seek_delta/2;
+
+ for (int i = vocab.token_beg; i < vocab.token_beg + tid0; ++i) {
+ logits[i] = -INFINITY;
+ }
+ }
+
// populate the logprobs array (log_softmax)
{
const float logit_max = *std::max_element(logits.begin(), logits.end());
}
static whisper_token_data whisper_sample_token(
- const whisper_context & ctx,
+ whisper_context & ctx,
const whisper_decoder & decoder,
bool best) {
whisper_token_data result = {
result.pt = result.p;
}
+ ctx.n_sample++;
+
return result;
}
std::vector<whisper_token_data> result;
result.reserve(k);
- whisper_token tid;
+ whisper_token tid = vocab.token_beg;
- float pt;
- float ptsum;
+ float pt = 0.0;
+ float ptsum = 0.0;
{
double sum_ts = 0.0;
}
}
+ ctx.n_sample++;
+
return result;
}
fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
return -3;
}
-
+ ctx->lang_id = lang_id;
params.language = whisper_lang_str(lang_id);
fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
if (whisper_is_multilingual(ctx)) {
const int lang_id = whisper_lang_id(params.language);
+ ctx->lang_id = lang_id;
prompt_init.push_back(whisper_token_lang(ctx, lang_id));
if (params.translate) {
prompt_init.push_back(whisper_token_translate());
prompt.clear();
// if we have already generated some text, use it as a prompt to condition the next generation
- if (!prompt_past.empty() && t_cur > 0.5f) {
+ if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) {
int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
prompt = { whisper_token_prev(ctx) };
prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
// print the prompt
- //WHISPER_PRINT_DEBUG("\n\n");
- //for (int i = 0; i < (int) prompt.size(); i++) {
- // WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str());
- //}
- //WHISPER_PRINT_DEBUG("\n\n");
+ WHISPER_PRINT_DEBUG("\n\n");
+ for (int i = 0; i < (int) prompt.size(); i++) {
+ WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str());
+ }
+ WHISPER_PRINT_DEBUG("\n\n");
if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
fprintf(stderr, "%s: failed to decode\n", __func__);
return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
});
- int cur_c = 0;
+ unsigned int cur_c = 0;
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j];
auto & cur = beam_candidates[cur_c++];
- while (beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
+ while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
++cur_c;
}
WHISPER_PRINT_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n",
__func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy);
- if (decoder.sequence.result_len > 8 && decoder.sequence.entropy < params.entropy_thold) {
+ if (decoder.sequence.result_len > 32 && decoder.sequence.entropy < params.entropy_thold) {
WHISPER_PRINT_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n",
__func__, j, decoder.sequence.entropy, params.entropy_thold);
decoder.failed = true;
+ ctx->n_fail_h++;
continue;
}
if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
success = false;
+ ctx->n_fail_p++;
}
if (success) {
if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
+
if (!text.empty()) {
const auto tt0 = params.speed_up ? 2*t0 : t0;
const auto tt1 = params.speed_up ? 2*t1 : t1;
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
if (params.max_len > 0) {
- n_new = whisper_wrap_segment(*ctx, params.max_len);
+ n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
}
}
if (params.new_segment_callback) {
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
if (params.max_len > 0) {
- n_new = whisper_wrap_segment(*ctx, params.max_len);
+ n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
}
}
if (params.new_segment_callback) {
return ctx->result_all.size();
}
+int whisper_full_lang_id(struct whisper_context * ctx) {
+ return ctx->lang_id;
+}
+
int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
return ctx->result_all[i_segment].t0;
}
// =================================================================================================
+//
+// Temporary interface needed for exposing ggml interface
+// Will be removed in the future when ggml becomes a separate library
+//
+
+WHISPER_API int whisper_bench_memcpy(int n_threads) {
+ ggml_time_init();
+
+ size_t n = 50;
+ size_t arr = n_threads > 0 ? 1024 : n_threads; // trick to avoid compiler optimizations
+
+ // 1 GB array
+ const size_t size = arr*1024llu*1024llu;
+
+ char * src = (char *) malloc(size);
+ char * dst = (char *) malloc(size);
+
+ for (size_t i = 0; i < size; i++) src[i] = i;
+
+ memcpy(dst, src, size); // heat-up
+
+ double tsum = 0.0;
+
+ for (size_t i = 0; i < n; i++) {
+ const int64_t t0 = ggml_time_us();
+
+ memcpy(dst, src, size);
+
+ const int64_t t1 = ggml_time_us();
+
+ tsum += (t1 - t0)*1e-6;
+
+ src[0] = rand();
+ }
+
+ fprintf(stderr, "memcpy: %.2f GB/s\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu));
+
+ // needed to prevent the compile from optimizing the memcpy away
+ {
+ double sum = 0.0;
+
+ for (size_t i = 0; i < size; i++) sum += dst[i];
+
+ fprintf(stderr, "sum: %s %f\n", sum == -536870910.00 ? "ok" : "error", sum);
+ }
+
+ free(src);
+ free(dst);
+
+ return 0;
+}
+
+WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
+ ggml_time_init();
+
+ const int n_max = 128;
+
+ const std::vector<size_t> sizes = {
+ 64, 128, 256, 512, 1024, 2048, 4096,
+ };
+
+ const size_t N_max = sizes.back();
+
+ // a: N*N*sizeof(float)
+ // b: N*N*sizeof(float)
+ // c: N*N*sizeof(float)
+ // when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
+ std::vector<char> buf(4llu*N_max*N_max*sizeof(float) + 4*256);
+
+ for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
+
+ for (int j = 0; j < (int) sizes.size(); j++) {
+ int n_fp16 = 0;
+ int n_fp32 = 0;
+
+ // GFLOPS/s
+ double s_fp16 = 0.0;
+ double s_fp32 = 0.0;
+
+ const size_t N = sizes[j];
+
+ for (int k = 0; k < 2; ++k) {
+ const ggml_type wtype = k == 0 ? GGML_TYPE_F16 : GGML_TYPE_F32;
+
+ double & s = k == 0 ? s_fp16 : s_fp32;
+ int & n = k == 0 ? n_fp16 : n_fp32;
+
+ struct ggml_init_params gparams = {
+ /*.mem_size =*/ buf.size(),
+ /*.mem_buffer =*/ buf.data(),
+ };
+
+ struct ggml_context * ctx0 = ggml_init(gparams);
+
+ struct ggml_tensor * a = ggml_new_tensor_2d(ctx0, wtype, N, N);
+ struct ggml_tensor * b = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, N, N);
+
+ struct ggml_tensor * c = ggml_mul_mat(ctx0, a, b);
+
+ struct ggml_cgraph gf = ggml_build_forward(c);
+
+ gf.n_threads = n_threads;
+
+ double tsum = 0.0;
+
+ // heat-up
+ ggml_graph_compute(ctx0, &gf);
+
+ for (int i = 0; i < n_max; ++i) {
+ const int64_t t0 = ggml_time_us();
+
+ ggml_graph_compute(ctx0, &gf);
+
+ const int64_t t1 = ggml_time_us();
+
+ tsum += (t1 - t0)*1e-6;
+ n++;
+
+ if (tsum > 1.0 && n >= 3) {
+ break;
+ }
+ }
+
+ ggml_free(ctx0);
+
+ s = ((2.0*N*N*N*n)/tsum)*1e-9;
+ }
+
+ fprintf(stderr, "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n",
+ N, N, s_fp16, n_fp16, s_fp32, n_fp32);
+ }
+
+ return 0;
+}
+
+// =================================================================================================
+
+// =================================================================================================
+
//
// Experimental stuff below
//