]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
whisper : sync (match OpenAI input, convert, new features) (#495)
authorGeorgi Gerganov <redacted>
Tue, 5 Sep 2023 10:55:06 +0000 (13:55 +0300)
committerGitHub <redacted>
Tue, 5 Sep 2023 10:55:06 +0000 (13:55 +0300)
ggml-ci

examples/whisper/convert-pt-to-ggml.py
examples/whisper/main.cpp
examples/whisper/quantize.cpp
examples/whisper/whisper.cpp
examples/whisper/whisper.h

index 07752e7556edfa124431f014a13913733862894f..9aa134b53f7d05c1f9d2be60759f28b14e87fdc6 100644 (file)
@@ -39,132 +39,133 @@ import json
 import code
 import torch
 import numpy as np
-
-from transformers import GPTJForCausalLM
-from transformers import GPT2TokenizerFast
+import base64
+from pathlib import Path
+#from transformers import GPTJForCausalLM
+#from transformers import GPT2TokenizerFast
 
 # ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L10-L110
-LANGUAGES = {
-    "en": "english",
-    "zh": "chinese",
-    "de": "german",
-    "es": "spanish",
-    "ru": "russian",
-    "ko": "korean",
-    "fr": "french",
-    "ja": "japanese",
-    "pt": "portuguese",
-    "tr": "turkish",
-    "pl": "polish",
-    "ca": "catalan",
-    "nl": "dutch",
-    "ar": "arabic",
-    "sv": "swedish",
-    "it": "italian",
-    "id": "indonesian",
-    "hi": "hindi",
-    "fi": "finnish",
-    "vi": "vietnamese",
-    "iw": "hebrew",
-    "uk": "ukrainian",
-    "el": "greek",
-    "ms": "malay",
-    "cs": "czech",
-    "ro": "romanian",
-    "da": "danish",
-    "hu": "hungarian",
-    "ta": "tamil",
-    "no": "norwegian",
-    "th": "thai",
-    "ur": "urdu",
-    "hr": "croatian",
-    "bg": "bulgarian",
-    "lt": "lithuanian",
-    "la": "latin",
-    "mi": "maori",
-    "ml": "malayalam",
-    "cy": "welsh",
-    "sk": "slovak",
-    "te": "telugu",
-    "fa": "persian",
-    "lv": "latvian",
-    "bn": "bengali",
-    "sr": "serbian",
-    "az": "azerbaijani",
-    "sl": "slovenian",
-    "kn": "kannada",
-    "et": "estonian",
-    "mk": "macedonian",
-    "br": "breton",
-    "eu": "basque",
-    "is": "icelandic",
-    "hy": "armenian",
-    "ne": "nepali",
-    "mn": "mongolian",
-    "bs": "bosnian",
-    "kk": "kazakh",
-    "sq": "albanian",
-    "sw": "swahili",
-    "gl": "galician",
-    "mr": "marathi",
-    "pa": "punjabi",
-    "si": "sinhala",
-    "km": "khmer",
-    "sn": "shona",
-    "yo": "yoruba",
-    "so": "somali",
-    "af": "afrikaans",
-    "oc": "occitan",
-    "ka": "georgian",
-    "be": "belarusian",
-    "tg": "tajik",
-    "sd": "sindhi",
-    "gu": "gujarati",
-    "am": "amharic",
-    "yi": "yiddish",
-    "lo": "lao",
-    "uz": "uzbek",
-    "fo": "faroese",
-    "ht": "haitian creole",
-    "ps": "pashto",
-    "tk": "turkmen",
-    "nn": "nynorsk",
-    "mt": "maltese",
-    "sa": "sanskrit",
-    "lb": "luxembourgish",
-    "my": "myanmar",
-    "bo": "tibetan",
-    "tl": "tagalog",
-    "mg": "malagasy",
-    "as": "assamese",
-    "tt": "tatar",
-    "haw": "hawaiian",
-    "ln": "lingala",
-    "ha": "hausa",
-    "ba": "bashkir",
-    "jw": "javanese",
-    "su": "sundanese",
-}
-
-# ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L273-L292
-def build_tokenizer(path_to_whisper_repo: str, name: str = "gpt2"):
-    os.environ["TOKENIZERS_PARALLELISM"] = "false"
-    path = os.path.join(path_to_whisper_repo, "whisper/assets", name)
-    tokenizer = GPT2TokenizerFast.from_pretrained(path)
-
-    specials = [
-        "<|startoftranscript|>",
-        *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
-        "<|translate|>",
-        "<|transcribe|>",
-        "<|startoflm|>",
-        "<|startofprev|>",
-        "<|nocaptions|>",
-        "<|notimestamps|>",
-    ]
-
-    tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
-    return tokenizer
+#LANGUAGES = {
+#    "en": "english",
+#    "zh": "chinese",
+#    "de": "german",
+#    "es": "spanish",
+#    "ru": "russian",
+#    "ko": "korean",
+#    "fr": "french",
+#    "ja": "japanese",
+#    "pt": "portuguese",
+#    "tr": "turkish",
+#    "pl": "polish",
+#    "ca": "catalan",
+#    "nl": "dutch",
+#    "ar": "arabic",
+#    "sv": "swedish",
+#    "it": "italian",
+#    "id": "indonesian",
+#    "hi": "hindi",
+#    "fi": "finnish",
+#    "vi": "vietnamese",
+#    "iw": "hebrew",
+#    "uk": "ukrainian",
+#    "el": "greek",
+#    "ms": "malay",
+#    "cs": "czech",
+#    "ro": "romanian",
+#    "da": "danish",
+#    "hu": "hungarian",
+#    "ta": "tamil",
+#    "no": "norwegian",
+#    "th": "thai",
+#    "ur": "urdu",
+#    "hr": "croatian",
+#    "bg": "bulgarian",
+#    "lt": "lithuanian",
+#    "la": "latin",
+#    "mi": "maori",
+#    "ml": "malayalam",
+#    "cy": "welsh",
+#    "sk": "slovak",
+#    "te": "telugu",
+#    "fa": "persian",
+#    "lv": "latvian",
+#    "bn": "bengali",
+#    "sr": "serbian",
+#    "az": "azerbaijani",
+#    "sl": "slovenian",
+#    "kn": "kannada",
+#    "et": "estonian",
+#    "mk": "macedonian",
+#    "br": "breton",
+#    "eu": "basque",
+#    "is": "icelandic",
+#    "hy": "armenian",
+#    "ne": "nepali",
+#    "mn": "mongolian",
+#    "bs": "bosnian",
+#    "kk": "kazakh",
+#    "sq": "albanian",
+#    "sw": "swahili",
+#    "gl": "galician",
+#    "mr": "marathi",
+#    "pa": "punjabi",
+#    "si": "sinhala",
+#    "km": "khmer",
+#    "sn": "shona",
+#    "yo": "yoruba",
+#    "so": "somali",
+#    "af": "afrikaans",
+#    "oc": "occitan",
+#    "ka": "georgian",
+#    "be": "belarusian",
+#    "tg": "tajik",
+#    "sd": "sindhi",
+#    "gu": "gujarati",
+#    "am": "amharic",
+#    "yi": "yiddish",
+#    "lo": "lao",
+#    "uz": "uzbek",
+#    "fo": "faroese",
+#    "ht": "haitian creole",
+#    "ps": "pashto",
+#    "tk": "turkmen",
+#    "nn": "nynorsk",
+#    "mt": "maltese",
+#    "sa": "sanskrit",
+#    "lb": "luxembourgish",
+#    "my": "myanmar",
+#    "bo": "tibetan",
+#    "tl": "tagalog",
+#    "mg": "malagasy",
+#    "as": "assamese",
+#    "tt": "tatar",
+#    "haw": "hawaiian",
+#    "ln": "lingala",
+#    "ha": "hausa",
+#    "ba": "bashkir",
+#    "jw": "javanese",
+#    "su": "sundanese",
+#}
+
+## ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L273-L292
+#def build_tokenizer(path_to_whisper_repo: str, name: str = "gpt2"):
+#    os.environ["TOKENIZERS_PARALLELISM"] = "false"
+#    path = os.path.join(path_to_whisper_repo, "whisper/assets", name)
+#    tokenizer = GPT2TokenizerFast.from_pretrained(path)
+#
+#    specials = [
+#        "<|startoftranscript|>",
+#        *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
+#        "<|translate|>",
+#        "<|transcribe|>",
+#        "<|startoflm|>",
+#        "<|startofprev|>",
+#        "<|nocaptions|>",
+#        "<|notimestamps|>",
+#    ]
+#
+#    tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
+#    return tokenizer
 
 # ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
 def bytes_to_unicode():
@@ -193,17 +194,17 @@ if len(sys.argv) < 4:
     print("Usage: convert-pt-to-ggml.py model.pt path-to-whisper-repo dir-output [use-f32]\n")
     sys.exit(1)
 
-fname_inp   = sys.argv[1]
-dir_whisper = sys.argv[2]
-dir_out     = sys.argv[3]
+fname_inp   = Path(sys.argv[1])
+dir_whisper = Path(sys.argv[2])
+dir_out     = Path(sys.argv[3])
 
 # try to load PyTorch binary data
 try:
     model_bytes = open(fname_inp, "rb").read()
     with io.BytesIO(model_bytes) as fp:
         checkpoint = torch.load(fp, map_location="cpu")
-except:
-    print("Error: failed to load PyTorch model file: %s" % fname_inp)
+except Exception:
+    print("Error: failed to load PyTorch model file:" , fname_inp)
     sys.exit(1)
 
 hparams = checkpoint["dims"]
@@ -217,33 +218,52 @@ list_vars = checkpoint["model_state_dict"]
 
 # load mel filters
 n_mels = hparams["n_mels"]
-with np.load(os.path.join(dir_whisper, "whisper/assets", "mel_filters.npz")) as f:
+with np.load(dir_whisper / "whisper" / "assets" / "mel_filters.npz") as f:
     filters = torch.from_numpy(f[f"mel_{n_mels}"])
     #print (filters)
 
 #code.interact(local=locals())
 
+# load tokenizer
+# for backwards compatibility, also check for older hf_transformers format tokenizer files
+# old format: dir_whisper/whisper/assets/[multilingual/gpt2]/vocab.json
+# new format: dir_whisper/whisper/assets/[multilingual/gpt2].tiktoken
 multilingual = hparams["n_vocab"] == 51865
-tokenizer = build_tokenizer(dir_whisper, multilingual and "multilingual" or "gpt2")
+tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual.tiktoken" or "gpt2.tiktoken")
+tokenizer_type = "tiktoken"
+if not tokenizer.is_file():
+    tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual" or "gpt2") / "vocab.json"
+    tokenizer_type = "hf_transformers"
+    if not tokenizer.is_file():
+        print("Error: failed to find either tiktoken or hf_transformers tokenizer file:", tokenizer)
+        sys.exit(1)
 
-#print(tokenizer)
-#print(tokenizer.name_or_path)
-#print(len(tokenizer.additional_special_tokens))
-dir_tokenizer = tokenizer.name_or_path
+byte_encoder = bytes_to_unicode()
+byte_decoder = {v:k for k, v in byte_encoder.items()}
 
-# output in the same directory as the model
-fname_out = dir_out + "/ggml-model.bin"
+if tokenizer_type == "tiktoken":
+    with open(tokenizer, "rb") as f:
+        contents = f.read()
+        tokens = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)}
+elif tokenizer_type == "hf_transformers":
+    with open(tokenizer, "r", encoding="utf8") as f:
+        _tokens_raw = json.load(f)
+        if '<|endoftext|>' in _tokens_raw:
+            # ensures exact same model as tokenizer_type == tiktoken
+            # details: https://github.com/ggerganov/whisper.cpp/pull/725
+            del _tokens_raw['<|endoftext|>']
+        tokens = {bytes([byte_decoder[c] for c in token]): int(idx) for token, idx in _tokens_raw.items()}
 
-with open(dir_tokenizer + "/vocab.json", "r") as f:
-    tokens = json.load(f)
+# output in the same directory as the model
+fname_out = dir_out / "ggml-model.bin"
 
 # use 16-bit or 32-bit floats
 use_f16 = True
 if len(sys.argv) > 4:
     use_f16 = False
-    fname_out = dir_out + "/ggml-model-f32.bin"
+    fname_out = dir_out / "ggml-model-f32.bin"
 
-fout = open(fname_out, "wb")
+fout = fname_out.open("wb")
 
 fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
 fout.write(struct.pack("i", hparams["n_vocab"]))
@@ -265,32 +285,28 @@ for i in range(filters.shape[0]):
     for j in range(filters.shape[1]):
         fout.write(struct.pack("f", filters[i][j]))
 
-byte_encoder = bytes_to_unicode()
-byte_decoder = {v:k for k, v in byte_encoder.items()}
-
+# write tokenizer
 fout.write(struct.pack("i", len(tokens)))
 
 for key in tokens:
-    text = bytearray([byte_decoder[c] for c in key])
-    fout.write(struct.pack("i", len(text)))
-    fout.write(text)
+    fout.write(struct.pack("i", len(key)))
+    fout.write(key)
 
 for name in list_vars.keys():
     data = list_vars[name].squeeze().numpy()
-    print("Processing variable: " + name + " with shape: ", data.shape)
+    print("Processing variable: " , name ,  " with shape: ", data.shape)
 
     # reshape conv bias from [n] to [n, 1]
-    if name == "encoder.conv1.bias" or \
-       name == "encoder.conv2.bias":
+    if name in ["encoder.conv1.bias", "encoder.conv2.bias"]:
         data = data.reshape(data.shape[0], 1)
-        print("  Reshaped variable: " + name + " to shape: ", data.shape)
+        print(f"  Reshaped variable: {name} to shape: ", data.shape)
 
-    n_dims = len(data.shape);
+    n_dims = len(data.shape)
 
     # looks like the whisper models are in f16 by default
     # so we need to convert the small tensors to f32 until we fully support f16 in ggml
     # ftype == 0 -> float32, ftype == 1 -> float16
-    ftype = 1;
+    ftype = 1
     if use_f16:
         if n_dims < 2 or \
                 name == "encoder.conv1.bias"   or \
@@ -301,9 +317,8 @@ for name in list_vars.keys():
             data = data.astype(np.float32)
             ftype = 0
     else:
-        if n_dims < 3 and data.dtype != np.float32:
-            data = data.astype(np.float32)
-            ftype = 0
+        data = data.astype(np.float32)
+        ftype = 0
 
     #if name.startswith("encoder"):
     #    if name.endswith("mlp.0.weight") or \
@@ -312,16 +327,16 @@ for name in list_vars.keys():
     #        data = data.transpose()
 
     # header
-    str = name.encode('utf-8')
-    fout.write(struct.pack("iii", n_dims, len(str), ftype))
+    str_ = name.encode('utf-8')
+    fout.write(struct.pack("iii", n_dims, len(str_), ftype))
     for i in range(n_dims):
         fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
-    fout.write(str);
+    fout.write(str_)
 
     # data
     data.tofile(fout)
 
 fout.close()
 
-print("Done. Output file: " + fname_out)
+print("Done. Output file: " , fname_out)
 print("")
index 8dd31d028b1e732bf6572230dab0fd8cc835d721..fa399c6d78114a9d886e626b3c3bf4b48a68451a 100644 (file)
@@ -59,6 +59,7 @@ struct whisper_params {
     int32_t offset_t_ms  =  0;
     int32_t offset_n     =  0;
     int32_t duration_ms  =  0;
+    int32_t progress_step =  5;
     int32_t max_context  = -1;
     int32_t max_len      =  0;
     int32_t best_of      =  2;
@@ -69,6 +70,7 @@ struct whisper_params {
     float logprob_thold = -1.00f;
 
     bool speed_up        = false;
+    bool debug_mode      = false;
     bool translate       = false;
     bool detect_language = false;
     bool diarize         = false;
@@ -86,6 +88,7 @@ struct whisper_params {
     bool print_colors    = false;
     bool print_progress  = false;
     bool no_timestamps   = false;
+    bool log_score       = false;
 
     std::string language  = "en";
     std::string prompt;
@@ -133,7 +136,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-wt"   || arg == "--word-thold")      { params.word_thold      = std::stof(argv[++i]); }
         else if (arg == "-et"   || arg == "--entropy-thold")   { params.entropy_thold   = std::stof(argv[++i]); }
         else if (arg == "-lpt"  || arg == "--logprob-thold")   { params.logprob_thold   = std::stof(argv[++i]); }
-        else if (arg == "-su"   || arg == "--speed-up")        { params.speed_up        = true; }
+        // else if (arg == "-su"   || arg == "--speed-up")        { params.speed_up        = true; }
+        else if (arg == "-debug"|| arg == "--debug-mode")      { params.debug_mode      = true; }
         else if (arg == "-tr"   || arg == "--translate")       { params.translate       = true; }
         else if (arg == "-di"   || arg == "--diarize")         { params.diarize         = true; }
         else if (arg == "-tdrz" || arg == "--tinydiarize")     { params.tinydiarize     = true; }
@@ -158,6 +162,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-m"    || arg == "--model")           { params.model           = argv[++i]; }
         else if (arg == "-f"    || arg == "--file")            { params.fname_inp.emplace_back(argv[++i]); }
         else if (arg == "-oved" || arg == "--ov-e-device")     { params.openvino_encode_device = argv[++i]; }
+        else if (arg == "-ls"   || arg == "--log-score")       { params.log_score = true; }
         else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
             whisper_print_usage(argc, argv, params);
@@ -187,7 +192,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -wt N,     --word-thold N      [%-7.2f] word timestamp probability threshold\n",         params.word_thold);
     fprintf(stderr, "  -et N,     --entropy-thold N   [%-7.2f] entropy threshold for decoder fail\n",           params.entropy_thold);
     fprintf(stderr, "  -lpt N,    --logprob-thold N   [%-7.2f] log probability threshold for decoder fail\n",   params.logprob_thold);
-    fprintf(stderr, "  -su,       --speed-up          [%-7s] speed up audio by x2 (reduced accuracy)\n",        params.speed_up ? "true" : "false");
+    // fprintf(stderr, "  -su,       --speed-up          [%-7s] speed up audio by x2 (reduced accuracy)\n",        params.speed_up ? "true" : "false");
+    fprintf(stderr, "  -debug,    --debug-mode        [%-7s] enable debug mode (eg. dump log_mel)\n",           params.debug_mode ? "true" : "false");
     fprintf(stderr, "  -tr,       --translate         [%-7s] translate from source language to english\n",      params.translate ? "true" : "false");
     fprintf(stderr, "  -di,       --diarize           [%-7s] stereo audio diarization\n",                       params.diarize ? "true" : "false");
     fprintf(stderr, "  -tdrz,     --tinydiarize       [%-7s] enable tinydiarize (requires a tdrz model)\n",     params.tinydiarize ? "true" : "false");
@@ -211,6 +217,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -m FNAME,  --model FNAME       [%-7s] model path\n",                                     params.model.c_str());
     fprintf(stderr, "  -f FNAME,  --file FNAME        [%-7s] input WAV file path\n",                            "");
     fprintf(stderr, "  -oved D,   --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n",  params.openvino_encode_device.c_str());
+    fprintf(stderr, "  -ls,       --log-score         [%-7s] log best decoder scores of tokens\n",              params.log_score?"true":"false");
     fprintf(stderr, "\n");
 }
 
@@ -218,6 +225,7 @@ struct whisper_print_user_data {
     const whisper_params * params;
 
     const std::vector<std::vector<float>> * pcmf32s;
+    int progress_prev;
 };
 
 std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) {
@@ -252,6 +260,14 @@ std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s
 
     return speaker;
 }
+void whisper_print_progress_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int progress, void * user_data) {
+    int progress_step = ((whisper_print_user_data *) user_data)->params->progress_step;
+    int * progress_prev  = &(((whisper_print_user_data *) user_data)->progress_prev);
+    if (progress >= *progress_prev + progress_step) {
+        *progress_prev += progress_step;
+        fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress);
+    }
+}
 
 void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) {
     const auto & params  = *((whisper_print_user_data *) user_data)->params;
@@ -476,6 +492,25 @@ bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_
     return true;
 }
 
+bool output_score(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
+    std::ofstream fout(fname);
+    fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
+
+    const int n_segments = whisper_full_n_segments(ctx);
+    // fprintf(stderr,"segments: %d\n",n_segments);
+    for (int i = 0; i < n_segments; ++i) {
+        const int n_tokens = whisper_full_n_tokens(ctx, i);
+        // fprintf(stderr,"tokens: %d\n",n_tokens);
+        for (int j = 0; j < n_tokens; j++) {
+            auto token = whisper_full_get_token_text(ctx, i, j);
+            auto probability = whisper_full_get_token_p(ctx, i, j);
+            fout << token << '\t' << probability << std::endl;
+            // fprintf(stderr,"token: %s %f\n",token,probability);
+           }
+    }
+    return true;
+}
+
 bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
     std::ofstream fout(fname);
     int indent = 0;
