TENSORS_SET = set(TENSORS_LIST)
+def find_n_mult(n_ff: int, n_embd: int) -> int:
+ # hardcoded magic range
+ for n_mult in range(256, 1, -1):
+ calc_ff = (((8*n_embd) // 3 + n_mult - 1) // n_mult)*n_mult
+ if calc_ff == n_ff:
+ return n_mult
+ return 1
+
@dataclass
class Params:
n_vocab: int
n_mult: int
n_head: int
n_layer: int
- file_type: GGMLFileType
@staticmethod
- def guessed(model: 'LazyModel', file_type: GGMLFileType) -> 'Params':
- n_vocab, n_embd = model["tok_embeddings.weight"].shape
+ def guessed(model: 'LazyModel') -> 'Params':
+ # try transformer naming first
+ n_vocab, n_embd = model["model.embed_tokens.weight"].shape if "model.embed_tokens.weight" in model else model["tok_embeddings.weight"].shape
+
+ # try transformer naming first
+ if "model.layers.0.self_attn.q_proj.weight" in model:
+ n_layer=next(i for i in itertools.count() if f"model.layers.{i}.self_attn.q_proj.weight" not in model)
+ else:
+ n_layer=next(i for i in itertools.count() if f"layers.{i}.attention.wq.weight" not in model)
+
+ n_head=n_embd // 128 # guessed
return Params(
n_vocab=n_vocab,
n_embd=n_embd,
n_mult=256,
- n_head=n_embd // 128,
- n_layer=next(i for i in itertools.count() if f"layers.{i}.attention.wq.weight" not in model),
- file_type=file_type,
+ n_head=n_head,
+ n_layer=n_layer,
)
+ @staticmethod
+ def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
+ config = json.load(open(config_path))
+
+ n_vocab = config["vocab_size"];
+ n_embd = config["hidden_size"];
+ n_head = config["num_attention_heads"];
+ n_layer = config["num_hidden_layers"];
+ n_ff = config["intermediate_size"];
+
+ n_mult = find_n_mult(n_ff, n_embd);
+
+ return Params(
+ n_vocab=n_vocab,
+ n_embd=n_embd,
+ n_mult=n_mult,
+ n_head=n_head,
+ n_layer=n_layer,
+ )
+
+ @staticmethod
+ def load(model_plus: 'ModelPlus') -> 'Params':
+ orig_config_path = model_plus.paths[0].parent / "params.json"
+ hf_transformer_config_path = model_plus.paths[0].parent / "config.json"
+
+ if hf_transformer_config_path.exists():
+ params = Params.loadHFTransformerJson(model_plus.model, hf_transformer_config_path)
+ else:
+ params = Params.guessed(model_plus.model)
+
+ print(f'params: n_vocab:{params.n_vocab} n_embd:{params.n_embd} n_mult:{params.n_mult} n_head:{params.n_head} n_layer:{params.n_layer}')
+ return params
+
class SentencePieceVocab:
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}) ' + lazy_tensor.description)
-def convert_transformers_to_orig(model: LazyModel) -> LazyModel:
+def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel:
out: LazyModel = {}
out["tok_embeddings.weight"] = model["model.embed_tokens.weight"]
out["norm.weight"] = model["model.norm.weight"]
out["output.weight"] = model["lm_head.weight"]
- n_head = model["model.layers.0.self_attn.q_proj.weight"].shape[1] // 128
for i in itertools.count():
if f"model.layers.{i}.self_attn.q_proj.weight" not in model:
break
- out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], n_head)
- out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], n_head)
+ out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head)
+ out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head)
out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
out[f"layers.{i}.attention.wo.weight"] = model[f"model.layers.{i}.self_attn.o_proj.weight"]
def __init__(self, fname_out: Path) -> None:
self.fout = open(fname_out, "wb")
- def write_file_header(self, params: Params) -> None:
+ def write_file_header(self, params: Params, file_type: GGMLFileType) -> None:
self.fout.write(b"ggjt"[::-1]) # magic
values = [
1, # file version
params.n_head,
params.n_layer,
params.n_embd // params.n_head, # rot (obsolete)
- params.file_type.value,
+ file_type.value,
]
self.fout.write(struct.pack("i" * len(values), *values))
of.fout.close()
@staticmethod
- def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab) -> None:
+ def write_all(fname_out: Path, params: Params, file_type: GGMLFileType, model: LazyModel, vocab: Vocab) -> None:
check_vocab_size(params, vocab)
of = OutputFile(fname_out)
- of.write_file_header(params)
+ of.write_file_header(params, file_type)
print("Writing vocab...")
of.write_vocab(vocab)
raise Exception(f"Unexpected combination of types: {name_to_type}")
-def do_necessary_conversions(model: LazyModel) -> LazyModel:
+def do_necessary_conversions(model: LazyModel, params: Params) -> LazyModel:
model = handle_quantization(model)
if "lm_head.weight" in model:
- model = convert_transformers_to_orig(model)
+ model = convert_transformers_to_orig(model, params)
model = filter_and_sort_tensors(model)
return model
return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None)
-def default_outfile(model_paths: List[Path], params: Params) -> Path:
+def default_outfile(model_paths: List[Path], file_type: GGMLFileType) -> Path:
namestr = {
GGMLFileType.AllF32: "f32",
GGMLFileType.MostlyF16: "f16",
GGMLFileType.MostlyQ4_0: "q4_0",
GGMLFileType.MostlyQ4_1: "q4_1",
GGMLFileType.PerLayerIsQ4_1: "q4_1",
- }[params.file_type]
+ }[file_type]
ret = model_paths[0].parent / f"ggml-model-{namestr}.bin"
if ret in model_paths:
sys.stderr.write(
else:
vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
vocab = load_vocab(vocab_dir)
+ params = Params.load(model_plus)
model = model_plus.model
- model = do_necessary_conversions(model)
+ model = do_necessary_conversions(model, params)
output_type = pick_output_type(model, args.outtype)
model = convert_to_output_type(model, output_type)
- params = Params.guessed(model, output_type)
- outfile = args.outfile or default_outfile(model_plus.paths, params)
- OutputFile.write_all(outfile, params, model, vocab)
+ outfile = args.outfile or default_outfile(model_plus.paths, output_type)
+ OutputFile.write_all(outfile, params, output_type, model, vocab)
print(f"Wrote {outfile}")