]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
py : make convert-pt-to-ggml.py backwards compatible with older vocab.json tokenizer...
authorAkash Mahajan <redacted>
Sun, 25 Jun 2023 10:50:14 +0000 (03:50 -0700)
committerGitHub <redacted>
Sun, 25 Jun 2023 10:50:14 +0000 (13:50 +0300)
* patch checkpoint convert script to keep compatibility with older hf_transformers whisper tokenizer

* typo fix

models/convert-pt-to-ggml.py

index f5aa6bd37b5417977d0097219879e054038c357b..9aa134b53f7d05c1f9d2be60759f28b14e87fdc6 100644 (file)
@@ -224,16 +224,39 @@ with np.load(dir_whisper / "whisper" / "assets" / "mel_filters.npz") as f:
 
 #code.interact(local=locals())
 
+# load tokenizer
+# for backwards compatibility, also check for older hf_transformers format tokenizer files
+# old format: dir_whisper/whisper/assets/[multilingual/gpt2]/vocab.json
+# new format: dir_whisper/whisper/assets/[multilingual/gpt2].tiktoken
 multilingual = hparams["n_vocab"] == 51865
 tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual.tiktoken" or "gpt2.tiktoken")
+tokenizer_type = "tiktoken"
+if not tokenizer.is_file():
+    tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual" or "gpt2") / "vocab.json"
+    tokenizer_type = "hf_transformers"
+    if not tokenizer.is_file():
+        print("Error: failed to find either tiktoken or hf_transformers tokenizer file:", tokenizer)
+        sys.exit(1)
+
+byte_encoder = bytes_to_unicode()
+byte_decoder = {v:k for k, v in byte_encoder.items()}
+
+if tokenizer_type == "tiktoken":
+    with open(tokenizer, "rb") as f:
+        contents = f.read()
+        tokens = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)}
+elif tokenizer_type == "hf_transformers":
+    with open(tokenizer, "r", encoding="utf8") as f:
+        _tokens_raw = json.load(f)
+        if '<|endoftext|>' in _tokens_raw:
+            # ensures exact same model as tokenizer_type == tiktoken
+            # details: https://github.com/ggerganov/whisper.cpp/pull/725
+            del _tokens_raw['<|endoftext|>']
+        tokens = {bytes([byte_decoder[c] for c in token]): int(idx) for token, idx in _tokens_raw.items()}
 
 # output in the same directory as the model
 fname_out = dir_out / "ggml-model.bin"
 
-with open(tokenizer, "rb") as f:
-    contents = f.read()
-    tokens = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)}
-
 # use 16-bit or 32-bit floats
 use_f16 = True
 if len(sys.argv) > 4:
@@ -262,9 +285,7 @@ for i in range(filters.shape[0]):
     for j in range(filters.shape[1]):
         fout.write(struct.pack("f", filters[i][j]))
 
-byte_encoder = bytes_to_unicode()
-byte_decoder = {v:k for k, v in byte_encoder.items()}
-
+# write tokenizer
 fout.write(struct.pack("i", len(tokens)))
 
 for key in tokens: