import signal
import struct
import sys
+import textwrap
import time
import zipfile
-from abc import ABCMeta, abstractmethod
+from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from dataclasses import dataclass
from pathlib import Path
-from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, TypeVar
+from typing import TYPE_CHECKING, Any, Callable, ClassVar, IO, Iterable, Literal, Protocol, TypeVar, runtime_checkable
import numpy as np
from sentencepiece import SentencePieceProcessor
DEFAULT_CONCURRENCY = 8
+ADDED_TOKENS_FILE = 'added_tokens.json'
+FAST_TOKENIZER_FILE = 'tokenizer.json'
+
#
# data types
#
n_layer = next(i for i in itertools.count() if f"layers.{i}.attention.wq.weight" not in model)
if n_layer < 1:
- raise Exception("failed to guess 'n_layer'. This model is unknown or unsupported.\n"
- "Suggestion: provide 'config.json' of the model in the same directory containing model files.")
+ msg = """\
+ failed to guess 'n_layer'. This model is unknown or unsupported.
+ Suggestion: provide 'config.json' of the model in the same directory containing model files."""
+ raise KeyError(textwrap.dedent(msg))
n_head = n_embd // 128 # guessed
n_mult = 256 # guessed
@staticmethod
def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
- config = json.load(open(config_path))
+ with open(config_path) as f:
+ config = json.load(f)
rope_scaling_type = f_rope_scale = n_orig_ctx = rope_finetuned = None
rope_scaling = config.get("rope_scaling")
elif "max_position_embeddings" in config:
n_ctx = config["max_position_embeddings"]
else:
- raise Exception("failed to guess 'n_ctx'. This model is unknown or unsupported.\n"
- "Suggestion: provide 'config.json' of the model in the same directory containing model files.")
+ msg = """\
+ failed to guess 'n_ctx'. This model is unknown or unsupported.
+ Suggestion: provide 'config.json' of the model in the same directory containing model files."""
+ raise KeyError(textwrap.dedent(msg))
n_experts = None
n_experts_used = None
# {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1}
@staticmethod
def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
- config = json.load(open(config_path))
+ with open(config_path) as f:
+ config = json.load(f)
n_experts = None
n_experts_used = None
# vocab
#
-class BpeVocab:
+@runtime_checkable
+class BaseVocab(Protocol):
+ tokenizer_model: ClassVar[str]
+ name: ClassVar[str]
+
+
+class NoVocab(BaseVocab):
+ tokenizer_model = "no_vocab"
+ name = "no_vocab"
+
+ def __repr__(self) -> str:
+ return "<NoVocab for a model without integrated vocabulary>"
+
+
+@runtime_checkable
+class Vocab(BaseVocab, Protocol):
+ vocab_size: int
+ added_tokens_dict: dict[str, int]
+ added_tokens_list: list[str]
+ fname_tokenizer: Path
+
+ def __init__(self, base_path: Path): ...
+ def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: ...
+
+
+class BpeVocab(Vocab):
tokenizer_model = "gpt2"
name = "bpe"
- def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None:
- self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read())
- if isinstance(self.bpe_tokenizer.get('model'), dict):
- self.vocab = self.bpe_tokenizer["model"]["vocab"]
- else:
- self.vocab = self.bpe_tokenizer
- added_tokens: dict[str, int]
- if fname_added_tokens is not None:
- # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab.
- added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
+ def __init__(self, base_path: Path):
+ added_tokens: dict[str, int] = {}
+
+ if (fname_tokenizer := base_path / 'vocab.json').exists():
+ # "slow" tokenizer
+ with open(fname_tokenizer, encoding="utf-8") as f:
+ self.vocab = json.load(f)
+
+ try:
+ # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab.
+ with open(base_path / ADDED_TOKENS_FILE, encoding="utf-8") as f:
+ added_tokens = json.load(f)
+ except FileNotFoundError:
+ pass
else:
- # Fall back to trying to find the added tokens in tokenizer.json
- tokenizer_json_file = fname_tokenizer.parent / 'tokenizer.json'
- if not tokenizer_json_file.is_file():
- added_tokens = {}
- else:
- tokenizer_json = json.load(open(tokenizer_json_file, encoding="utf-8"))
- added_tokens = dict(
- (item['content'], item['id'])
- for item in tokenizer_json.get('added_tokens', [])
- # Added tokens here can be duplicates of the main vocabulary.
- if item['content'] not in self.bpe_tokenizer)
-
- vocab_size: int = len(self.vocab)
- expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
- actual_ids = sorted(added_tokens.values())
+ # "fast" tokenizer
+ fname_tokenizer = base_path / FAST_TOKENIZER_FILE
+
+ # if this fails, FileNotFoundError propagates to caller
+ with open(fname_tokenizer, encoding="utf-8") as f:
+ tokenizer_json = json.load(f)
+
+ tokenizer_model: dict[str, Any] = tokenizer_json['model']
+ if (
+ tokenizer_model['type'] != 'BPE' or tokenizer_model.get('byte_fallback', False)
+ or tokenizer_json['decoder']['type'] != 'ByteLevel'
+ ):
+ raise FileNotFoundError('Cannot find GPT-2 BPE tokenizer')
+
+ self.vocab = tokenizer_model["vocab"]
+
+ if (added := tokenizer_json.get('added_tokens')) is not None:
+ # Added tokens here can be duplicates of the main vocabulary.
+ added_tokens = {item['content']: item['id']
+ for item in added
+ if item['content'] not in self.vocab}
+
+ vocab_size = len(self.vocab)
+ expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
+ actual_ids = sorted(added_tokens.values())
if expected_ids != actual_ids:
expected_end_id = vocab_size + len(actual_ids) - 1
- raise Exception(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range {vocab_size} - {expected_end_id}; got {actual_ids}")
+ raise ValueError(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range "
+ f"{vocab_size} - {expected_end_id}; got {actual_ids}")
items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
self.added_tokens_dict = added_tokens
self.added_tokens_list = [text for (text, idx) in items]
- self.vocab_size_base: int = vocab_size
- self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list)
+ self.vocab_size_base = vocab_size
+ self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
self.fname_tokenizer = fname_tokenizer
- self.fname_added_tokens = fname_added_tokens
def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
reverse_vocab = {id: encoded_tok for encoded_tok, id in self.vocab.items()}
return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
-class SentencePieceVocab:
+class SentencePieceVocab(Vocab):
tokenizer_model = "llama"
name = "spm"
- def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None:
- self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
- added_tokens: dict[str, int]
- if fname_added_tokens is not None:
- added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
- else:
- added_tokens = {}
+ def __init__(self, base_path: Path):
+ added_tokens: dict[str, int] = {}
+ if (fname_tokenizer := base_path / 'tokenizer.model').exists():
+ # normal location
+ try:
+ with open(base_path / ADDED_TOKENS_FILE, encoding="utf-8") as f:
+ added_tokens = json.load(f)
+ except FileNotFoundError:
+ pass
+ elif not (fname_tokenizer := base_path.parent / 'tokenizer.model').exists():
+ # not found in alternate location either
+ raise FileNotFoundError('Cannot find tokenizer.model')
- vocab_size: int = self.sentencepiece_tokenizer.vocab_size()
+ self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
+ vocab_size = self.sentencepiece_tokenizer.vocab_size()
new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens)))
raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}")
# Token pieces that were added to the base vocabulary.
- self.added_tokens_dict = added_tokens
+ self.added_tokens_dict = added_tokens
self.added_tokens_list = [new_tokens[id] for id in actual_new_ids]
self.vocab_size_base = vocab_size
self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
self.fname_tokenizer = fname_tokenizer
- self.fname_added_tokens = fname_added_tokens
def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
tokenizer = self.sentencepiece_tokenizer
for i in range(tokenizer.vocab_size()):
piece = tokenizer.id_to_piece(i)
- text: bytes = piece.encode("utf-8")
+ text = piece.encode("utf-8")
score: float = tokenizer.get_score(i)
toktype = gguf.TokenType.NORMAL
return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
-class HfVocab:
+class LlamaHfVocab(Vocab):
tokenizer_model = "llama"
name = "hfft"
- def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None = None) -> None:
+ def __init__(self, base_path: Path, ignore_nonllama: bool = False):
+ fname_tokenizer = base_path / FAST_TOKENIZER_FILE
+ # if this fails, FileNotFoundError propagates to caller
+ with open(fname_tokenizer, encoding='utf-8') as f:
+ tokenizer_json = json.load(f)
+
+ # pre-check so we know if we need transformers
+ tokenizer_model: dict[str, Any] = tokenizer_json['model']
+ if ignore_nonllama:
+ pass # workaround incorrect use of this class for WordPiece
+ elif (
+ tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False)
+ or tokenizer_json['decoder']['type'] != 'Sequence'
+ ):
+ raise FileNotFoundError('Cannot find Llama BPE tokenizer')
+
try:
from transformers import AutoTokenizer
except ImportError as e:
raise ImportError(
- "To use HfVocab, please install the `transformers` package. "
+ "To use LlamaHfVocab, please install the `transformers` package. "
"You can install it with `pip install transformers`."
) from e
- print("fname_tokenizer:", fname_tokenizer)
# Allow the tokenizer to default to slow or fast versions.
# Explicitly set tokenizer to use local paths.
self.tokenizer = AutoTokenizer.from_pretrained(
- fname_tokenizer,
- cache_dir=fname_tokenizer,
+ base_path,
+ cache_dir=base_path,
local_files_only=True,
)
+ assert self.tokenizer.is_fast # assume tokenizer.json is used
# Initialize lists and dictionaries for added tokens
self.added_tokens_list = []
self.vocab_size_base = self.tokenizer.vocab_size
self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
- self.fname_tokenizer = fname_tokenizer
- self.fname_added_tokens = fname_added_tokens
+ self.fname_tokenizer = fname_tokenizer
def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
reverse_vocab = {
yield from self.added_tokens()
def __repr__(self) -> str:
- return f"<HfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
-
-
-class NoVocab:
- tokenizer_model = "no_vocab"
- name = "no_vocab"
-
- def __repr__(self) -> str:
- return "<NoVocab for a model without integrated vocabulary>"
-
-
-Vocab: TypeAlias = "BpeVocab | SentencePieceVocab | HfVocab | NoVocab"
+ return f"<LlamaHfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
#
.reshape(weights.shape))
-class Tensor(metaclass=ABCMeta):
+class Tensor(ABC):
data_type: DataType
@abstractmethod
class UnquantizedTensor(Tensor):
- def __init__(self, ndarray: NDArray) -> None:
+ def __init__(self, ndarray: NDArray):
assert isinstance(ndarray, np.ndarray)
self.ndarray = ndarray
self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype]
model: LazyModel
paths: list[Path] # Where this was read from.
format: Literal['ggml', 'torch', 'safetensors', 'none']
- vocab: Vocab | None # For GGML models (which have vocab built in), the vocab.
+ vocab: BaseVocab | None # For GGML models (which have vocab built in), the vocab.
def merge_sharded(models: list[LazyModel]) -> LazyModel:
names = {name: None for model in models for name in model}
def convert(name: str) -> LazyTensor:
- lazy_tensors: list[LazyTensor] = [model[name] for model in models]
+ lazy_tensors = [model[name] for model in models]
if len(lazy_tensors) == 1:
# only one file; don't go through this procedure since there might
# be quantized tensors
def load() -> UnquantizedTensor:
ndarrays = [load_unquantized(tensor) for tensor in lazy_tensors]
- concatenated: NDArray = np.concatenate(ndarrays, axis=axis)
+ concatenated = np.concatenate(ndarrays, axis=axis)
return UnquantizedTensor(concatenated)
description = 'concatenated[[' + '] | ['.join(lt.description for lt in lazy_tensors) + ']]'
return LazyTensor(load, concatenated_shape, lazy_tensors[0].data_type, description)
def load(offset: int, elm_count: int) -> NDArray:
dtype = data_type.dtype
- fp = self.zip_file.open(info)
- fp.seek(offset * dtype.itemsize)
- size = elm_count * dtype.itemsize
- data = fp.read(size)
+ with self.zip_file.open(info) as fp:
+ fp.seek(offset * dtype.itemsize)
+ size = elm_count * dtype.itemsize
+ data = fp.read(size)
assert len(data) == size
return np.frombuffer(data, dtype)
description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}'
def rebuild_from_type_v2(func, new_type, args, state):
return func(*args)
- CLASSES: dict[tuple[str, str], Any] = {
+ CLASSES = {
# getattr used here as a workaround for mypy not being smart enough to determine
# the staticmethods have a __func__ attribute.
('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'),
def must_read(fp: IO[bytes], length: int) -> bytes:
ret = fp.read(length)
if len(ret) < length:
- raise Exception("unexpectedly reached end of file")
+ raise EOFError("unexpectedly reached end of file")
return ret
yield result
-def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None:
+def check_vocab_size(params: Params, vocab: BaseVocab, pad_vocab: bool = False) -> None:
# Handle special case where the model's vocab size is not set
if params.n_vocab == -1:
raise ValueError(
- f"The model's vocab size is set to -1 in params.json. Please update it manually.{f' Maybe {vocab.vocab_size}?' if hasattr(vocab, 'vocab_size') else ''}"
+ "The model's vocab size is set to -1 in params.json. Please update it manually."
+ + (f" Maybe {vocab.vocab_size}?" if isinstance(vocab, Vocab) else ""),
)
- if isinstance(vocab, NoVocab):
+ if not isinstance(vocab, Vocab):
return # model has no vocab
# Check for a vocab size mismatch
if vocab.vocab_size < params.n_vocab:
msg += " Add the --pad-vocab option and try again."
- raise Exception(msg)
+ raise ValueError(msg)
class OutputFile:
- def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE) -> None:
+ def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
def add_meta_arch(self, params: Params) -> None:
self.gguf.add_file_type(params.ftype)
def extract_vocabulary_from_model(self, vocab: Vocab) -> tuple[list[bytes], list[float], list[gguf.TokenType]]:
- assert not isinstance(vocab, NoVocab)
-
tokens = []
scores = []
toktypes = []
@staticmethod
def write_all(
- fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab,
+ fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
pad_vocab: bool = False,
) -> None:
# meta data
of.add_meta_arch(params)
- if isinstance(vocab, NoVocab):
- of.gguf.add_tokenizer_model(vocab.tokenizer_model)
- else:
+ if isinstance(vocab, Vocab):
of.add_meta_vocab(vocab)
of.add_meta_special_vocab(svocab)
+ else: # NoVocab
+ of.gguf.add_tokenizer_model(vocab.tokenizer_model)
# tensor info
for name, lazy_tensor in model.items():
name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()}
- raise Exception(f"Unexpected combination of types: {name_to_type}")
+ raise ValueError(f"Unexpected combination of types: {name_to_type}")
def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel:
def convert_model_names(model: LazyModel, params: Params, skip_unknown: bool) -> LazyModel:
tmap = gguf.TensorNameMap(ARCH, params.n_layer)
- should_skip: set[gguf.MODEL_TENSOR] = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, []))
+ should_skip = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, []))
tmp = model
if skip_unknown:
print(f"Unexpected tensor name: {name} - skipping")
continue
- else:
- raise Exception(f"Unexpected tensor name: {name}. Use --skip-unknown to ignore it (e.g. LLaVA)")
+ raise ValueError(f"Unexpected tensor name: {name}. Use --skip-unknown to ignore it (e.g. LLaVA)")
if tensor_type in should_skip:
print(f"skipping tensor {name_new}")
the nth path in the model.
'''
# Support the following patterns:
- patterns: list[tuple[str, str]] = [
+ patterns = [
# - x.00.pth, x.01.pth, etc.
(r'\.[0-9]{2}\.pth$', f'.{n:02}.pth'),
# - x-00001-of-00002.bin, x-00002-of-00002.bin, etc.
globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"]
files = [file for glob in globs for file in path.glob(glob)]
if not files:
- raise Exception(f"Can't find model in directory {path}")
+ raise FileNotFoundError(f"Can't find model in directory {path}")
if len(files) > 1:
- raise Exception(f"Found multiple models in {path}, not sure which to pick: {files}")
+ raise ValueError(f"Found multiple models in {path}, not sure which to pick: {files}")
path = files[0]
paths = find_multifile_paths(path)
class VocabFactory:
- _FILES = {"spm": "tokenizer.model", "bpe": "vocab.json", "hfft": "tokenizer.json"}
+ _VOCAB_CLASSES: list[type[Vocab]] = [SentencePieceVocab, BpeVocab, LlamaHfVocab]
def __init__(self, path: Path):
self.path = path
- self.file_paths = self._detect_files()
- print(f"Found vocab files: {self.file_paths}")
-
- def _detect_files(self) -> dict[str, Path | None]:
- def locate(file: str) -> Path | None:
- if (path := self.path / file).exists():
- return path
- if (path := self.path.parent / file).exists():
- return path
- return None
-
- return {vt: locate(f) for vt, f in self._FILES.items()}
-
- def _select_file(self, vocab_types: list[str]) -> tuple[str, Path]:
- for vtype in vocab_types:
- try:
- path = self.file_paths[vtype]
- except KeyError:
- raise ValueError(f"Unsupported vocabulary type {vtype}") from None
- if path is not None:
- return vtype, path
- raise FileNotFoundError(f"Could not find any of {[self._FILES[vt] for vt in vocab_types]}")
- def _create_special_vocab(self, vocab: Vocab, model_parent_path: Path) -> gguf.SpecialVocab:
+ def _create_special_vocab(self, vocab: BaseVocab, model_parent_path: Path) -> gguf.SpecialVocab:
load_merges = vocab.name == "bpe"
- n_vocab = vocab.vocab_size if hasattr(vocab, "vocab_size") else None
+ n_vocab = vocab.vocab_size if isinstance(vocab, Vocab) else None
return gguf.SpecialVocab(
model_parent_path,
load_merges=load_merges,
)
def _create_vocab_by_path(self, vocab_types: list[str]) -> Vocab:
- vocab_type, path = self._select_file(vocab_types)
- print(f"Loading vocab file {path!r}, type {vocab_type!r}")
+ vocab_classes: dict[str, type[Vocab]] = {cls.name: cls for cls in self._VOCAB_CLASSES}
+ selected_vocabs: dict[str, type[Vocab]] = {}
+ for vtype in vocab_types:
+ try:
+ selected_vocabs[vtype] = vocab_classes[vtype]
+ except KeyError:
+ raise ValueError(f"Unsupported vocabulary type {vtype}") from None
- added_tokens_path = path.parent / "added_tokens.json"
- if vocab_type == "bpe":
- return BpeVocab(
- path, added_tokens_path if added_tokens_path.exists() else None
- )
- if vocab_type == "spm":
- return SentencePieceVocab(
- path, added_tokens_path if added_tokens_path.exists() else None
- )
- if vocab_type == "hfft":
- return HfVocab(
- path.parent, added_tokens_path if added_tokens_path.exists() else None
- )
- raise ValueError(vocab_type)
+ for vtype, cls in selected_vocabs.items():
+ try:
+ vocab = cls(self.path)
+ break
+ except FileNotFoundError:
+ pass # ignore unavailable tokenizers
+ else:
+ raise FileNotFoundError(f"Could not find a tokenizer matching any of {vocab_types}")
+
+ print(f"Loaded vocab file {vocab.fname_tokenizer!r}, type {vocab.name!r}")
+ return vocab
- def load_vocab(self, vocab_types: list[str], model_parent_path: Path) -> tuple[Vocab, gguf.SpecialVocab]:
- vocab: Vocab
- if len(vocab_types) == 1 and "no_vocab" in vocab_types:
+ def load_vocab(self, vocab_types: list[str] | None, model_parent_path: Path) -> tuple[BaseVocab, gguf.SpecialVocab]:
+ vocab: BaseVocab
+ if vocab_types is None:
vocab = NoVocab()
else:
vocab = self._create_vocab_by_path(vocab_types)
parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing")
args = parser.parse_args(args_in)
- if args.no_vocab:
- if args.vocab_only:
- raise ValueError("no need to specify --vocab-only if using --no-vocab")
- args.vocab_type = "no_vocab"
+ if args.no_vocab and args.vocab_only:
+ raise ValueError("--vocab-only does not make sense with --no-vocab")
if args.dump_single:
model_plus = lazy_load_file(args.model)
params = Params.load(model_plus)
if params.n_ctx == -1:
if args.ctx is None:
- raise Exception("The model doesn't have a context size, and you didn't specify one with --ctx\n"
- "Please specify one with --ctx:\n"
- " - LLaMA v1: --ctx 2048\n"
- " - LLaMA v2: --ctx 4096\n")
+ msg = """\
+ The model doesn't have a context size, and you didn't specify one with --ctx
+ Please specify one with --ctx:
+ - LLaMA v1: --ctx 2048
+ - LLaMA v2: --ctx 4096"""
+ parser.error(textwrap.dedent(msg))
params.n_ctx = args.ctx
if args.outtype:
model_parent_path = model_plus.paths[0].parent
vocab_path = Path(args.vocab_dir or args.model or model_parent_path)
vocab_factory = VocabFactory(vocab_path)
- vocab, special_vocab = vocab_factory.load_vocab(args.vocab_type.split(","), model_parent_path)
+ vocab_types = None if args.no_vocab else args.vocab_type.split(",")
+ vocab, special_vocab = vocab_factory.load_vocab(vocab_types, model_parent_path)
if args.vocab_only:
+ assert isinstance(vocab, Vocab)
if not args.outfile:
raise ValueError("need --outfile if using --vocab-only")
outfile = args.outfile