#code.interact(local=locals())
+# load tokenizer
+# for backwards compatibility, also check for older hf_transformers format tokenizer files
+# old format: dir_whisper/whisper/assets/[multilingual/gpt2]/vocab.json
+# new format: dir_whisper/whisper/assets/[multilingual/gpt2].tiktoken
multilingual = hparams["n_vocab"] == 51865
tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual.tiktoken" or "gpt2.tiktoken")
+tokenizer_type = "tiktoken"
+if not tokenizer.is_file():
+ tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual" or "gpt2") / "vocab.json"
+ tokenizer_type = "hf_transformers"
+ if not tokenizer.is_file():
+ print("Error: failed to find either tiktoken or hf_transformers tokenizer file:", tokenizer)
+ sys.exit(1)
+
+byte_encoder = bytes_to_unicode()
+byte_decoder = {v:k for k, v in byte_encoder.items()}
+
+if tokenizer_type == "tiktoken":
+ with open(tokenizer, "rb") as f:
+ contents = f.read()
+ tokens = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)}
+elif tokenizer_type == "hf_transformers":
+ with open(tokenizer, "r", encoding="utf8") as f:
+ _tokens_raw = json.load(f)
+ if '<|endoftext|>' in _tokens_raw:
+ # ensures exact same model as tokenizer_type == tiktoken
+ # details: https://github.com/ggerganov/whisper.cpp/pull/725
+ del _tokens_raw['<|endoftext|>']
+ tokens = {bytes([byte_decoder[c] for c in token]): int(idx) for token, idx in _tokens_raw.items()}
# output in the same directory as the model
fname_out = dir_out / "ggml-model.bin"
-with open(tokenizer, "rb") as f:
- contents = f.read()
- tokens = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)}
-
# use 16-bit or 32-bit floats
use_f16 = True
if len(sys.argv) > 4:
for j in range(filters.shape[1]):
fout.write(struct.pack("f", filters[i][j]))
-byte_encoder = bytes_to_unicode()
-byte_decoder = {v:k for k, v in byte_encoder.items()}
-
+# write tokenizer
fout.write(struct.pack("i", len(tokens)))
for key in tokens: