set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX2 /D_CRT_SECURE_NO_WARNINGS=1")
else()
if (EMSCRIPTEN)
+ # we require support for WASM SIMD 128-bit
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread -msimd128")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread")
else()
ar rcs libwhisper.a ggml.o whisper.o
clean:
- rm -f *.o main libwhisper.a
+ rm -f *.o main stream libwhisper.a
#
# Examples
https://ggml.ggerganov.com
-For more details, see the conversion script [convert-pt-to-ggml.py](convert-pt-to-ggml.py) or the README in [models](models).
+For more details, see the conversion script [models/convert-pt-to-ggml.py](models/convert-pt-to-ggml.py) or the README in [models](models).
## Bindings
+++ /dev/null
-# Convert Whisper transformer model from PyTorch to ggml format
-#
-# Usage: python convert-pt-to-ggml.py ~/.cache/whisper/medium.pt ~/path/to/repo/whisper/ ./models/whisper-medium
-#
-# You need to clone the original repo in ~/path/to/repo/whisper/
-#
-# git clone https://github.com/openai/whisper ~/path/to/repo/whisper/
-#
-# It is used to various assets needed by the algorithm:
-#
-# - tokenizer
-# - mel filters
-#
-# Also, you need to have the original models in ~/.cache/whisper/
-# See the original repo for more details.
-#
-# This script loads the specified model and whisper assets and saves them in ggml format.
-# The output is a single binary file containing the following information:
-#
-# - hparams
-# - mel filters
-# - tokenizer vocab
-# - model variables
-#
-# For each variable, write the following:
-#
-# - Number of dimensions (int)
-# - Name length (int)
-# - Dimensions (int[n_dims])
-# - Name (char[name_length])
-# - Data (float[n_dims])
-#
-
-import io
-import os
-import sys
-import struct
-import json
-import code
-import torch
-import numpy as np
-
-from transformers import GPTJForCausalLM
-from transformers import GPT2TokenizerFast
-
-# ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L10-L110
-LANGUAGES = {
- "en": "english",
- "zh": "chinese",
- "de": "german",
- "es": "spanish",
- "ru": "russian",
- "ko": "korean",
- "fr": "french",
- "ja": "japanese",
- "pt": "portuguese",
- "tr": "turkish",
- "pl": "polish",
- "ca": "catalan",
- "nl": "dutch",
- "ar": "arabic",
- "sv": "swedish",
- "it": "italian",
- "id": "indonesian",
- "hi": "hindi",
- "fi": "finnish",
- "vi": "vietnamese",
- "iw": "hebrew",
- "uk": "ukrainian",
- "el": "greek",
- "ms": "malay",
- "cs": "czech",
- "ro": "romanian",
- "da": "danish",
- "hu": "hungarian",
- "ta": "tamil",
- "no": "norwegian",
- "th": "thai",
- "ur": "urdu",
- "hr": "croatian",
- "bg": "bulgarian",
- "lt": "lithuanian",
- "la": "latin",
- "mi": "maori",
- "ml": "malayalam",
- "cy": "welsh",
- "sk": "slovak",
- "te": "telugu",
- "fa": "persian",
- "lv": "latvian",
- "bn": "bengali",
- "sr": "serbian",
- "az": "azerbaijani",
- "sl": "slovenian",
- "kn": "kannada",
- "et": "estonian",
- "mk": "macedonian",
- "br": "breton",
- "eu": "basque",
- "is": "icelandic",
- "hy": "armenian",
- "ne": "nepali",
- "mn": "mongolian",
- "bs": "bosnian",
- "kk": "kazakh",
- "sq": "albanian",
- "sw": "swahili",
- "gl": "galician",
- "mr": "marathi",
- "pa": "punjabi",
- "si": "sinhala",
- "km": "khmer",
- "sn": "shona",
- "yo": "yoruba",
- "so": "somali",
- "af": "afrikaans",
- "oc": "occitan",
- "ka": "georgian",
- "be": "belarusian",
- "tg": "tajik",
- "sd": "sindhi",
- "gu": "gujarati",
- "am": "amharic",
- "yi": "yiddish",
- "lo": "lao",
- "uz": "uzbek",
- "fo": "faroese",
- "ht": "haitian creole",
- "ps": "pashto",
- "tk": "turkmen",
- "nn": "nynorsk",
- "mt": "maltese",
- "sa": "sanskrit",
- "lb": "luxembourgish",
- "my": "myanmar",
- "bo": "tibetan",
- "tl": "tagalog",
- "mg": "malagasy",
- "as": "assamese",
- "tt": "tatar",
- "haw": "hawaiian",
- "ln": "lingala",
- "ha": "hausa",
- "ba": "bashkir",
- "jw": "javanese",
- "su": "sundanese",
-}
-
-# ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L273-L292
-def build_tokenizer(path_to_whisper_repo: str, name: str = "gpt2"):
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
- path = os.path.join(path_to_whisper_repo, "whisper/assets", name)
- tokenizer = GPT2TokenizerFast.from_pretrained(path)
-
- specials = [
- "<|startoftranscript|>",
- *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
- "<|translate|>",
- "<|transcribe|>",
- "<|startoflm|>",
- "<|startofprev|>",
- "<|nocaptions|>",
- "<|notimestamps|>",
- ]
-
- tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
- return tokenizer
-
-# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
-def bytes_to_unicode():
- """
- Returns list of utf-8 byte and a corresponding list of unicode strings.
- The reversible bpe codes work on unicode strings.
- This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
- When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
- This is a signficant percentage of your normal, say, 32K bpe vocab.
- To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
- And avoids mapping to whitespace/control characters the bpe code barfs on.
- """
- bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
- cs = bs[:]
- n = 0
- for b in range(2**8):
- if b not in bs:
- bs.append(b)
- cs.append(2**8+n)
- n += 1
- cs = [chr(n) for n in cs]
- return dict(zip(bs, cs))
-
-
-if len(sys.argv) < 4:
- print("Usage: convert-pt-to-ggml.py model.pt path-to-whisper-repo dir-output [use-f32]\n")
- sys.exit(1)
-
-fname_inp = sys.argv[1]
-dir_whisper = sys.argv[2]
-dir_out = sys.argv[3]
-
-# try to load PyTorch binary data
-try:
- model_bytes = open(fname_inp, "rb").read()
- with io.BytesIO(model_bytes) as fp:
- checkpoint = torch.load(fp, map_location="cpu")
-except:
- print("Error: failed to load PyTorch model file: %s" % fname_inp)
- sys.exit(1)
-
-hparams = checkpoint["dims"]
-print("hparams:", hparams)
-
-list_vars = checkpoint["model_state_dict"]
-
-#print(list_vars['encoder.positional_embedding'])
-#print(list_vars['encoder.conv1.weight'])
-#print(list_vars['encoder.conv1.weight'].shape)
-
-# load mel filters
-n_mels = hparams["n_mels"]
-with np.load(os.path.join(dir_whisper, "whisper/assets", "mel_filters.npz")) as f:
- filters = torch.from_numpy(f[f"mel_{n_mels}"])
- #print (filters)
-
-#code.interact(local=locals())
-
-multilingual = hparams["n_vocab"] == 51865
-tokenizer = build_tokenizer(dir_whisper, multilingual and "multilingual" or "gpt2")
-
-#print(tokenizer)
-#print(tokenizer.name_or_path)
-#print(len(tokenizer.additional_special_tokens))
-dir_tokenizer = tokenizer.name_or_path
-
-# output in the same directory as the model
-fname_out = dir_out + "/ggml-model.bin"
-
-with open(dir_tokenizer + "/vocab.json", "r") as f:
- tokens = json.load(f)
-
-# use 16-bit or 32-bit floats
-use_f16 = True
-if len(sys.argv) > 4:
- use_f16 = False
- fname_out = dir_out + "/ggml-model-f32.bin"
-
-fout = open(fname_out, "wb")
-
-fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
-fout.write(struct.pack("i", hparams["n_vocab"]))
-fout.write(struct.pack("i", hparams["n_audio_ctx"]))
-fout.write(struct.pack("i", hparams["n_audio_state"]))
-fout.write(struct.pack("i", hparams["n_audio_head"]))
-fout.write(struct.pack("i", hparams["n_audio_layer"]))
-fout.write(struct.pack("i", hparams["n_text_ctx"]))
-fout.write(struct.pack("i", hparams["n_text_state"]))
-fout.write(struct.pack("i", hparams["n_text_head"]))
-fout.write(struct.pack("i", hparams["n_text_layer"]))
-fout.write(struct.pack("i", hparams["n_mels"]))
-fout.write(struct.pack("i", use_f16))
-
-# write mel filters
-fout.write(struct.pack("i", filters.shape[0]))
-fout.write(struct.pack("i", filters.shape[1]))
-for i in range(filters.shape[0]):
- for j in range(filters.shape[1]):
- fout.write(struct.pack("f", filters[i][j]))
-
-byte_encoder = bytes_to_unicode()
-byte_decoder = {v:k for k, v in byte_encoder.items()}
-
-fout.write(struct.pack("i", len(tokens)))
-
-for key in tokens:
- text = bytearray([byte_decoder[c] for c in key])
- fout.write(struct.pack("i", len(text)))
- fout.write(text)
-
-for name in list_vars.keys():
- data = list_vars[name].squeeze().numpy()
- print("Processing variable: " + name + " with shape: ", data.shape)
-
- # reshape conv bias from [n] to [n, 1]
- if name == "encoder.conv1.bias" or \
- name == "encoder.conv2.bias":
- data = data.reshape(data.shape[0], 1)
- print(" Reshaped variable: " + name + " to shape: ", data.shape)
-
- n_dims = len(data.shape);
-
- # looks like the whisper models are in f16 by default
- # so we need to convert the small tensors to f32 until we fully support f16 in ggml
- # ftype == 0 -> float32, ftype == 1 -> float16
- ftype = 1;
- if use_f16:
- if n_dims < 2 or \
- name == "encoder.conv1.bias" or \
- name == "encoder.conv2.bias" or \
- name == "encoder.positional_embedding" or \
- name == "decoder.positional_embedding":
- ftype = 0
- data = data.astype(np.float32)
- print(" Converting to float32")
- data = data.astype(np.float32)
- ftype = 0
- else:
- data = data.astype(np.float32)
- ftype = 0
-
- #if name.startswith("encoder"):
- # if name.endswith("mlp.0.weight") or \
- # name.endswith("mlp.2.weight"):
- # print(" Transposing")
- # data = data.transpose()
-
- # header
- str = name.encode('utf-8')
- fout.write(struct.pack("iii", n_dims, len(str), ftype))
- for i in range(n_dims):
- fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
- fout.write(str);
-
- # data
- data.tofile(fout)
-
-fout.close()
-
-print("Done. Output file: " + fname_out)
-print("")
# whisper.wasm
-Live demo: https://whisper.ggerganov.com
+Inference of [OpenAI's Whisper ASR model](https://github.com/openai/whisper) inside the browser
+
+This example uses a WebAssembly (WASM) port of the [whisper.cpp](https://github.com/ggerganov/whisper.cpp)
+implementation of the transformer to run the inference inside a web page. The audio data does not leave your computer -
+it is processed locally on your machine. The performance is not great but you should be able to achieve x2 or x3
+real-time for the `tiny` and `base` models on a modern CPU and browser (i.e. transcribe a 60 seconds audio in about
+~20-30 seconds).
+
+This WASM port utilizes [WASM SIMD 128-bit intrinsics](https://emcc.zcopy.site/docs/porting/simd/) so you have to make
+sure that [your browser supports them](https://webassembly.org/roadmap/).
+
+The example is capable of running all models up to size `small` inclusive. Beyond that, the memory requirements and
+performance are unsatisfactory. The implementation currently support only the `Greedy` sampling strategy. Both
+transcription and translation are supported.
+
+Since the model data is quite big (74MB for the `tiny` model) you need to manually load the model into the web-page.
+
+The example supports both loading audio from a file and recording audio from the microphone. The maximum length of the
+audio is limited to 120 seconds.
+
+## Live demo
+
+Link: https://whisper.ggerganov.com
+
+
</tr>
</table>
- <br><br>
+ <br>
<!-- textarea with height filling the rest of the page -->
<textarea id="output" rows="20"></textarea>
return new type(buffer);
}
+ //
+ // load model
+ //
+
function loadFile(event, fname) {
var file = event.target.files[0] || null;
if (file == null) {
reader.readAsArrayBuffer(file);
}
+ //
+ // audio file
+ //
+
function loadAudio(event) {
if (!context) {
context = new AudioContext({sampleRate: 16000});
}
//
- // Microphone
+ // microphone
//
var mediaRecorder = null;
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large" )
for model in "${models[@]}"; do
- python3 convert-pt-to-ggml.py ~/.cache/whisper/$model.pt ../whisper models/
+ python3 models/convert-pt-to-ggml.py ~/.cache/whisper/$model.pt ../whisper models/
mv -v models/ggml-model.bin models/ggml-$model.bin
done
--- /dev/null
+# Convert Whisper transformer model from PyTorch to ggml format
+#
+# Usage: python convert-pt-to-ggml.py ~/.cache/whisper/medium.pt ~/path/to/repo/whisper/ ./models/whisper-medium
+#
+# You need to clone the original repo in ~/path/to/repo/whisper/
+#
+# git clone https://github.com/openai/whisper ~/path/to/repo/whisper/
+#
+# It is used to various assets needed by the algorithm:
+#
+# - tokenizer
+# - mel filters
+#
+# Also, you need to have the original models in ~/.cache/whisper/
+# See the original repo for more details.
+#
+# This script loads the specified model and whisper assets and saves them in ggml format.
+# The output is a single binary file containing the following information:
+#
+# - hparams
+# - mel filters
+# - tokenizer vocab
+# - model variables
+#
+# For each variable, write the following:
+#
+# - Number of dimensions (int)
+# - Name length (int)
+# - Dimensions (int[n_dims])
+# - Name (char[name_length])
+# - Data (float[n_dims])
+#
+
+import io
+import os
+import sys
+import struct
+import json
+import code
+import torch
+import numpy as np
+
+from transformers import GPTJForCausalLM
+from transformers import GPT2TokenizerFast
+
+# ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L10-L110
+LANGUAGES = {
+ "en": "english",
+ "zh": "chinese",
+ "de": "german",
+ "es": "spanish",
+ "ru": "russian",
+ "ko": "korean",
+ "fr": "french",
+ "ja": "japanese",
+ "pt": "portuguese",
+ "tr": "turkish",
+ "pl": "polish",
+ "ca": "catalan",
+ "nl": "dutch",
+ "ar": "arabic",
+ "sv": "swedish",
+ "it": "italian",
+ "id": "indonesian",
+ "hi": "hindi",
+ "fi": "finnish",
+ "vi": "vietnamese",
+ "iw": "hebrew",
+ "uk": "ukrainian",
+ "el": "greek",
+ "ms": "malay",
+ "cs": "czech",
+ "ro": "romanian",
+ "da": "danish",
+ "hu": "hungarian",
+ "ta": "tamil",
+ "no": "norwegian",
+ "th": "thai",
+ "ur": "urdu",
+ "hr": "croatian",
+ "bg": "bulgarian",
+ "lt": "lithuanian",
+ "la": "latin",
+ "mi": "maori",
+ "ml": "malayalam",
+ "cy": "welsh",
+ "sk": "slovak",
+ "te": "telugu",
+ "fa": "persian",
+ "lv": "latvian",
+ "bn": "bengali",
+ "sr": "serbian",
+ "az": "azerbaijani",
+ "sl": "slovenian",
+ "kn": "kannada",
+ "et": "estonian",
+ "mk": "macedonian",
+ "br": "breton",
+ "eu": "basque",
+ "is": "icelandic",
+ "hy": "armenian",
+ "ne": "nepali",
+ "mn": "mongolian",
+ "bs": "bosnian",
+ "kk": "kazakh",
+ "sq": "albanian",
+ "sw": "swahili",
+ "gl": "galician",
+ "mr": "marathi",
+ "pa": "punjabi",
+ "si": "sinhala",
+ "km": "khmer",
+ "sn": "shona",
+ "yo": "yoruba",
+ "so": "somali",
+ "af": "afrikaans",
+ "oc": "occitan",
+ "ka": "georgian",
+ "be": "belarusian",
+ "tg": "tajik",
+ "sd": "sindhi",
+ "gu": "gujarati",
+ "am": "amharic",
+ "yi": "yiddish",
+ "lo": "lao",
+ "uz": "uzbek",
+ "fo": "faroese",
+ "ht": "haitian creole",
+ "ps": "pashto",
+ "tk": "turkmen",
+ "nn": "nynorsk",
+ "mt": "maltese",
+ "sa": "sanskrit",
+ "lb": "luxembourgish",
+ "my": "myanmar",
+ "bo": "tibetan",
+ "tl": "tagalog",
+ "mg": "malagasy",
+ "as": "assamese",
+ "tt": "tatar",
+ "haw": "hawaiian",
+ "ln": "lingala",
+ "ha": "hausa",
+ "ba": "bashkir",
+ "jw": "javanese",
+ "su": "sundanese",
+}
+
+# ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L273-L292
+def build_tokenizer(path_to_whisper_repo: str, name: str = "gpt2"):
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
+ path = os.path.join(path_to_whisper_repo, "whisper/assets", name)
+ tokenizer = GPT2TokenizerFast.from_pretrained(path)
+
+ specials = [
+ "<|startoftranscript|>",
+ *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
+ "<|translate|>",
+ "<|transcribe|>",
+ "<|startoflm|>",
+ "<|startofprev|>",
+ "<|nocaptions|>",
+ "<|notimestamps|>",
+ ]
+
+ tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
+ return tokenizer
+
+# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8+n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+if len(sys.argv) < 4:
+ print("Usage: convert-pt-to-ggml.py model.pt path-to-whisper-repo dir-output [use-f32]\n")
+ sys.exit(1)
+
+fname_inp = sys.argv[1]
+dir_whisper = sys.argv[2]
+dir_out = sys.argv[3]
+
+# try to load PyTorch binary data
+try:
+ model_bytes = open(fname_inp, "rb").read()
+ with io.BytesIO(model_bytes) as fp:
+ checkpoint = torch.load(fp, map_location="cpu")
+except:
+ print("Error: failed to load PyTorch model file: %s" % fname_inp)
+ sys.exit(1)
+
+hparams = checkpoint["dims"]
+print("hparams:", hparams)
+
+list_vars = checkpoint["model_state_dict"]
+
+#print(list_vars['encoder.positional_embedding'])
+#print(list_vars['encoder.conv1.weight'])
+#print(list_vars['encoder.conv1.weight'].shape)
+
+# load mel filters
+n_mels = hparams["n_mels"]
+with np.load(os.path.join(dir_whisper, "whisper/assets", "mel_filters.npz")) as f:
+ filters = torch.from_numpy(f[f"mel_{n_mels}"])
+ #print (filters)
+
+#code.interact(local=locals())
+
+multilingual = hparams["n_vocab"] == 51865
+tokenizer = build_tokenizer(dir_whisper, multilingual and "multilingual" or "gpt2")
+
+#print(tokenizer)
+#print(tokenizer.name_or_path)
+#print(len(tokenizer.additional_special_tokens))
+dir_tokenizer = tokenizer.name_or_path
+
+# output in the same directory as the model
+fname_out = dir_out + "/ggml-model.bin"
+
+with open(dir_tokenizer + "/vocab.json", "r") as f:
+ tokens = json.load(f)
+
+# use 16-bit or 32-bit floats
+use_f16 = True
+if len(sys.argv) > 4:
+ use_f16 = False
+ fname_out = dir_out + "/ggml-model-f32.bin"
+
+fout = open(fname_out, "wb")
+
+fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
+fout.write(struct.pack("i", hparams["n_vocab"]))
+fout.write(struct.pack("i", hparams["n_audio_ctx"]))
+fout.write(struct.pack("i", hparams["n_audio_state"]))
+fout.write(struct.pack("i", hparams["n_audio_head"]))
+fout.write(struct.pack("i", hparams["n_audio_layer"]))
+fout.write(struct.pack("i", hparams["n_text_ctx"]))
+fout.write(struct.pack("i", hparams["n_text_state"]))
+fout.write(struct.pack("i", hparams["n_text_head"]))
+fout.write(struct.pack("i", hparams["n_text_layer"]))
+fout.write(struct.pack("i", hparams["n_mels"]))
+fout.write(struct.pack("i", use_f16))
+
+# write mel filters
+fout.write(struct.pack("i", filters.shape[0]))
+fout.write(struct.pack("i", filters.shape[1]))
+for i in range(filters.shape[0]):
+ for j in range(filters.shape[1]):
+ fout.write(struct.pack("f", filters[i][j]))
+
+byte_encoder = bytes_to_unicode()
+byte_decoder = {v:k for k, v in byte_encoder.items()}
+
+fout.write(struct.pack("i", len(tokens)))
+
+for key in tokens:
+ text = bytearray([byte_decoder[c] for c in key])
+ fout.write(struct.pack("i", len(text)))
+ fout.write(text)
+
+for name in list_vars.keys():
+ data = list_vars[name].squeeze().numpy()
+ print("Processing variable: " + name + " with shape: ", data.shape)
+
+ # reshape conv bias from [n] to [n, 1]
+ if name == "encoder.conv1.bias" or \
+ name == "encoder.conv2.bias":
+ data = data.reshape(data.shape[0], 1)
+ print(" Reshaped variable: " + name + " to shape: ", data.shape)
+
+ n_dims = len(data.shape);
+
+ # looks like the whisper models are in f16 by default
+ # so we need to convert the small tensors to f32 until we fully support f16 in ggml
+ # ftype == 0 -> float32, ftype == 1 -> float16
+ ftype = 1;
+ if use_f16:
+ if n_dims < 2 or \
+ name == "encoder.conv1.bias" or \
+ name == "encoder.conv2.bias" or \
+ name == "encoder.positional_embedding" or \
+ name == "decoder.positional_embedding":
+ ftype = 0
+ data = data.astype(np.float32)
+ print(" Converting to float32")
+ data = data.astype(np.float32)
+ ftype = 0
+ else:
+ data = data.astype(np.float32)
+ ftype = 0
+
+ #if name.startswith("encoder"):
+ # if name.endswith("mlp.0.weight") or \
+ # name.endswith("mlp.2.weight"):
+ # print(" Transposing")
+ # data = data.transpose()
+
+ # header
+ str = name.encode('utf-8')
+ fout.write(struct.pack("iii", n_dims, len(str), ftype))
+ for i in range(n_dims):
+ fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
+ fout.write(str);
+
+ # data
+ data.tofile(fout)
+
+fout.close()
+
+print("Done. Output file: " + fname_out)
+print("")