@@ -883,6 +918,7 @@ int main(int argc, char ** argv) {
             wparams.split_on_word    = params.split_on_word;
 
             wparams.speed_up         = params.speed_up;
+            wparams.debug_mode       = params.debug_mode;
 
             wparams.tdrz_enable      = params.tinydiarize; // [TDRZ]
 
@@ -895,7 +931,7 @@ int main(int argc, char ** argv) {
             wparams.entropy_thold    = params.entropy_thold;
             wparams.logprob_thold    = params.logprob_thold;
 
-            whisper_print_user_data user_data = { &params, &pcmf32s };
+            whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
 
             // this callback is called on each new segment
             if (!wparams.print_realtime) {
@@ -903,6 +939,11 @@ int main(int argc, char ** argv) {
                 wparams.new_segment_callback_user_data = &user_data;
             }
 
+            if (wparams.print_progress) {
+                wparams.progress_callback           = whisper_print_progress_callback;
+                wparams.progress_callback_user_data = &user_data;
+            }
+
             // example for abort mechanism
             // in this example, we do not abort the processing, but we could if the flag is set to true
             // the callback is called before every encoder run - if it returns false, the processing is aborted
@@ -967,6 +1008,12 @@ int main(int argc, char ** argv) {
                 const auto fname_lrc = fname_out + ".lrc";
                 output_lrc(ctx, fname_lrc.c_str(), params, pcmf32s);
             }
+
+            // output to score file
+            if (params.log_score) {
+                const auto fname_score = fname_out + ".score.txt";
+                output_score(ctx, fname_score.c_str(), params, pcmf32s);
+            }
         }
     }
 
index 64e8f35c3863a6be0cfee2211c8be922f95699a7..b01d61431086eaab0fcdd73418e95077fa4750b6 100644 (file)
@@ -138,7 +138,7 @@ bool whisper_model_quantize(const std::string & fname_inp, const std::string & f
         //    return false;
         //}
 
-        char word[128];
+        char word[129];
 
         for (int i = 0; i < n_vocab; i++) {
             uint32_t len;
index cb124ec9b564b12ac03ec37d592a2c7ef98f92d8..b50c86d057fc207a213e678393312d299ac6a7de 100644 (file)
@@ -14,6 +14,7 @@
 #define _USE_MATH_DEFINES
 #include <cmath>
 #include <cstdio>
+#include <cstdarg>
 #include <cstring>
 #include <fstream>
 #include <map>
@@ -81,7 +82,7 @@ static void byteswap_tensor(ggml_tensor * tensor) {
     } while (0)
 #define BYTESWAP_TENSOR(t)       \
     do {                         \
-        byteswap_tensor(tensor); \
+        byteswap_tensor(t); \
     } while (0)
 #else
 #define BYTESWAP_VALUE(d) do {} while (0)
@@ -92,7 +93,7 @@ static void byteswap_tensor(ggml_tensor * tensor) {
 #define WHISPER_ASSERT(x) \
     do { \
         if (!(x)) { \
-            fprintf(stderr, "WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
+            log("WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
             abort(); \
         } \
     } while (0)
@@ -589,7 +590,7 @@ struct whisper_model {
 struct whisper_sequence {
     std::vector<whisper_token_data> tokens;
 
-    // the accumulated transcription in the current interation (used to truncate the tokens array)
+    // the accumulated transcription in the current iteration (used to truncate the tokens array)
     int result_len;
 
     double sum_logprobs_all; // the sum of the log probabilities of the tokens
@@ -724,6 +725,22 @@ struct whisper_context {
     std::string path_model; // populated by whisper_init_from_file()
 };
 
+static void whisper_default_log(const char * text) {
+    fprintf(stderr, "%s", text);
+}
+
+static whisper_log_callback whisper_log = whisper_default_log;
+
+// TODO: fix compile warning about "format string is not a string literal"
+static void log(const char * fmt, ...) {
+    if (!whisper_log) return;
+    char buf[1024];
+    va_list args;
+    va_start(args, fmt);
+    vsnprintf(buf, sizeof(buf), fmt, args);
+    whisper_log(buf);
+}
+
 template<typename T>
 static void read_safe(whisper_model_loader * loader, T & dest) {
     loader->read(loader->context, &dest, sizeof(T));
@@ -747,7 +764,7 @@ static bool kv_cache_init(
     cache.ctx = ggml_init(params);
 
     if (!cache.ctx) {
-        fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
+        log("%s: failed to allocate memory for kv cache\n", __func__);
         return false;
     }
 
@@ -783,7 +800,7 @@ static bool kv_cache_reinit(struct whisper_kv_cache & cache) {
     cache.ctx = ggml_init(params);
 
     if (!cache.ctx) {
-        fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
+        log("%s: failed to allocate memory for kv cache\n", __func__);
         return false;
     }
 
@@ -812,7 +829,7 @@ static void kv_cache_free(struct whisper_kv_cache & cache) {
 // see the convert-pt-to-ggml.py script for details
 //
 static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
-    fprintf(stderr, "%s: loading model\n", __func__);
+    log("%s: loading model\n", __func__);
 
     const int64_t t_start_us = ggml_time_us();
 
@@ -826,7 +843,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         uint32_t magic;
         read_safe(loader, magic);
         if (magic != GGML_FILE_MAGIC) {
-            fprintf(stderr, "%s: invalid model data (bad magic)\n", __func__);
+            log("%s: invalid model data (bad magic)\n", __func__);
             return false;
         }
     }
@@ -877,25 +894,25 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         // in order to save memory and also to speed up the computation
         wctx.wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
         if (wctx.wtype == GGML_TYPE_COUNT) {
-            fprintf(stderr, "%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype);
+            log("%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype);
             return false;
         }
 
         const size_t scale = model.hparams.ftype ? 1 : 2;
 
-        fprintf(stderr, "%s: n_vocab       = %d\n", __func__, hparams.n_vocab);
-        fprintf(stderr, "%s: n_audio_ctx   = %d\n", __func__, hparams.n_audio_ctx);
-        fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
-        fprintf(stderr, "%s: n_audio_head  = %d\n", __func__, hparams.n_audio_head);
-        fprintf(stderr, "%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
-        fprintf(stderr, "%s: n_text_ctx    = %d\n", __func__, hparams.n_text_ctx);
-        fprintf(stderr, "%s: n_text_state  = %d\n", __func__, hparams.n_text_state);
-        fprintf(stderr, "%s: n_text_head   = %d\n", __func__, hparams.n_text_head);
-        fprintf(stderr, "%s: n_text_layer  = %d\n", __func__, hparams.n_text_layer);
-        fprintf(stderr, "%s: n_mels        = %d\n", __func__, hparams.n_mels);
-        fprintf(stderr, "%s: ftype         = %d\n", __func__, model.hparams.ftype);
-        fprintf(stderr, "%s: qntvr         = %d\n", __func__, qntvr);
-        fprintf(stderr, "%s: type          = %d\n", __func__, model.type);
+        log("%s: n_vocab       = %d\n", __func__, hparams.n_vocab);
+        log("%s: n_audio_ctx   = %d\n", __func__, hparams.n_audio_ctx);
+        log("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
+        log("%s: n_audio_head  = %d\n", __func__, hparams.n_audio_head);
+        log("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
+        log("%s: n_text_ctx    = %d\n", __func__, hparams.n_text_ctx);
+        log("%s: n_text_state  = %d\n", __func__, hparams.n_text_state);
+        log("%s: n_text_head   = %d\n", __func__, hparams.n_text_head);
+        log("%s: n_text_layer  = %d\n", __func__, hparams.n_text_layer);
+        log("%s: n_mels        = %d\n", __func__, hparams.n_mels);
+        log("%s: ftype         = %d\n", __func__, model.hparams.ftype);
+        log("%s: qntvr         = %d\n", __func__, qntvr);
+        log("%s: type          = %d\n", __func__, model.type);
 
         // print memory requirements
         {
@@ -913,7 +930,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
             const size_t mem_required_decoder =
                 scale*MEM_REQ_KV_SELF.at(model.type);
 
-            fprintf(stderr, "%s: mem required  = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
+            log("%s: mem required  = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
                     mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
         }
 
@@ -945,7 +962,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         read_safe(loader, n_vocab);
 
         //if (n_vocab != model.hparams.n_vocab) {
-        //    fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
+        //    log("%s: invalid model file '%s' (bad vocab size %d != %d)\n",
         //            __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
         //    return false;
         //}
@@ -965,7 +982,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
                 word.assign(&tmp[0], tmp.size());
             } else {
                 // seems like we have an empty-string token in multi-language models (i = 50256)
-                //fprintf(stderr, "%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
+                //log("%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
                 word = "";
             }
 
@@ -989,7 +1006,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         }
 
         if (n_vocab < model.hparams.n_vocab) {
-            fprintf(stderr, "%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
+            log("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
             for (int i = n_vocab; i < model.hparams.n_vocab; i++) {
                 if (i > vocab.token_beg) {
                     word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
@@ -1128,7 +1145,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
 
         ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*512; // object overhead
 
-        fprintf(stderr, "%s: model ctx     = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
+        log("%s: model ctx     = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
     }
 
     // create the ggml context
@@ -1141,7 +1158,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
 
         model.ctx = ggml_init(params);
         if (!model.ctx) {
-            fprintf(stderr, "%s: ggml_init() failed\n", __func__);
+            log("%s: ggml_init() failed\n", __func__);
             return false;
         }
     }
@@ -1374,20 +1391,20 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
             name.assign(&tmp[0], tmp.size());
 
             if (model.tensors.find(name) == model.tensors.end()) {
-                fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
+                log("%s: unknown tensor '%s' in model file\n", __func__, name.data());
                 return false;
             }
 
             auto tensor = model.tensors[name.data()];
             if (ggml_nelements(tensor) != nelements) {
-                fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
-                fprintf(stderr, "%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
+                log("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
+                log("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
                         __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
                 return false;
             }
 
             if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
-                fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
+                log("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
                         __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
                 return false;
             }
@@ -1395,7 +1412,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
             const size_t bpe = ggml_type_size(ggml_type(ttype));
 
             if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
-                fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
+                log("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
                         __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
                 return false;
             }
@@ -1408,12 +1425,12 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
             model.n_loaded++;
         }
 
-        fprintf(stderr, "%s: model size    = %7.2f MB\n", __func__, total_size/1024.0/1024.0);
+        log("%s: model size    = %7.2f MB\n", __func__, total_size/1024.0/1024.0);
 
         if (model.n_loaded == 0) {
-            fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
+            log("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
         } else if (model.n_loaded != (int) model.tensors.size()) {
-            fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
+            log("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
             return false;
         }
     }
@@ -1784,7 +1801,7 @@ static bool whisper_encode_internal(
         {
             struct ggml_cgraph gf = {};
 
-            ggml_build_forward_expand(&gf, cur);
+            ggml_build_forward_expand  (&gf, cur);
             ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
 
             //ggml_graph_print(&gf);
@@ -2283,7 +2300,7 @@ static bool whisper_decode_internal(
 
     // run the computation
     {
-        ggml_build_forward_expand(&gf, logits);
+        ggml_build_forward_expand  (&gf, logits);
         ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
     }
 
@@ -2329,6 +2346,23 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
     return std::string(buf);
 }
 
+#define SIN_COS_N_COUNT WHISPER_N_FFT
+static float sin_vals[SIN_COS_N_COUNT];
+static float cos_vals[SIN_COS_N_COUNT];
+
+// In FFT, we frequently use sine and cosine operations with the same values.
+// We can use precalculated values to speed up the process.
+static void fill_sin_cos_table() {
+    static bool is_filled = false;
+    if (is_filled) return;
+    for (int i = 0; i < SIN_COS_N_COUNT; i++) {
+        double theta = (2*M_PI*i)/SIN_COS_N_COUNT;
+        sin_vals[i] = sinf(theta);
+        cos_vals[i] = cosf(theta);
+    }
+    is_filled = true;
+}
+
 // naive Discrete Fourier Transform
 // input is real-valued
 // output is complex-valued
@@ -2336,15 +2370,16 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
     int N = in.size();
 
     out.resize(N*2);
+    const int sin_cos_step = SIN_COS_N_COUNT / N;
 
     for (int k = 0; k < N; k++) {
         float re = 0;
         float im = 0;
 
         for (int n = 0; n < N; n++) {
-            float angle = 2*M_PI*k*n/N;
-            re += in[n]*cos(angle);
-            im -= in[n]*sin(angle);
+            int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
+            re += in[n]*cos_vals[idx]; // cos(t)
+            im -= in[n]*sin_vals[idx]; // sin(t)
         }
 
         out[k*2 + 0] = re;
@@ -2392,11 +2427,11 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
     fft(even, even_fft);
     fft(odd, odd_fft);
 
+    const int sin_cos_step = SIN_COS_N_COUNT / N;
     for (int k = 0; k < N/2; k++) {
-        float theta = 2*M_PI*k/N;
-
-        float re = cos(theta);
-        float im = -sin(theta);
+        int idx = k * sin_cos_step; // t = 2*M_PI*k/N
+        float re = cos_vals[idx]; // cos(t)
+        float im = -sin_vals[idx]; // sin(t)
 
         float re_odd = odd_fft[2*k + 0];
         float im_odd = odd_fft[2*k + 1];
@@ -2409,41 +2444,51 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
     }
 }
 
-static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> &hann, const float *samples,
-                                              int n_samples, int fft_size, int fft_step, int n_threads,
-                                              const whisper_filters &filters, bool speed_up, whisper_mel &mel) {
-    std::vector<float> fft_in(fft_size, 0.0);
-    std::vector<float> fft_out(2 * fft_size);
-    int n_fft = 1 + (speed_up ? fft_size / 4 : fft_size / 2);
+static bool hann_window(int length, bool periodic, std::vector<float> & output) {
+    if (output.size() < length) {
+        output.resize(length);
+    }
+    int offset = -1;
+    if (periodic) {
+        offset = 0;
+    }
+    for (int i = 0; i < length; i++) {
+        output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset)));
+    }
 
-    for (int i = ith; i < mel.n_len; i += n_threads) {
-        const int offset = i * fft_step;
+    return true;
+}
 
-        // apply Hanning window
-        for (int j = 0; j < fft_size; j++) {
-            if (offset + j < n_samples) {
-                fft_in[j] = hann[j] * samples[offset + j];
-            } else {
-                fft_in[j] = 0.0;
-            }
+static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> & hann, const std::vector<float> & samples,
+                                              int n_samples, int frame_size, int frame_step, int n_threads,
+                                              const whisper_filters & filters, whisper_mel & mel) {
+    std::vector<float> fft_in(frame_size, 0.0);
+    std::vector<float> fft_out(2 * frame_step);
+    // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
+    int n_fft = 1 + (frame_size / 2);
+    int i = ith;
+
+    // calculate FFT only when fft_in are not all zero
+    for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
+        const int offset = i * frame_step;
+
+        // apply Hanning window (~10% faster)
+        for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
+            fft_in[j] = hann[j] * samples[offset + j];
+        }
+        // fill the rest with zeros
+        if (n_samples - offset < frame_size) {
+            std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
         }
 
-        // FFT -> mag^2
+        // FFT
         fft(fft_in, fft_out);
 
-        for (int j = 0; j < fft_size; j++) {
+        // Calculate modulus^2 of complex numbers
+        // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
+        for (int j = 0; j < frame_size; j++) {
             fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
         }
-        for (int j = 1; j < fft_size / 2; j++) {
-            fft_out[j] += fft_out[fft_size - j];
-        }
-
-        if (speed_up) {
-            // scale down in the frequency domain results in a speed up in the time domain
-            for (int j = 0; j < n_fft; j++) {
-                fft_out[j] = 0.5 * (fft_out[2 * j] + fft_out[2 * j + 1]);
-            }
-        }
 
         // mel spectrogram
         for (int j = 0; j < mel.n_mel; j++) {
@@ -2453,10 +2498,10 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
             int k = 0;
             for (k = 0; k < n_fft - 3; k += 4) {
                 sum +=
-                    fft_out[k + 0] * filters.data[j*n_fft + k + 0] +
-                    fft_out[k + 1] * filters.data[j*n_fft + k + 1] +
-                    fft_out[k + 2] * filters.data[j*n_fft + k + 2] +
-                    fft_out[k + 3] * filters.data[j*n_fft + k + 3];
+                        fft_out[k + 0] * filters.data[j * n_fft + k + 0] +
+                        fft_out[k + 1] * filters.data[j * n_fft + k + 1] +
+                        fft_out[k + 2] * filters.data[j * n_fft + k + 2] +
+                        fft_out[k + 3] * filters.data[j * n_fft + k + 3];
             }
 
             // handle n_fft remainder
@@ -2469,68 +2514,73 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
             mel.data[j * mel.n_len + i] = sum;
         }
     }
+
+    // Otherwise fft_out are all zero
+    double sum = log10(1e-10);
+    for (; i < mel.n_len; i += n_threads) {
+        for (int j = 0; j < mel.n_mel; j++) {
+            mel.data[j * mel.n_len + i] = sum;
+        }
+    }
 }
 
-// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
+// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
 static bool log_mel_spectrogram(
-          whisper_state & wstate,
-            const float * samples,
+              whisper_state & wstate,
+              const float * samples,
               const int   n_samples,
               const int   /*sample_rate*/,
-              const int   fft_size,
-              const int   fft_step,
+              const int   frame_size,
+              const int   frame_step,
               const int   n_mel,
               const int   n_threads,
-  const whisper_filters & filters,
-             const bool   speed_up,
-            whisper_mel & mel) {
+              const whisper_filters & filters,
+              const bool   debug,
+              whisper_mel & mel) {
     const int64_t t_start_us = ggml_time_us();
 
-    // Hanning window
+    // Hanning window (Use cosf to eliminate difference)
+    // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
+    // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
     std::vector<float> hann;
-    hann.resize(fft_size);
-    for (int i = 0; i < fft_size; i++) {
-        hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size)));
-    }
+    hann_window(frame_size, true, hann);
 
-    mel.n_mel     = n_mel;
-    mel.n_len     = n_samples/fft_step;
-    mel.n_len_org = mel.n_len;
 
-    std::vector<float> samples_padded;
-
-    // pad audio with at least one extra chunk of zeros
-    {
-        const int pad = (100*WHISPER_CHUNK_SIZE)/2;
+    // Calculate the length of padding
+    int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
+    int64_t stage_2_pad = frame_size / 2;
 
-        if (mel.n_len % pad != 0) {
-            mel.n_len = (mel.n_len/pad + 1)*pad;
-        }
-        mel.n_len += pad;
+    // Initialize a vector and copy data from C array to it.
+    std::vector<float> samples_padded;
+    samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2);
+    std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad);
 
-        samples_padded.resize(mel.n_len*fft_step);
-        memcpy(samples_padded.data(), samples, n_samples*sizeof(float));
-        memset(samples_padded.data() + n_samples, 0, (mel.n_len*fft_step - n_samples)*sizeof(float));
+    // pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio
+    std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0);
 
-        samples = samples_padded.data();
-    }
+    // reflective pad 200 samples at the beginning of audio
+    std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());
 
-    mel.data.resize(mel.n_mel*mel.n_len);
+    mel.n_mel     = n_mel;
+    // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
+    // Calculate number of frames + remove the last frame
+    mel.n_len     = (samples_padded.size() - frame_size) / frame_step;
+    // Calculate semi-padded sample length to ensure compatibility
+    mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
+    mel.data.resize(mel.n_mel * mel.n_len);
 
-    //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
-    //printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
 
     {
         std::vector<std::thread> workers(n_threads - 1);
         for (int iw = 0; iw < n_threads - 1; ++iw) {
             workers[iw] = std::thread(
-                    log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples,
-                    n_samples, fft_size, fft_step, n_threads,
-                    std::cref(filters), speed_up, std::ref(mel));
+                    log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples_padded,
+                    n_samples + stage_2_pad, frame_size, frame_step, n_threads,
+                    std::cref(filters), std::ref(mel));
         }
 
         // main thread
-        log_mel_spectrogram_worker_thread(0, hann, samples, n_samples, fft_size, fft_step, n_threads, filters, speed_up, mel);
+        log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel);
 
         for (int iw = 0; iw < n_threads - 1; ++iw) {
             workers[iw].join();
@@ -2544,7 +2594,6 @@ static bool log_mel_spectrogram(
             mmax = mel.data[i];
         }
     }
-    //printf("%s: max = %f\n", __func__, mmax);
 
     mmax -= 8.0;
 
@@ -2558,7 +2607,16 @@ static bool log_mel_spectrogram(
 
     wstate.t_mel_us += ggml_time_us() - t_start_us;
 
-    //printf("mel.n_len() = %d, divided by 1500: %f, n_samples / fft_step: %d\n", mel.n_len, mel.n_len / 1500.0, n_samples / fft_step);
+    // Dump log_mel_spectrogram
+    if (debug) {
+        std::ofstream outFile("log_mel_spectrogram.json");
+        outFile << "[";
+        for (uint64_t i = 0; i < mel.data.size() - 1; i++) {
+            outFile << mel.data[i] << ", ";
+        }
+        outFile << mel.data[mel.data.size() - 1] << "]";
+        outFile.close();
+    }
 
     return true;
 }
@@ -2614,7 +2672,7 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
                 --j;
             }
             if (!found) {
-                fprintf(stderr, "unknown token \n");
+                log("unknown token\n");
                 ++i;
             }
         }
@@ -2676,46 +2734,47 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
 #endif
 
 struct whisper_state * whisper_init_state(whisper_context * ctx) {
+    fill_sin_cos_table();
     whisper_state * state = new whisper_state;
 
     const size_t scale = ctx->model.hparams.ftype ? 1 : 2;
 
     if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
-        fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
+        log("%s: kv_cache_init() failed for self-attention cache\n", __func__);
         delete state;
         return nullptr;
     }
 
     {
         const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v);
-        fprintf(stderr, "%s: kv self size  = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
+        log("%s: kv self size  = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
     }
 
     if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
-        fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
+        log("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
         delete state;
         return nullptr;
     }
 
     {
         const size_t memory_size = ggml_nbytes(state->kv_cross.k) + ggml_nbytes(state->kv_cross.v);
-        fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
+        log("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
     }
 
 #ifdef WHISPER_USE_COREML
     const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);
 
-    fprintf(stderr, "%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
-    fprintf(stderr, "%s: first run on a device may take a while ...\n", __func__);
+    log("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
+    log("%s: first run on a device may take a while ...\n", __func__);
 
     state->ctx_coreml = whisper_coreml_init(path_coreml.c_str());
     if (!state->ctx_coreml) {
-        fprintf(stderr, "%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
+        log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
 #ifndef WHISPER_COREML_ALLOW_FALLBACK
         return nullptr;
 #endif
     } else {
-        fprintf(stderr, "%s: Core ML model loaded\n", __func__);
+        log("%s: Core ML model loaded\n", __func__);
     }
 #endif
 
@@ -2755,7 +2814,7 @@ int whisper_ctx_init_openvino_encoder(
     return 1;
 #else
     if (!model_path && ctx->path_model.empty()) {
-        fprintf(stderr, "%s: model_path is nullptr, and ctx has no model_path set.\n", __func__);
+        log("%s: model_path is nullptr, and ctx has no model_path set.\n", __func__);
         return 1;
     }
 
@@ -2775,15 +2834,15 @@ int whisper_ctx_init_openvino_encoder(
         path_cache = cache_dir;
     }
 
-    fprintf(stderr, "%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str());
-    fprintf(stderr, "%s: first run on a device may take a while ...\n", __func__);
+    log("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str());
+    log("%s: first run on a device may take a while ...\n", __func__);
 
     ctx->state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
     if (!ctx->state->ctx_openvino) {
-        fprintf(stderr, "%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str());
+        log("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str());
         return 1;
     } else {
-        fprintf(stderr, "%s: OpenVINO model loaded\n", __func__);
+        log("%s: OpenVINO model loaded\n", __func__);
     }
 
     return 0;
@@ -2792,11 +2851,11 @@ int whisper_ctx_init_openvino_encoder(
 
 struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
 
-    fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model);
+    log("%s: loading model from '%s'\n", __func__, path_model);
 
     auto fin = std::ifstream(path_model, std::ios::binary);
     if (!fin) {
-        fprintf(stderr, "%s: failed to open '%s'\n", __func__, path_model);
+        log("%s: failed to open '%s'\n", __func__, path_model);
         return nullptr;
     }
 
@@ -2838,7 +2897,7 @@ struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t
 
     buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 };
 
-    fprintf(stderr, "%s: loading model from buffer\n", __func__);
+    log("%s: loading model from buffer\n", __func__);
 
     whisper_model_loader loader = {};
 
@@ -2873,7 +2932,7 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
 
     if (!whisper_model_load(loader, *ctx)) {
         loader->close(loader->context);
-        fprintf(stderr, "%s: failed to load model\n", __func__);
+        log("%s: failed to load model\n", __func__);
         delete ctx;
         return nullptr;
     }
@@ -2978,7 +3037,7 @@ void whisper_free_params(struct whisper_full_params * params) {
 
 int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
     if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
-        fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
+        log("%s: failed to compute mel spectrogram\n", __func__);
         return -1;
     }
 
@@ -2989,21 +3048,30 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
     return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
 }
 
-// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
+// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
 int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
-    if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, state->mel)) {
-        fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
+    if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
+        log("%s: failed to compute mel spectrogram\n", __func__);
         return -1;
     }
 
     return 0;
 }
 
-// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
+// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
 int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
     return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads);
 }
 
+// same as whisper_pcm_to_mel, but applies WSOLA to speed up the audio x2
+// TODO
+
+// same as whisper_pcm_to_mel, but applies HPTSM to speed up the audio x2
+// TODO
+
+// same as whisper_pcm_to_mel, but applies PV (with phase lock) to speed up the audio x2
+// TODO
+
 int whisper_set_mel_with_state(
         struct whisper_context * /*ctx*/,
           struct whisper_state * state,
@@ -3011,7 +3079,7 @@ int whisper_set_mel_with_state(
                            int   n_len,
                            int   n_mel) {
     if (n_mel != WHISPER_N_MEL) {
-        fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL);
+        log("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL);
         return -1;
     }
 
@@ -3035,7 +3103,7 @@ int whisper_set_mel(
 
 int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
     if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) {
-        fprintf(stderr, "%s: failed to eval\n", __func__);
+        log("%s: failed to eval\n", __func__);
         return -1;
     }
 
@@ -3044,7 +3112,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state
 
 int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
     if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) {
-        fprintf(stderr, "%s: failed to eval\n", __func__);
+        log("%s: failed to eval\n", __func__);
         return -1;
     }
 
@@ -3055,7 +3123,7 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state
     const int selected_decoder_id = 0;
 
     if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
-        fprintf(stderr, "%s: failed to eval\n", __func__);
+        log("%s: failed to eval\n", __func__);
         return 1;
     }
 
@@ -3067,13 +3135,12 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
     const int selected_decoder_id = 0;
 
     if (ctx->state == nullptr) {
-        fprintf(stderr, "%s: ERROR state was not loaded.\n", __func__);
+        log("%s: ERROR state was not loaded.\n", __func__);
         return false;
     }
 
-
     if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
-        fprintf(stderr, "%s: failed to eval\n", __func__);
+        log("%s: failed to eval\n", __func__);
         return 1;
     }
 
@@ -3084,7 +3151,7 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to
     const auto res = tokenize(ctx->vocab, text);
 
     if (n_max_tokens < (int) res.size()) {
-        fprintf(stderr, "%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
+        log("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
         return -1;
     }
 
@@ -3112,7 +3179,7 @@ int whisper_lang_id(const char * lang) {
             }
         }
 
-        fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
+        log("%s: unknown language '%s'\n", __func__, lang);
         return -1;
     }
     return g_lang.at(lang).first;
@@ -3125,7 +3192,7 @@ const char * whisper_lang_str(int id) {
         }
     }
 
-    fprintf(stderr, "%s: unknown language id %d\n", __func__, id);
+    log("%s: unknown language id %d\n", __func__, id);
     return nullptr;
 }
 
@@ -3138,25 +3205,25 @@ int whisper_lang_auto_detect_with_state(
     const int seek = offset_ms/10;
 
     if (seek < 0) {
-        fprintf(stderr, "%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
+        log("%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
         return -1;
     }
 
     if (seek >= state->mel.n_len_org) {
-        fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10);
+        log("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10);
         return -2;
     }
 
     // run the encoder
     if (whisper_encode_with_state(ctx, state, seek, n_threads) != 0) {
-        fprintf(stderr, "%s: failed to encode\n", __func__);
+        log("%s: failed to encode\n", __func__);
         return -6;
     }
 
     const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
 
     if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) {
-        fprintf(stderr, "%s: failed to decode\n", __func__);
+        log("%s: failed to decode\n", __func__);
         return -7;
     }
 
@@ -3305,7 +3372,6 @@ float * whisper_get_logits(struct whisper_context * ctx) {
     return ctx->state->logits.data();
 }
 
-
 float * whisper_get_logits_from_state(struct whisper_state * state) {
     return state->logits.data();
 }
@@ -3357,21 +3423,21 @@ whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
 void whisper_print_timings(struct whisper_context * ctx) {
     const int64_t t_end_us = ggml_time_us();
 
-    fprintf(stderr, "\n");
-    fprintf(stderr, "%s:     load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
+    log("\n");
+    log("%s:     load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
     if (ctx->state != nullptr) {
 
         const int32_t n_sample = std::max(1, ctx->state->n_sample);
         const int32_t n_encode = std::max(1, ctx->state->n_encode);
         const int32_t n_decode = std::max(1, ctx->state->n_decode);
 
-        fprintf(stderr, "%s:     fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
-        fprintf(stderr, "%s:      mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
-        fprintf(stderr, "%s:   sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
-        fprintf(stderr, "%s:   encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
-        fprintf(stderr, "%s:   decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
+        log("%s:     fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
+        log("%s:      mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
+        log("%s:   sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
+        log("%s:   encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
+        log("%s:   decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
     }
-    fprintf(stderr, "%s:    total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
+    log("%s:    total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
 }
 
 void whisper_reset_timings(struct whisper_context * ctx) {
@@ -3413,6 +3479,7 @@ const char * whisper_print_system_info(void) {
     s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
     s += "BLAS = "      + std::to_string(ggml_cpu_has_blas())      + " | ";
     s += "SSE3 = "      + std::to_string(ggml_cpu_has_sse3())      + " | ";
+    s += "SSSE3 = "     + std::to_string(ggml_cpu_has_ssse3())     + " | ";
     s += "VSX = "       + std::to_string(ggml_cpu_has_vsx())       + " | ";
     s += "COREML = "    + std::to_string(whisper_has_coreml())     + " | ";
     s += "OPENVINO = "  + std::to_string(whisper_has_openvino())   + " | ";
@@ -3455,6 +3522,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
         /*.max_tokens        =*/ 0,
 
         /*.speed_up          =*/ false,
+        /*.debug_mode        =*/ false,
         /*.audio_ctx         =*/ 0,
 
         /*.tdrz_enable       =*/ false,
@@ -3616,7 +3684,7 @@ static void whisper_process_logits(
     WHISPER_ASSERT(n_logits == ctx.vocab.n_vocab);
 
     // extract the logits for the last token
-    // we will be mutating and therefore we don't want to use the ctx.logits buffer directly
+    // we will be mutating, and therefore we don't want to use the ctx.logits buffer directly
     auto & probs    = decoder.probs;
     auto & logits   = decoder.logits;
     auto & logprobs = decoder.logprobs;
@@ -3695,7 +3763,7 @@ static void whisper_process_logits(
             const bool last_was_timestamp        = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg;
             const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg;
 
-            //fprintf(stderr, "last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp);
+            //log("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp);
 
             if (last_was_timestamp) {
                 if (penultimate_was_timestamp) {
@@ -3771,7 +3839,7 @@ static void whisper_process_logits(
 
             const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg);
 
-            //fprintf(stderr, "timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
+            //log("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
 
             if (timestamp_logprob > max_text_token_logprob) {
                 for (int i = 0; i < vocab.token_beg; ++i) {
@@ -4017,16 +4085,17 @@ int whisper_full_with_state(
 
     result_all.clear();
 
-    // compute log mel spectrogram
-    if (params.speed_up) {
-        if (whisper_pcm_to_mel_phase_vocoder_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
-            fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
+    if (n_samples > 0) {
+        // compute log mel spectrogram
+        if (params.speed_up) {
+            // TODO: Replace PV with more advanced algorithm
+            log("%s: failed to compute log mel spectrogram\n", __func__);
             return -1;
-        }
-    } else {
-        if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
-            fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
-            return -2;
+        } else {
+            if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
+                log("%s: failed to compute log mel spectrogram\n", __func__);
+                return -2;
+            }
         }
     }
 
@@ -4036,13 +4105,13 @@ int whisper_full_with_state(
 
         const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
         if (lang_id < 0) {
-            fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
+            log("%s: failed to auto-detect language\n", __func__);
             return -3;
         }
         state->lang_id = lang_id;
         params.language = whisper_lang_str(lang_id);
 
-        fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
+        log("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
         if (params.detect_language) {
             return 0;
         }
@@ -4052,14 +4121,16 @@ int whisper_full_with_state(
         state->t_beg    = 0;
         state->t_last   = 0;
         state->tid_last = 0;
-        state->energy = get_signal_energy(samples, n_samples, 32);
+        if (n_samples > 0) {
+            state->energy = get_signal_energy(samples, n_samples, 32);
+        }
     }
 
     const int seek_start = params.offset_ms/10;
     const int seek_end = params.duration_ms == 0 ? whisper_n_len_from_state(state) : seek_start + params.duration_ms/10;
 
-    // if length of spectrogram is less than 1s (100 samples), then return
-    // basically don't process anything that is less than 1s
+    // if length of spectrogram is less than 1.0s (100 frames), then return
+    // basically don't process anything that is less than 1.0s
     // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
     if (seek_end < seek_start + (params.speed_up ? 50 : 100)) {
         return 0;
@@ -4099,7 +4170,7 @@ int whisper_full_with_state(
         if (decoder.kv_self.ctx == nullptr) {
             decoder.kv_self = state->decoders[0].kv_self;
             if (!kv_cache_reinit(decoder.kv_self)) {
-                fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
+                log("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
                 return -4;
             }
 
@@ -4143,7 +4214,7 @@ int whisper_full_with_state(
 
     // overwrite audio_ctx, max allowed is hparams.n_audio_ctx
     if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
-        fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
+        log("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
         return -5;
     }
     state->exp_n_audio_ctx = params.audio_ctx;
@@ -4161,9 +4232,6 @@ int whisper_full_with_state(
         }
     }
 
-    int progress_prev = 0;
-    int progress_step = 5;
-
     int seek = seek_start;
 
     std::vector<whisper_token> prompt;
@@ -4190,16 +4258,11 @@ int whisper_full_with_state(
 
     // main loop
     while (true) {
-        const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
-        while (progress_cur >= progress_prev + progress_step) {
-            progress_prev += progress_step;
-            if (params.print_progress) {
-                fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress_prev);
-            }
-        }
         if (params.progress_callback) {
+            const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
+
             params.progress_callback(
-                ctx, ctx->state, progress_prev, params.progress_callback_user_data);
+                ctx, ctx->state, progress_cur, params.progress_callback_user_data);
         }
 
         // of only 1 second left, then stop
@@ -4209,14 +4272,14 @@ int whisper_full_with_state(
 
         if (params.encoder_begin_callback) {
             if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) {
-                fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
+                log("%s: encoder_begin_callback returned false - aborting\n", __func__);
                 break;
             }
         }
 
         // encode audio features starting at offset seek
         if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) {
-            fprintf(stderr, "%s: failed to encode\n", __func__);
+            log("%s: failed to encode\n", __func__);
             return -6;
         }
 
@@ -4299,7 +4362,7 @@ int whisper_full_with_state(
                 WHISPER_PRINT_DEBUG("\n\n");
 
                 if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
-                    fprintf(stderr, "%s: failed to decode\n", __func__);
+                    log("%s: failed to decode\n", __func__);
                     return -7;
                 }
 
@@ -4537,7 +4600,7 @@ int whisper_full_with_state(
                     //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
 
                     if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
-                        fprintf(stderr, "%s: failed to decode\n", __func__);
+                        log("%s: failed to decode\n", __func__);
                         return -8;
                     }
 
@@ -4752,7 +4815,6 @@ int whisper_full_with_state(
     return 0;
 }
 
-
 int whisper_full(
         struct whisper_context * ctx,
     struct whisper_full_params   params,
@@ -4829,7 +4891,6 @@ int whisper_full_parallel(
             result.t0 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t;
             result.t1 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t;
 
-
             // make sure that segments are not overlapping
             if (!ctx->state->result_all.empty()) {
                 result.t0 = std::max(result.t0, ctx->state->result_all.back().t1);
@@ -4859,12 +4920,12 @@ int whisper_full_parallel(
     ctx->state->t_decode_us /= n_processors;
 
     // print information about the audio boundaries
-    fprintf(stderr, "\n");
-    fprintf(stderr, "%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors);
+    log("\n");
+    log("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors);
     for (int i = 0; i < n_processors - 1; ++i) {
-        fprintf(stderr, "%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str());
+        log("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str());
     }
-    fprintf(stderr, "%s: the transcription quality may be degraded near these boundaries\n", __func__);
+    log("%s: the transcription quality may be degraded near these boundaries\n", __func__);
 
     return ret;
 }
@@ -5222,7 +5283,7 @@ static void whisper_exp_compute_token_level_timestamps(
     const int n_samples = state.energy.size();
 
     if (n_samples == 0) {
-        fprintf(stderr, "%s: no signal data available\n", __func__);
+        log("%s: no signal data available\n", __func__);
         return;
     }
 
@@ -5442,3 +5503,7 @@ static void whisper_exp_compute_token_level_timestamps(
     //    }
     //}
 }
+
+void whisper_set_log_callback(whisper_log_callback callback) {
+    whisper_log = callback;
+}
index 83af11bd848c66ff1fd6b5c9e97eec6f1484165d..73ab4d799a23ad73bf9a6406fd9e503116bb8917 100644 (file)
@@ -67,6 +67,7 @@ extern "C" {
 
     struct whisper_context;
     struct whisper_state;
+    struct whisper_full_params;
 
     typedef int whisper_token;
 
@@ -345,7 +346,7 @@ extern "C" {
                               void * user_data);
 
     // Parameters for the whisper_full() function
-    // If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
+    // If you change the order or add new parameters, make sure to update the default values in whisper.cpp:
     // whisper_full_default_params()
     struct whisper_full_params {
         enum whisper_sampling_strategy strategy;
@@ -374,6 +375,7 @@ extern "C" {
         // [EXPERIMENTAL] speed-up techniques
         // note: these can significantly reduce the quality of the output
         bool speed_up;          // speed-up the audio by 2x using Phase Vocoder
+        bool debug_mode;        // enable debug_mode provides extra info (eg. Dump log_mel)
         int  audio_ctx;         // overwrite the audio context size (0 = use default)
 
         // [EXPERIMENTAL] [TDRZ] tinydiarize
@@ -517,6 +519,11 @@ extern "C" {
     WHISPER_API int          whisper_bench_ggml_mul_mat    (int n_threads);
     WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads);
 
+    // Control logging output; default behavior is to print to stderr
+
+    typedef void (*whisper_log_callback)(const char * line);
+    WHISPER_API void whisper_set_log_callback(whisper_log_callback callback);
+
 #ifdef __cplusplus
 }
 #endif