From: Georgi Gerganov Date: Tue, 5 Sep 2023 10:55:06 +0000 (+0300) Subject: whisper : sync (match OpenAI input, convert, new features) (#495) X-Git-Tag: upstream/0.0.1642~1252 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=755be569c970032c0a1b72cae251dadda58ca263;p=pkg%2Fggml%2Fsources%2Fggml whisper : sync (match OpenAI input, convert, new features) (#495) ggml-ci --- diff --git a/examples/whisper/convert-pt-to-ggml.py b/examples/whisper/convert-pt-to-ggml.py index 07752e75..9aa134b5 100644 --- a/examples/whisper/convert-pt-to-ggml.py +++ b/examples/whisper/convert-pt-to-ggml.py @@ -39,132 +39,133 @@ import json import code import torch import numpy as np - -from transformers import GPTJForCausalLM -from transformers import GPT2TokenizerFast +import base64 +from pathlib import Path +#from transformers import GPTJForCausalLM +#from transformers import GPT2TokenizerFast # ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L10-L110 -LANGUAGES = { - "en": "english", - "zh": "chinese", - "de": "german", - "es": "spanish", - "ru": "russian", - "ko": "korean", - "fr": "french", - "ja": "japanese", - "pt": "portuguese", - "tr": "turkish", - "pl": "polish", - "ca": "catalan", - "nl": "dutch", - "ar": "arabic", - "sv": "swedish", - "it": "italian", - "id": "indonesian", - "hi": "hindi", - "fi": "finnish", - "vi": "vietnamese", - "iw": "hebrew", - "uk": "ukrainian", - "el": "greek", - "ms": "malay", - "cs": "czech", - "ro": "romanian", - "da": "danish", - "hu": "hungarian", - "ta": "tamil", - "no": "norwegian", - "th": "thai", - "ur": "urdu", - "hr": "croatian", - "bg": "bulgarian", - "lt": "lithuanian", - "la": "latin", - "mi": "maori", - "ml": "malayalam", - "cy": "welsh", - "sk": "slovak", - "te": "telugu", - "fa": "persian", - "lv": "latvian", - "bn": "bengali", - "sr": "serbian", - "az": "azerbaijani", - "sl": "slovenian", - "kn": "kannada", - "et": "estonian", - "mk": "macedonian", - "br": "breton", - "eu": "basque", - "is": "icelandic", - "hy": "armenian", - "ne": "nepali", - "mn": "mongolian", - "bs": "bosnian", - "kk": "kazakh", - "sq": "albanian", - "sw": "swahili", - "gl": "galician", - "mr": "marathi", - "pa": "punjabi", - "si": "sinhala", - "km": "khmer", - "sn": "shona", - "yo": "yoruba", - "so": "somali", - "af": "afrikaans", - "oc": "occitan", - "ka": "georgian", - "be": "belarusian", - "tg": "tajik", - "sd": "sindhi", - "gu": "gujarati", - "am": "amharic", - "yi": "yiddish", - "lo": "lao", - "uz": "uzbek", - "fo": "faroese", - "ht": "haitian creole", - "ps": "pashto", - "tk": "turkmen", - "nn": "nynorsk", - "mt": "maltese", - "sa": "sanskrit", - "lb": "luxembourgish", - "my": "myanmar", - "bo": "tibetan", - "tl": "tagalog", - "mg": "malagasy", - "as": "assamese", - "tt": "tatar", - "haw": "hawaiian", - "ln": "lingala", - "ha": "hausa", - "ba": "bashkir", - "jw": "javanese", - "su": "sundanese", -} - -# ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L273-L292 -def build_tokenizer(path_to_whisper_repo: str, name: str = "gpt2"): - os.environ["TOKENIZERS_PARALLELISM"] = "false" - path = os.path.join(path_to_whisper_repo, "whisper/assets", name) - tokenizer = GPT2TokenizerFast.from_pretrained(path) - - specials = [ - "<|startoftranscript|>", - *[f"<|{lang}|>" for lang in LANGUAGES.keys()], - "<|translate|>", - "<|transcribe|>", - "<|startoflm|>", - "<|startofprev|>", - "<|nocaptions|>", - "<|notimestamps|>", - ] - - tokenizer.add_special_tokens(dict(additional_special_tokens=specials)) - return tokenizer +#LANGUAGES = { +# "en": "english", +# "zh": "chinese", +# "de": "german", +# "es": "spanish", +# "ru": "russian", +# "ko": "korean", +# "fr": "french", +# "ja": "japanese", +# "pt": "portuguese", +# "tr": "turkish", +# "pl": "polish", +# "ca": "catalan", +# "nl": "dutch", +# "ar": "arabic", +# "sv": "swedish", +# "it": "italian", +# "id": "indonesian", +# "hi": "hindi", +# "fi": "finnish", +# "vi": "vietnamese", +# "iw": "hebrew", +# "uk": "ukrainian", +# "el": "greek", +# "ms": "malay", +# "cs": "czech", +# "ro": "romanian", +# "da": "danish", +# "hu": "hungarian", +# "ta": "tamil", +# "no": "norwegian", +# "th": "thai", +# "ur": "urdu", +# "hr": "croatian", +# "bg": "bulgarian", +# "lt": "lithuanian", +# "la": "latin", +# "mi": "maori", +# "ml": "malayalam", +# "cy": "welsh", +# "sk": "slovak", +# "te": "telugu", +# "fa": "persian", +# "lv": "latvian", +# "bn": "bengali", +# "sr": "serbian", +# "az": "azerbaijani", +# "sl": "slovenian", +# "kn": "kannada", +# "et": "estonian", +# "mk": "macedonian", +# "br": "breton", +# "eu": "basque", +# "is": "icelandic", +# "hy": "armenian", +# "ne": "nepali", +# "mn": "mongolian", +# "bs": "bosnian", +# "kk": "kazakh", +# "sq": "albanian", +# "sw": "swahili", +# "gl": "galician", +# "mr": "marathi", +# "pa": "punjabi", +# "si": "sinhala", +# "km": "khmer", +# "sn": "shona", +# "yo": "yoruba", +# "so": "somali", +# "af": "afrikaans", +# "oc": "occitan", +# "ka": "georgian", +# "be": "belarusian", +# "tg": "tajik", +# "sd": "sindhi", +# "gu": "gujarati", +# "am": "amharic", +# "yi": "yiddish", +# "lo": "lao", +# "uz": "uzbek", +# "fo": "faroese", +# "ht": "haitian creole", +# "ps": "pashto", +# "tk": "turkmen", +# "nn": "nynorsk", +# "mt": "maltese", +# "sa": "sanskrit", +# "lb": "luxembourgish", +# "my": "myanmar", +# "bo": "tibetan", +# "tl": "tagalog", +# "mg": "malagasy", +# "as": "assamese", +# "tt": "tatar", +# "haw": "hawaiian", +# "ln": "lingala", +# "ha": "hausa", +# "ba": "bashkir", +# "jw": "javanese", +# "su": "sundanese", +#} + +## ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L273-L292 +#def build_tokenizer(path_to_whisper_repo: str, name: str = "gpt2"): +# os.environ["TOKENIZERS_PARALLELISM"] = "false" +# path = os.path.join(path_to_whisper_repo, "whisper/assets", name) +# tokenizer = GPT2TokenizerFast.from_pretrained(path) +# +# specials = [ +# "<|startoftranscript|>", +# *[f"<|{lang}|>" for lang in LANGUAGES.keys()], +# "<|translate|>", +# "<|transcribe|>", +# "<|startoflm|>", +# "<|startofprev|>", +# "<|nocaptions|>", +# "<|notimestamps|>", +# ] +# +# tokenizer.add_special_tokens(dict(additional_special_tokens=specials)) +# return tokenizer # ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py def bytes_to_unicode(): @@ -193,17 +194,17 @@ if len(sys.argv) < 4: print("Usage: convert-pt-to-ggml.py model.pt path-to-whisper-repo dir-output [use-f32]\n") sys.exit(1) -fname_inp = sys.argv[1] -dir_whisper = sys.argv[2] -dir_out = sys.argv[3] +fname_inp = Path(sys.argv[1]) +dir_whisper = Path(sys.argv[2]) +dir_out = Path(sys.argv[3]) # try to load PyTorch binary data try: model_bytes = open(fname_inp, "rb").read() with io.BytesIO(model_bytes) as fp: checkpoint = torch.load(fp, map_location="cpu") -except: - print("Error: failed to load PyTorch model file: %s" % fname_inp) +except Exception: + print("Error: failed to load PyTorch model file:" , fname_inp) sys.exit(1) hparams = checkpoint["dims"] @@ -217,33 +218,52 @@ list_vars = checkpoint["model_state_dict"] # load mel filters n_mels = hparams["n_mels"] -with np.load(os.path.join(dir_whisper, "whisper/assets", "mel_filters.npz")) as f: +with np.load(dir_whisper / "whisper" / "assets" / "mel_filters.npz") as f: filters = torch.from_numpy(f[f"mel_{n_mels}"]) #print (filters) #code.interact(local=locals()) +# load tokenizer +# for backwards compatibility, also check for older hf_transformers format tokenizer files +# old format: dir_whisper/whisper/assets/[multilingual/gpt2]/vocab.json +# new format: dir_whisper/whisper/assets/[multilingual/gpt2].tiktoken multilingual = hparams["n_vocab"] == 51865 -tokenizer = build_tokenizer(dir_whisper, multilingual and "multilingual" or "gpt2") +tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual.tiktoken" or "gpt2.tiktoken") +tokenizer_type = "tiktoken" +if not tokenizer.is_file(): + tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual" or "gpt2") / "vocab.json" + tokenizer_type = "hf_transformers" + if not tokenizer.is_file(): + print("Error: failed to find either tiktoken or hf_transformers tokenizer file:", tokenizer) + sys.exit(1) -#print(tokenizer) -#print(tokenizer.name_or_path) -#print(len(tokenizer.additional_special_tokens)) -dir_tokenizer = tokenizer.name_or_path +byte_encoder = bytes_to_unicode() +byte_decoder = {v:k for k, v in byte_encoder.items()} -# output in the same directory as the model -fname_out = dir_out + "/ggml-model.bin" +if tokenizer_type == "tiktoken": + with open(tokenizer, "rb") as f: + contents = f.read() + tokens = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)} +elif tokenizer_type == "hf_transformers": + with open(tokenizer, "r", encoding="utf8") as f: + _tokens_raw = json.load(f) + if '<|endoftext|>' in _tokens_raw: + # ensures exact same model as tokenizer_type == tiktoken + # details: https://github.com/ggerganov/whisper.cpp/pull/725 + del _tokens_raw['<|endoftext|>'] + tokens = {bytes([byte_decoder[c] for c in token]): int(idx) for token, idx in _tokens_raw.items()} -with open(dir_tokenizer + "/vocab.json", "r") as f: - tokens = json.load(f) +# output in the same directory as the model +fname_out = dir_out / "ggml-model.bin" # use 16-bit or 32-bit floats use_f16 = True if len(sys.argv) > 4: use_f16 = False - fname_out = dir_out + "/ggml-model-f32.bin" + fname_out = dir_out / "ggml-model-f32.bin" -fout = open(fname_out, "wb") +fout = fname_out.open("wb") fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex fout.write(struct.pack("i", hparams["n_vocab"])) @@ -265,32 +285,28 @@ for i in range(filters.shape[0]): for j in range(filters.shape[1]): fout.write(struct.pack("f", filters[i][j])) -byte_encoder = bytes_to_unicode() -byte_decoder = {v:k for k, v in byte_encoder.items()} - +# write tokenizer fout.write(struct.pack("i", len(tokens))) for key in tokens: - text = bytearray([byte_decoder[c] for c in key]) - fout.write(struct.pack("i", len(text))) - fout.write(text) + fout.write(struct.pack("i", len(key))) + fout.write(key) for name in list_vars.keys(): data = list_vars[name].squeeze().numpy() - print("Processing variable: " + name + " with shape: ", data.shape) + print("Processing variable: " , name , " with shape: ", data.shape) # reshape conv bias from [n] to [n, 1] - if name == "encoder.conv1.bias" or \ - name == "encoder.conv2.bias": + if name in ["encoder.conv1.bias", "encoder.conv2.bias"]: data = data.reshape(data.shape[0], 1) - print(" Reshaped variable: " + name + " to shape: ", data.shape) + print(f" Reshaped variable: {name} to shape: ", data.shape) - n_dims = len(data.shape); + n_dims = len(data.shape) # looks like the whisper models are in f16 by default # so we need to convert the small tensors to f32 until we fully support f16 in ggml # ftype == 0 -> float32, ftype == 1 -> float16 - ftype = 1; + ftype = 1 if use_f16: if n_dims < 2 or \ name == "encoder.conv1.bias" or \ @@ -301,9 +317,8 @@ for name in list_vars.keys(): data = data.astype(np.float32) ftype = 0 else: - if n_dims < 3 and data.dtype != np.float32: - data = data.astype(np.float32) - ftype = 0 + data = data.astype(np.float32) + ftype = 0 #if name.startswith("encoder"): # if name.endswith("mlp.0.weight") or \ @@ -312,16 +327,16 @@ for name in list_vars.keys(): # data = data.transpose() # header - str = name.encode('utf-8') - fout.write(struct.pack("iii", n_dims, len(str), ftype)) + str_ = name.encode('utf-8') + fout.write(struct.pack("iii", n_dims, len(str_), ftype)) for i in range(n_dims): fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) - fout.write(str); + fout.write(str_) # data data.tofile(fout) fout.close() -print("Done. Output file: " + fname_out) +print("Done. Output file: " , fname_out) print("") diff --git a/examples/whisper/main.cpp b/examples/whisper/main.cpp index 8dd31d02..fa399c6d 100644 --- a/examples/whisper/main.cpp +++ b/examples/whisper/main.cpp @@ -59,6 +59,7 @@ struct whisper_params { int32_t offset_t_ms = 0; int32_t offset_n = 0; int32_t duration_ms = 0; + int32_t progress_step = 5; int32_t max_context = -1; int32_t max_len = 0; int32_t best_of = 2; @@ -69,6 +70,7 @@ struct whisper_params { float logprob_thold = -1.00f; bool speed_up = false; + bool debug_mode = false; bool translate = false; bool detect_language = false; bool diarize = false; @@ -86,6 +88,7 @@ struct whisper_params { bool print_colors = false; bool print_progress = false; bool no_timestamps = false; + bool log_score = false; std::string language = "en"; std::string prompt; @@ -133,7 +136,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } - else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } + // else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } + else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } @@ -158,6 +162,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); } else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; } + else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params); @@ -187,7 +192,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); - fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); + // fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); + fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); @@ -211,6 +217,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", ""); fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str()); + fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false"); fprintf(stderr, "\n"); } @@ -218,6 +225,7 @@ struct whisper_print_user_data { const whisper_params * params; const std::vector> * pcmf32s; + int progress_prev; }; std::string estimate_diarization_speaker(std::vector> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) { @@ -252,6 +260,14 @@ std::string estimate_diarization_speaker(std::vector> pcmf32s return speaker; } +void whisper_print_progress_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int progress, void * user_data) { + int progress_step = ((whisper_print_user_data *) user_data)->params->progress_step; + int * progress_prev = &(((whisper_print_user_data *) user_data)->progress_prev); + if (progress >= *progress_prev + progress_step) { + *progress_prev += progress_step; + fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress); + } +} void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) { const auto & params = *((whisper_print_user_data *) user_data)->params; @@ -476,6 +492,25 @@ bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_ return true; } +bool output_score(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector> pcmf32s) { + std::ofstream fout(fname); + fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); + + const int n_segments = whisper_full_n_segments(ctx); + // fprintf(stderr,"segments: %d\n",n_segments); + for (int i = 0; i < n_segments; ++i) { + const int n_tokens = whisper_full_n_tokens(ctx, i); + // fprintf(stderr,"tokens: %d\n",n_tokens); + for (int j = 0; j < n_tokens; j++) { + auto token = whisper_full_get_token_text(ctx, i, j); + auto probability = whisper_full_get_token_p(ctx, i, j); + fout << token << '\t' << probability << std::endl; + // fprintf(stderr,"token: %s %f\n",token,probability); + } + } + return true; +} + bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector> pcmf32s) { std::ofstream fout(fname); int indent = 0; @@ -883,6 +918,7 @@ int main(int argc, char ** argv) { wparams.split_on_word = params.split_on_word; wparams.speed_up = params.speed_up; + wparams.debug_mode = params.debug_mode; wparams.tdrz_enable = params.tinydiarize; // [TDRZ] @@ -895,7 +931,7 @@ int main(int argc, char ** argv) { wparams.entropy_thold = params.entropy_thold; wparams.logprob_thold = params.logprob_thold; - whisper_print_user_data user_data = { ¶ms, &pcmf32s }; + whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 }; // this callback is called on each new segment if (!wparams.print_realtime) { @@ -903,6 +939,11 @@ int main(int argc, char ** argv) { wparams.new_segment_callback_user_data = &user_data; } + if (wparams.print_progress) { + wparams.progress_callback = whisper_print_progress_callback; + wparams.progress_callback_user_data = &user_data; + } + // example for abort mechanism // in this example, we do not abort the processing, but we could if the flag is set to true // the callback is called before every encoder run - if it returns false, the processing is aborted @@ -967,6 +1008,12 @@ int main(int argc, char ** argv) { const auto fname_lrc = fname_out + ".lrc"; output_lrc(ctx, fname_lrc.c_str(), params, pcmf32s); } + + // output to score file + if (params.log_score) { + const auto fname_score = fname_out + ".score.txt"; + output_score(ctx, fname_score.c_str(), params, pcmf32s); + } } } diff --git a/examples/whisper/quantize.cpp b/examples/whisper/quantize.cpp index 64e8f35c..b01d6143 100644 --- a/examples/whisper/quantize.cpp +++ b/examples/whisper/quantize.cpp @@ -138,7 +138,7 @@ bool whisper_model_quantize(const std::string & fname_inp, const std::string & f // return false; //} - char word[128]; + char word[129]; for (int i = 0; i < n_vocab; i++) { uint32_t len; diff --git a/examples/whisper/whisper.cpp b/examples/whisper/whisper.cpp index cb124ec9..b50c86d0 100644 --- a/examples/whisper/whisper.cpp +++ b/examples/whisper/whisper.cpp @@ -14,6 +14,7 @@ #define _USE_MATH_DEFINES #include #include +#include #include #include #include @@ -81,7 +82,7 @@ static void byteswap_tensor(ggml_tensor * tensor) { } while (0) #define BYTESWAP_TENSOR(t) \ do { \ - byteswap_tensor(tensor); \ + byteswap_tensor(t); \ } while (0) #else #define BYTESWAP_VALUE(d) do {} while (0) @@ -92,7 +93,7 @@ static void byteswap_tensor(ggml_tensor * tensor) { #define WHISPER_ASSERT(x) \ do { \ if (!(x)) { \ - fprintf(stderr, "WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ + log("WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ abort(); \ } \ } while (0) @@ -589,7 +590,7 @@ struct whisper_model { struct whisper_sequence { std::vector tokens; - // the accumulated transcription in the current interation (used to truncate the tokens array) + // the accumulated transcription in the current iteration (used to truncate the tokens array) int result_len; double sum_logprobs_all; // the sum of the log probabilities of the tokens @@ -724,6 +725,22 @@ struct whisper_context { std::string path_model; // populated by whisper_init_from_file() }; +static void whisper_default_log(const char * text) { + fprintf(stderr, "%s", text); +} + +static whisper_log_callback whisper_log = whisper_default_log; + +// TODO: fix compile warning about "format string is not a string literal" +static void log(const char * fmt, ...) { + if (!whisper_log) return; + char buf[1024]; + va_list args; + va_start(args, fmt); + vsnprintf(buf, sizeof(buf), fmt, args); + whisper_log(buf); +} + template static void read_safe(whisper_model_loader * loader, T & dest) { loader->read(loader->context, &dest, sizeof(T)); @@ -747,7 +764,7 @@ static bool kv_cache_init( cache.ctx = ggml_init(params); if (!cache.ctx) { - fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); + log("%s: failed to allocate memory for kv cache\n", __func__); return false; } @@ -783,7 +800,7 @@ static bool kv_cache_reinit(struct whisper_kv_cache & cache) { cache.ctx = ggml_init(params); if (!cache.ctx) { - fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); + log("%s: failed to allocate memory for kv cache\n", __func__); return false; } @@ -812,7 +829,7 @@ static void kv_cache_free(struct whisper_kv_cache & cache) { // see the convert-pt-to-ggml.py script for details // static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) { - fprintf(stderr, "%s: loading model\n", __func__); + log("%s: loading model\n", __func__); const int64_t t_start_us = ggml_time_us(); @@ -826,7 +843,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con uint32_t magic; read_safe(loader, magic); if (magic != GGML_FILE_MAGIC) { - fprintf(stderr, "%s: invalid model data (bad magic)\n", __func__); + log("%s: invalid model data (bad magic)\n", __func__); return false; } } @@ -877,25 +894,25 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con // in order to save memory and also to speed up the computation wctx.wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype)); if (wctx.wtype == GGML_TYPE_COUNT) { - fprintf(stderr, "%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype); + log("%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype); return false; } const size_t scale = model.hparams.ftype ? 1 : 2; - fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab); - fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); - fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); - fprintf(stderr, "%s: n_audio_head = %d\n", __func__, hparams.n_audio_head); - fprintf(stderr, "%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); - fprintf(stderr, "%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx); - fprintf(stderr, "%s: n_text_state = %d\n", __func__, hparams.n_text_state); - fprintf(stderr, "%s: n_text_head = %d\n", __func__, hparams.n_text_head); - fprintf(stderr, "%s: n_text_layer = %d\n", __func__, hparams.n_text_layer); - fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels); - fprintf(stderr, "%s: ftype = %d\n", __func__, model.hparams.ftype); - fprintf(stderr, "%s: qntvr = %d\n", __func__, qntvr); - fprintf(stderr, "%s: type = %d\n", __func__, model.type); + log("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + log("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); + log("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); + log("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head); + log("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); + log("%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx); + log("%s: n_text_state = %d\n", __func__, hparams.n_text_state); + log("%s: n_text_head = %d\n", __func__, hparams.n_text_head); + log("%s: n_text_layer = %d\n", __func__, hparams.n_text_layer); + log("%s: n_mels = %d\n", __func__, hparams.n_mels); + log("%s: ftype = %d\n", __func__, model.hparams.ftype); + log("%s: qntvr = %d\n", __func__, qntvr); + log("%s: type = %d\n", __func__, model.type); // print memory requirements { @@ -913,7 +930,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con const size_t mem_required_decoder = scale*MEM_REQ_KV_SELF.at(model.type); - fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__, + log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__, mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0); } @@ -945,7 +962,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con read_safe(loader, n_vocab); //if (n_vocab != model.hparams.n_vocab) { - // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", + // log("%s: invalid model file '%s' (bad vocab size %d != %d)\n", // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab); // return false; //} @@ -965,7 +982,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con word.assign(&tmp[0], tmp.size()); } else { // seems like we have an empty-string token in multi-language models (i = 50256) - //fprintf(stderr, "%s: warning: empty-string token in vocab, i = %d\n", __func__, i); + //log("%s: warning: empty-string token in vocab, i = %d\n", __func__, i); word = ""; } @@ -989,7 +1006,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } if (n_vocab < model.hparams.n_vocab) { - fprintf(stderr, "%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab); + log("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab); for (int i = n_vocab; i < model.hparams.n_vocab; i++) { if (i > vocab.token_beg) { word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]"; @@ -1128,7 +1145,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*512; // object overhead - fprintf(stderr, "%s: model ctx = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); + log("%s: model ctx = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); } // create the ggml context @@ -1141,7 +1158,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con model.ctx = ggml_init(params); if (!model.ctx) { - fprintf(stderr, "%s: ggml_init() failed\n", __func__); + log("%s: ggml_init() failed\n", __func__); return false; } } @@ -1374,20 +1391,20 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con name.assign(&tmp[0], tmp.size()); if (model.tensors.find(name) == model.tensors.end()) { - fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); + log("%s: unknown tensor '%s' in model file\n", __func__, name.data()); return false; } auto tensor = model.tensors[name.data()]; if (ggml_nelements(tensor) != nelements) { - fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); - fprintf(stderr, "%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n", + log("%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + log("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n", __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]); return false; } if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) { - fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n", + log("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n", __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]); return false; } @@ -1395,7 +1412,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con const size_t bpe = ggml_type_size(ggml_type(ttype)); if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { - fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + log("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); return false; } @@ -1408,12 +1425,12 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con model.n_loaded++; } - fprintf(stderr, "%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0); + log("%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0); if (model.n_loaded == 0) { - fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); + log("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); } else if (model.n_loaded != (int) model.tensors.size()) { - fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded); + log("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded); return false; } } @@ -1784,7 +1801,7 @@ static bool whisper_encode_internal( { struct ggml_cgraph gf = {}; - ggml_build_forward_expand(&gf, cur); + ggml_build_forward_expand (&gf, cur); ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); //ggml_graph_print(&gf); @@ -2283,7 +2300,7 @@ static bool whisper_decode_internal( // run the computation { - ggml_build_forward_expand(&gf, logits); + ggml_build_forward_expand (&gf, logits); ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); } @@ -2329,6 +2346,23 @@ static std::string to_timestamp(int64_t t, bool comma = false) { return std::string(buf); } +#define SIN_COS_N_COUNT WHISPER_N_FFT +static float sin_vals[SIN_COS_N_COUNT]; +static float cos_vals[SIN_COS_N_COUNT]; + +// In FFT, we frequently use sine and cosine operations with the same values. +// We can use precalculated values to speed up the process. +static void fill_sin_cos_table() { + static bool is_filled = false; + if (is_filled) return; + for (int i = 0; i < SIN_COS_N_COUNT; i++) { + double theta = (2*M_PI*i)/SIN_COS_N_COUNT; + sin_vals[i] = sinf(theta); + cos_vals[i] = cosf(theta); + } + is_filled = true; +} + // naive Discrete Fourier Transform // input is real-valued // output is complex-valued @@ -2336,15 +2370,16 @@ static void dft(const std::vector & in, std::vector & out) { int N = in.size(); out.resize(N*2); + const int sin_cos_step = SIN_COS_N_COUNT / N; for (int k = 0; k < N; k++) { float re = 0; float im = 0; for (int n = 0; n < N; n++) { - float angle = 2*M_PI*k*n/N; - re += in[n]*cos(angle); - im -= in[n]*sin(angle); + int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N + re += in[n]*cos_vals[idx]; // cos(t) + im -= in[n]*sin_vals[idx]; // sin(t) } out[k*2 + 0] = re; @@ -2392,11 +2427,11 @@ static void fft(const std::vector & in, std::vector & out) { fft(even, even_fft); fft(odd, odd_fft); + const int sin_cos_step = SIN_COS_N_COUNT / N; for (int k = 0; k < N/2; k++) { - float theta = 2*M_PI*k/N; - - float re = cos(theta); - float im = -sin(theta); + int idx = k * sin_cos_step; // t = 2*M_PI*k/N + float re = cos_vals[idx]; // cos(t) + float im = -sin_vals[idx]; // sin(t) float re_odd = odd_fft[2*k + 0]; float im_odd = odd_fft[2*k + 1]; @@ -2409,41 +2444,51 @@ static void fft(const std::vector & in, std::vector & out) { } } -static void log_mel_spectrogram_worker_thread(int ith, const std::vector &hann, const float *samples, - int n_samples, int fft_size, int fft_step, int n_threads, - const whisper_filters &filters, bool speed_up, whisper_mel &mel) { - std::vector fft_in(fft_size, 0.0); - std::vector fft_out(2 * fft_size); - int n_fft = 1 + (speed_up ? fft_size / 4 : fft_size / 2); +static bool hann_window(int length, bool periodic, std::vector & output) { + if (output.size() < length) { + output.resize(length); + } + int offset = -1; + if (periodic) { + offset = 0; + } + for (int i = 0; i < length; i++) { + output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset))); + } - for (int i = ith; i < mel.n_len; i += n_threads) { - const int offset = i * fft_step; + return true; +} - // apply Hanning window - for (int j = 0; j < fft_size; j++) { - if (offset + j < n_samples) { - fft_in[j] = hann[j] * samples[offset + j]; - } else { - fft_in[j] = 0.0; - } +static void log_mel_spectrogram_worker_thread(int ith, const std::vector & hann, const std::vector & samples, + int n_samples, int frame_size, int frame_step, int n_threads, + const whisper_filters & filters, whisper_mel & mel) { + std::vector fft_in(frame_size, 0.0); + std::vector fft_out(2 * frame_step); + // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist + int n_fft = 1 + (frame_size / 2); + int i = ith; + + // calculate FFT only when fft_in are not all zero + for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) { + const int offset = i * frame_step; + + // apply Hanning window (~10% faster) + for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) { + fft_in[j] = hann[j] * samples[offset + j]; + } + // fill the rest with zeros + if (n_samples - offset < frame_size) { + std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0); } - // FFT -> mag^2 + // FFT fft(fft_in, fft_out); - for (int j = 0; j < fft_size; j++) { + // Calculate modulus^2 of complex numbers + // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting. + for (int j = 0; j < frame_size; j++) { fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]); } - for (int j = 1; j < fft_size / 2; j++) { - fft_out[j] += fft_out[fft_size - j]; - } - - if (speed_up) { - // scale down in the frequency domain results in a speed up in the time domain - for (int j = 0; j < n_fft; j++) { - fft_out[j] = 0.5 * (fft_out[2 * j] + fft_out[2 * j + 1]); - } - } // mel spectrogram for (int j = 0; j < mel.n_mel; j++) { @@ -2453,10 +2498,10 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector int k = 0; for (k = 0; k < n_fft - 3; k += 4) { sum += - fft_out[k + 0] * filters.data[j*n_fft + k + 0] + - fft_out[k + 1] * filters.data[j*n_fft + k + 1] + - fft_out[k + 2] * filters.data[j*n_fft + k + 2] + - fft_out[k + 3] * filters.data[j*n_fft + k + 3]; + fft_out[k + 0] * filters.data[j * n_fft + k + 0] + + fft_out[k + 1] * filters.data[j * n_fft + k + 1] + + fft_out[k + 2] * filters.data[j * n_fft + k + 2] + + fft_out[k + 3] * filters.data[j * n_fft + k + 3]; } // handle n_fft remainder @@ -2469,68 +2514,73 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector mel.data[j * mel.n_len + i] = sum; } } + + // Otherwise fft_out are all zero + double sum = log10(1e-10); + for (; i < mel.n_len; i += n_threads) { + for (int j = 0; j < mel.n_mel; j++) { + mel.data[j * mel.n_len + i] = sum; + } + } } -// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124 +// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157 static bool log_mel_spectrogram( - whisper_state & wstate, - const float * samples, + whisper_state & wstate, + const float * samples, const int n_samples, const int /*sample_rate*/, - const int fft_size, - const int fft_step, + const int frame_size, + const int frame_step, const int n_mel, const int n_threads, - const whisper_filters & filters, - const bool speed_up, - whisper_mel & mel) { + const whisper_filters & filters, + const bool debug, + whisper_mel & mel) { const int64_t t_start_us = ggml_time_us(); - // Hanning window + // Hanning window (Use cosf to eliminate difference) + // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html + // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147 std::vector hann; - hann.resize(fft_size); - for (int i = 0; i < fft_size; i++) { - hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size))); - } + hann_window(frame_size, true, hann); - mel.n_mel = n_mel; - mel.n_len = n_samples/fft_step; - mel.n_len_org = mel.n_len; - std::vector samples_padded; - - // pad audio with at least one extra chunk of zeros - { - const int pad = (100*WHISPER_CHUNK_SIZE)/2; + // Calculate the length of padding + int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30; + int64_t stage_2_pad = frame_size / 2; - if (mel.n_len % pad != 0) { - mel.n_len = (mel.n_len/pad + 1)*pad; - } - mel.n_len += pad; + // Initialize a vector and copy data from C array to it. + std::vector samples_padded; + samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2); + std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad); - samples_padded.resize(mel.n_len*fft_step); - memcpy(samples_padded.data(), samples, n_samples*sizeof(float)); - memset(samples_padded.data() + n_samples, 0, (mel.n_len*fft_step - n_samples)*sizeof(float)); + // pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio + std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0); - samples = samples_padded.data(); - } + // reflective pad 200 samples at the beginning of audio + std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin()); - mel.data.resize(mel.n_mel*mel.n_len); + mel.n_mel = n_mel; + // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936 + // Calculate number of frames + remove the last frame + mel.n_len = (samples_padded.size() - frame_size) / frame_step; + // Calculate semi-padded sample length to ensure compatibility + mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step; + mel.data.resize(mel.n_mel * mel.n_len); - //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len); - //printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate); { std::vector workers(n_threads - 1); for (int iw = 0; iw < n_threads - 1; ++iw) { workers[iw] = std::thread( - log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples, - n_samples, fft_size, fft_step, n_threads, - std::cref(filters), speed_up, std::ref(mel)); + log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples_padded, + n_samples + stage_2_pad, frame_size, frame_step, n_threads, + std::cref(filters), std::ref(mel)); } // main thread - log_mel_spectrogram_worker_thread(0, hann, samples, n_samples, fft_size, fft_step, n_threads, filters, speed_up, mel); + log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel); for (int iw = 0; iw < n_threads - 1; ++iw) { workers[iw].join(); @@ -2544,7 +2594,6 @@ static bool log_mel_spectrogram( mmax = mel.data[i]; } } - //printf("%s: max = %f\n", __func__, mmax); mmax -= 8.0; @@ -2558,7 +2607,16 @@ static bool log_mel_spectrogram( wstate.t_mel_us += ggml_time_us() - t_start_us; - //printf("mel.n_len() = %d, divided by 1500: %f, n_samples / fft_step: %d\n", mel.n_len, mel.n_len / 1500.0, n_samples / fft_step); + // Dump log_mel_spectrogram + if (debug) { + std::ofstream outFile("log_mel_spectrogram.json"); + outFile << "["; + for (uint64_t i = 0; i < mel.data.size() - 1; i++) { + outFile << mel.data[i] << ", "; + } + outFile << mel.data[mel.data.size() - 1] << "]"; + outFile.close(); + } return true; } @@ -2614,7 +2672,7 @@ static std::vector tokenize(const whisper_vocab & vocab, cons --j; } if (!found) { - fprintf(stderr, "unknown token \n"); + log("unknown token\n"); ++i; } } @@ -2676,46 +2734,47 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) { #endif struct whisper_state * whisper_init_state(whisper_context * ctx) { + fill_sin_cos_table(); whisper_state * state = new whisper_state; const size_t scale = ctx->model.hparams.ftype ? 1 : 2; if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) { - fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); + log("%s: kv_cache_init() failed for self-attention cache\n", __func__); delete state; return nullptr; } { const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v); - fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); + log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); } if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) { - fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__); + log("%s: kv_cache_init() failed for cross-attention cache\n", __func__); delete state; return nullptr; } { const size_t memory_size = ggml_nbytes(state->kv_cross.k) + ggml_nbytes(state->kv_cross.v); - fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); + log("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); } #ifdef WHISPER_USE_COREML const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model); - fprintf(stderr, "%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str()); - fprintf(stderr, "%s: first run on a device may take a while ...\n", __func__); + log("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str()); + log("%s: first run on a device may take a while ...\n", __func__); state->ctx_coreml = whisper_coreml_init(path_coreml.c_str()); if (!state->ctx_coreml) { - fprintf(stderr, "%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str()); + log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str()); #ifndef WHISPER_COREML_ALLOW_FALLBACK return nullptr; #endif } else { - fprintf(stderr, "%s: Core ML model loaded\n", __func__); + log("%s: Core ML model loaded\n", __func__); } #endif @@ -2755,7 +2814,7 @@ int whisper_ctx_init_openvino_encoder( return 1; #else if (!model_path && ctx->path_model.empty()) { - fprintf(stderr, "%s: model_path is nullptr, and ctx has no model_path set.\n", __func__); + log("%s: model_path is nullptr, and ctx has no model_path set.\n", __func__); return 1; } @@ -2775,15 +2834,15 @@ int whisper_ctx_init_openvino_encoder( path_cache = cache_dir; } - fprintf(stderr, "%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str()); - fprintf(stderr, "%s: first run on a device may take a while ...\n", __func__); + log("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str()); + log("%s: first run on a device may take a while ...\n", __func__); ctx->state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str()); if (!ctx->state->ctx_openvino) { - fprintf(stderr, "%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str()); + log("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str()); return 1; } else { - fprintf(stderr, "%s: OpenVINO model loaded\n", __func__); + log("%s: OpenVINO model loaded\n", __func__); } return 0; @@ -2792,11 +2851,11 @@ int whisper_ctx_init_openvino_encoder( struct whisper_context * whisper_init_from_file_no_state(const char * path_model) { - fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model); + log("%s: loading model from '%s'\n", __func__, path_model); auto fin = std::ifstream(path_model, std::ios::binary); if (!fin) { - fprintf(stderr, "%s: failed to open '%s'\n", __func__, path_model); + log("%s: failed to open '%s'\n", __func__, path_model); return nullptr; } @@ -2838,7 +2897,7 @@ struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buf_context ctx = { reinterpret_cast(buffer), buffer_size, 0 }; - fprintf(stderr, "%s: loading model from buffer\n", __func__); + log("%s: loading model from buffer\n", __func__); whisper_model_loader loader = {}; @@ -2873,7 +2932,7 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa if (!whisper_model_load(loader, *ctx)) { loader->close(loader->context); - fprintf(stderr, "%s: failed to load model\n", __func__); + log("%s: failed to load model\n", __func__); delete ctx; return nullptr; } @@ -2978,7 +3037,7 @@ void whisper_free_params(struct whisper_full_params * params) { int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) { - fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); + log("%s: failed to compute mel spectrogram\n", __func__); return -1; } @@ -2989,21 +3048,30 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads); } -// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 +// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good) int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { - if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, state->mel)) { - fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); + if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) { + log("%s: failed to compute mel spectrogram\n", __func__); return -1; } return 0; } -// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 +// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good) int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads); } +// same as whisper_pcm_to_mel, but applies WSOLA to speed up the audio x2 +// TODO + +// same as whisper_pcm_to_mel, but applies HPTSM to speed up the audio x2 +// TODO + +// same as whisper_pcm_to_mel, but applies PV (with phase lock) to speed up the audio x2 +// TODO + int whisper_set_mel_with_state( struct whisper_context * /*ctx*/, struct whisper_state * state, @@ -3011,7 +3079,7 @@ int whisper_set_mel_with_state( int n_len, int n_mel) { if (n_mel != WHISPER_N_MEL) { - fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL); + log("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL); return -1; } @@ -3035,7 +3103,7 @@ int whisper_set_mel( int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) { if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) { - fprintf(stderr, "%s: failed to eval\n", __func__); + log("%s: failed to eval\n", __func__); return -1; } @@ -3044,7 +3112,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) { - fprintf(stderr, "%s: failed to eval\n", __func__); + log("%s: failed to eval\n", __func__); return -1; } @@ -3055,7 +3123,7 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state const int selected_decoder_id = 0; if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) { - fprintf(stderr, "%s: failed to eval\n", __func__); + log("%s: failed to eval\n", __func__); return 1; } @@ -3067,13 +3135,12 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i const int selected_decoder_id = 0; if (ctx->state == nullptr) { - fprintf(stderr, "%s: ERROR state was not loaded.\n", __func__); + log("%s: ERROR state was not loaded.\n", __func__); return false; } - if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) { - fprintf(stderr, "%s: failed to eval\n", __func__); + log("%s: failed to eval\n", __func__); return 1; } @@ -3084,7 +3151,7 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to const auto res = tokenize(ctx->vocab, text); if (n_max_tokens < (int) res.size()) { - fprintf(stderr, "%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens); + log("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens); return -1; } @@ -3112,7 +3179,7 @@ int whisper_lang_id(const char * lang) { } } - fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang); + log("%s: unknown language '%s'\n", __func__, lang); return -1; } return g_lang.at(lang).first; @@ -3125,7 +3192,7 @@ const char * whisper_lang_str(int id) { } } - fprintf(stderr, "%s: unknown language id %d\n", __func__, id); + log("%s: unknown language id %d\n", __func__, id); return nullptr; } @@ -3138,25 +3205,25 @@ int whisper_lang_auto_detect_with_state( const int seek = offset_ms/10; if (seek < 0) { - fprintf(stderr, "%s: offset %dms is before the start of the audio\n", __func__, offset_ms); + log("%s: offset %dms is before the start of the audio\n", __func__, offset_ms); return -1; } if (seek >= state->mel.n_len_org) { - fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10); + log("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10); return -2; } // run the encoder if (whisper_encode_with_state(ctx, state, seek, n_threads) != 0) { - fprintf(stderr, "%s: failed to encode\n", __func__); + log("%s: failed to encode\n", __func__); return -6; } const std::vector prompt = { whisper_token_sot(ctx) }; if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) { - fprintf(stderr, "%s: failed to decode\n", __func__); + log("%s: failed to decode\n", __func__); return -7; } @@ -3305,7 +3372,6 @@ float * whisper_get_logits(struct whisper_context * ctx) { return ctx->state->logits.data(); } - float * whisper_get_logits_from_state(struct whisper_state * state) { return state->logits.data(); } @@ -3357,21 +3423,21 @@ whisper_token whisper_token_transcribe(struct whisper_context * ctx) { void whisper_print_timings(struct whisper_context * ctx) { const int64_t t_end_us = ggml_time_us(); - fprintf(stderr, "\n"); - fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f); + log("\n"); + log("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f); if (ctx->state != nullptr) { const int32_t n_sample = std::max(1, ctx->state->n_sample); const int32_t n_encode = std::max(1, ctx->state->n_encode); const int32_t n_decode = std::max(1, ctx->state->n_decode); - fprintf(stderr, "%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); - fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f); - fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); - fprintf(stderr, "%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); - fprintf(stderr, "%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); + log("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); + log("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f); + log("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); + log("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); + log("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); } - fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); + log("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); } void whisper_reset_timings(struct whisper_context * ctx) { @@ -3413,6 +3479,7 @@ const char * whisper_print_system_info(void) { s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | "; + s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | "; s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; s += "COREML = " + std::to_string(whisper_has_coreml()) + " | "; s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | "; @@ -3455,6 +3522,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.max_tokens =*/ 0, /*.speed_up =*/ false, + /*.debug_mode =*/ false, /*.audio_ctx =*/ 0, /*.tdrz_enable =*/ false, @@ -3616,7 +3684,7 @@ static void whisper_process_logits( WHISPER_ASSERT(n_logits == ctx.vocab.n_vocab); // extract the logits for the last token - // we will be mutating and therefore we don't want to use the ctx.logits buffer directly + // we will be mutating, and therefore we don't want to use the ctx.logits buffer directly auto & probs = decoder.probs; auto & logits = decoder.logits; auto & logprobs = decoder.logprobs; @@ -3695,7 +3763,7 @@ static void whisper_process_logits( const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg; const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg; - //fprintf(stderr, "last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp); + //log("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp); if (last_was_timestamp) { if (penultimate_was_timestamp) { @@ -3771,7 +3839,7 @@ static void whisper_process_logits( const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg); - //fprintf(stderr, "timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob); + //log("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob); if (timestamp_logprob > max_text_token_logprob) { for (int i = 0; i < vocab.token_beg; ++i) { @@ -4017,16 +4085,17 @@ int whisper_full_with_state( result_all.clear(); - // compute log mel spectrogram - if (params.speed_up) { - if (whisper_pcm_to_mel_phase_vocoder_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { - fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__); + if (n_samples > 0) { + // compute log mel spectrogram + if (params.speed_up) { + // TODO: Replace PV with more advanced algorithm + log("%s: failed to compute log mel spectrogram\n", __func__); return -1; - } - } else { - if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { - fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__); - return -2; + } else { + if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { + log("%s: failed to compute log mel spectrogram\n", __func__); + return -2; + } } } @@ -4036,13 +4105,13 @@ int whisper_full_with_state( const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data()); if (lang_id < 0) { - fprintf(stderr, "%s: failed to auto-detect language\n", __func__); + log("%s: failed to auto-detect language\n", __func__); return -3; } state->lang_id = lang_id; params.language = whisper_lang_str(lang_id); - fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); + log("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); if (params.detect_language) { return 0; } @@ -4052,14 +4121,16 @@ int whisper_full_with_state( state->t_beg = 0; state->t_last = 0; state->tid_last = 0; - state->energy = get_signal_energy(samples, n_samples, 32); + if (n_samples > 0) { + state->energy = get_signal_energy(samples, n_samples, 32); + } } const int seek_start = params.offset_ms/10; const int seek_end = params.duration_ms == 0 ? whisper_n_len_from_state(state) : seek_start + params.duration_ms/10; - // if length of spectrogram is less than 1s (100 samples), then return - // basically don't process anything that is less than 1s + // if length of spectrogram is less than 1.0s (100 frames), then return + // basically don't process anything that is less than 1.0s // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39 if (seek_end < seek_start + (params.speed_up ? 50 : 100)) { return 0; @@ -4099,7 +4170,7 @@ int whisper_full_with_state( if (decoder.kv_self.ctx == nullptr) { decoder.kv_self = state->decoders[0].kv_self; if (!kv_cache_reinit(decoder.kv_self)) { - fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j); + log("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j); return -4; } @@ -4143,7 +4214,7 @@ int whisper_full_with_state( // overwrite audio_ctx, max allowed is hparams.n_audio_ctx if (params.audio_ctx > whisper_n_audio_ctx(ctx)) { - fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); + log("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); return -5; } state->exp_n_audio_ctx = params.audio_ctx; @@ -4161,9 +4232,6 @@ int whisper_full_with_state( } } - int progress_prev = 0; - int progress_step = 5; - int seek = seek_start; std::vector prompt; @@ -4190,16 +4258,11 @@ int whisper_full_with_state( // main loop while (true) { - const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start); - while (progress_cur >= progress_prev + progress_step) { - progress_prev += progress_step; - if (params.print_progress) { - fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress_prev); - } - } if (params.progress_callback) { + const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start); + params.progress_callback( - ctx, ctx->state, progress_prev, params.progress_callback_user_data); + ctx, ctx->state, progress_cur, params.progress_callback_user_data); } // of only 1 second left, then stop @@ -4209,14 +4272,14 @@ int whisper_full_with_state( if (params.encoder_begin_callback) { if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) { - fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__); + log("%s: encoder_begin_callback returned false - aborting\n", __func__); break; } } // encode audio features starting at offset seek if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) { - fprintf(stderr, "%s: failed to encode\n", __func__); + log("%s: failed to encode\n", __func__); return -6; } @@ -4299,7 +4362,7 @@ int whisper_full_with_state( WHISPER_PRINT_DEBUG("\n\n"); if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) { - fprintf(stderr, "%s: failed to decode\n", __func__); + log("%s: failed to decode\n", __func__); return -7; } @@ -4537,7 +4600,7 @@ int whisper_full_with_state( //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta); if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) { - fprintf(stderr, "%s: failed to decode\n", __func__); + log("%s: failed to decode\n", __func__); return -8; } @@ -4752,7 +4815,6 @@ int whisper_full_with_state( return 0; } - int whisper_full( struct whisper_context * ctx, struct whisper_full_params params, @@ -4829,7 +4891,6 @@ int whisper_full_parallel( result.t0 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t; result.t1 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t; - // make sure that segments are not overlapping if (!ctx->state->result_all.empty()) { result.t0 = std::max(result.t0, ctx->state->result_all.back().t1); @@ -4859,12 +4920,12 @@ int whisper_full_parallel( ctx->state->t_decode_us /= n_processors; // print information about the audio boundaries - fprintf(stderr, "\n"); - fprintf(stderr, "%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors); + log("\n"); + log("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors); for (int i = 0; i < n_processors - 1; ++i) { - fprintf(stderr, "%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str()); + log("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str()); } - fprintf(stderr, "%s: the transcription quality may be degraded near these boundaries\n", __func__); + log("%s: the transcription quality may be degraded near these boundaries\n", __func__); return ret; } @@ -5222,7 +5283,7 @@ static void whisper_exp_compute_token_level_timestamps( const int n_samples = state.energy.size(); if (n_samples == 0) { - fprintf(stderr, "%s: no signal data available\n", __func__); + log("%s: no signal data available\n", __func__); return; } @@ -5442,3 +5503,7 @@ static void whisper_exp_compute_token_level_timestamps( // } //} } + +void whisper_set_log_callback(whisper_log_callback callback) { + whisper_log = callback; +} diff --git a/examples/whisper/whisper.h b/examples/whisper/whisper.h index 83af11bd..73ab4d79 100644 --- a/examples/whisper/whisper.h +++ b/examples/whisper/whisper.h @@ -67,6 +67,7 @@ extern "C" { struct whisper_context; struct whisper_state; + struct whisper_full_params; typedef int whisper_token; @@ -345,7 +346,7 @@ extern "C" { void * user_data); // Parameters for the whisper_full() function - // If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp: + // If you change the order or add new parameters, make sure to update the default values in whisper.cpp: // whisper_full_default_params() struct whisper_full_params { enum whisper_sampling_strategy strategy; @@ -374,6 +375,7 @@ extern "C" { // [EXPERIMENTAL] speed-up techniques // note: these can significantly reduce the quality of the output bool speed_up; // speed-up the audio by 2x using Phase Vocoder + bool debug_mode; // enable debug_mode provides extra info (eg. Dump log_mel) int audio_ctx; // overwrite the audio context size (0 = use default) // [EXPERIMENTAL] [TDRZ] tinydiarize @@ -517,6 +519,11 @@ extern "C" { WHISPER_API int whisper_bench_ggml_mul_mat (int n_threads); WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads); + // Control logging output; default behavior is to print to stderr + + typedef void (*whisper_log_callback)(const char * line); + WHISPER_API void whisper_set_log_callback(whisper_log_callback callback); + #ifdef __cplusplus } #endif