]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
convert : refactor vocab selection logic (#6355)
authorJared Van Bortel <redacted>
Thu, 28 Mar 2024 15:44:36 +0000 (11:44 -0400)
committerGitHub <redacted>
Thu, 28 Mar 2024 15:44:36 +0000 (11:44 -0400)
convert-hf-to-gguf.py
convert-persimmon-to-gguf.py
convert.py
llama.h

index c5d2d0b7813d1c517e6b76642972ee10fd57fd89..918a90e584b57a11eb3597ee89bb66fadfd5fc67 100755 (executable)
@@ -23,7 +23,7 @@ if 'NO_LOCAL_GGUF' not in os.environ:
     sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
 import gguf
 
-from convert import HfVocab
+from convert import LlamaHfVocab
 
 
 ###### MODEL DEFINITIONS ######
@@ -230,7 +230,7 @@ class Model(ABC):
     def _set_vocab_gpt2(self):
         dir_model = self.dir_model
         hparams = self.hparams
-        tokens: list[bytearray] = []
+        tokens: list[str] = []
         toktypes: list[int] = []
 
         from transformers import AutoTokenizer
@@ -243,8 +243,7 @@ class Model(ABC):
 
         for i in range(vocab_size):
             if i not in reverse_vocab:
-                pad_token = f"[PAD{i}]".encode('utf-8')
-                tokens.append(bytearray(pad_token))
+                tokens.append(f"[PAD{i}]")
                 toktypes.append(gguf.TokenType.USER_DEFINED)
             elif reverse_vocab[i] in added_vocab:
                 tokens.append(reverse_vocab[i])
@@ -266,7 +265,7 @@ class Model(ABC):
     def _set_vocab_qwen(self):
         dir_model = self.dir_model
         hparams = self.hparams
-        tokens: list[bytearray] = []
+        tokens: list[str] = []
         toktypes: list[int] = []
 
         from transformers import AutoTokenizer
@@ -291,8 +290,7 @@ class Model(ABC):
 
         for i in range(vocab_size):
             if i not in reverse_vocab:
-                pad_token = f"[PAD{i}]".encode("utf-8")
-                tokens.append(bytearray(pad_token))
+                tokens.append(f"[PAD{i}]")
                 toktypes.append(gguf.TokenType.USER_DEFINED)
             elif reverse_vocab[i] in added_vocab:
                 tokens.append(reverse_vocab[i])
@@ -372,12 +370,8 @@ class Model(ABC):
         special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
         special_vocab.add_to_gguf(self.gguf_writer)
 
-    def _set_vocab_hf(self):
-        path = self.dir_model
-        added_tokens_path = self.dir_model
-        vocab = HfVocab(
-            path, added_tokens_path if added_tokens_path.exists() else None
-        )
+    def _set_vocab_llama_hf(self):
+        vocab = LlamaHfVocab(self.dir_model)
         tokens = []
         scores = []
         toktypes = []
@@ -1099,7 +1093,7 @@ class MiniCPMModel(Model):
         self.gguf_writer.add_file_type(self.ftype)
 
     def set_vocab(self):
-        self._set_vocab_hf()
+        self._set_vocab_llama_hf()
 
     def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
         if n_kv_head is not None and n_head != n_kv_head:
@@ -1700,11 +1694,8 @@ class BertModel(Model):
             self.gguf_writer.add_pooling_type(pooling_type)
 
     def set_vocab(self):
-        path = self.dir_model
-        added_tokens_path = self.dir_model if self.dir_model.exists() else None
-
         # use huggingface vocab to get all tokens
-        vocab = HfVocab(path, added_tokens_path)
+        vocab = LlamaHfVocab(self.dir_model, ignore_nonllama=True)
         tokens, scores, toktypes = zip(*vocab.all_tokens())
         assert len(tokens) == vocab.vocab_size
         self.vocab_size = vocab.vocab_size
index def210531e27b1794597573790b799eff81a4d34..ccb99279e20a8c2bf441dbcfd7e5d292b270afa5 100755 (executable)
@@ -106,12 +106,12 @@ def main():
     tensor_map = gguf.get_tensor_name_map(arch, block_count)
     print(tensor_map)
     for name in tensors.keys():
-        data = tensors[name]
+        data_torch = tensors[name]
         if name.endswith(".self_attention.rotary_emb.inv_freq"):
             continue
-        old_dtype = data.dtype
+        old_dtype = data_torch.dtype
         # TODO: FP16 conversion produces garbage outputs. (Q8_0 does not, so..?)
-        data = data.to(torch.float32).squeeze().numpy()
+        data = data_torch.to(torch.float32).squeeze().numpy()
         new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
         if new_name is None:
             print("Can not map tensor '" + name + "'")
index 817cb66123a8f48af6f2f1081c28dc4de65dcd25..d3a9ccaf21e616ec770b55a2d85706a26b72a017 100755 (executable)
@@ -16,13 +16,14 @@ import re
 import signal
 import struct
 import sys
+import textwrap
 import time
 import zipfile
-from abc import ABCMeta, abstractmethod
+from abc import ABC, abstractmethod
 from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
 from dataclasses import dataclass
 from pathlib import Path
-from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, TypeVar
+from typing import TYPE_CHECKING, Any, Callable, ClassVar, IO, Iterable, Literal, Protocol, TypeVar, runtime_checkable
 
 import numpy as np
 from sentencepiece import SentencePieceProcessor
@@ -43,6 +44,9 @@ ARCH = gguf.MODEL_ARCH.LLAMA
 
 DEFAULT_CONCURRENCY = 8
 
+ADDED_TOKENS_FILE = 'added_tokens.json'
+FAST_TOKENIZER_FILE = 'tokenizer.json'
+
 #
 # data types
 #
@@ -188,8 +192,10 @@ class Params:
             n_layer = next(i for i in itertools.count() if f"layers.{i}.attention.wq.weight" not in model)
 
         if n_layer < 1:
-            raise Exception("failed to guess 'n_layer'. This model is unknown or unsupported.\n"
-                            "Suggestion: provide 'config.json' of the model in the same directory containing model files.")
+            msg = """\
+                failed to guess 'n_layer'. This model is unknown or unsupported.
+                Suggestion: provide 'config.json' of the model in the same directory containing model files."""
+            raise KeyError(textwrap.dedent(msg))
 
         n_head = n_embd // 128 # guessed
         n_mult = 256           # guessed
@@ -211,7 +217,8 @@ class Params:
 
     @staticmethod
     def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
-        config = json.load(open(config_path))
+        with open(config_path) as f:
+            config = json.load(f)
 
         rope_scaling_type = f_rope_scale = n_orig_ctx = rope_finetuned = None
         rope_scaling = config.get("rope_scaling")
@@ -233,8 +240,10 @@ class Params:
         elif "max_position_embeddings" in config:
             n_ctx = config["max_position_embeddings"]
         else:
-            raise Exception("failed to guess 'n_ctx'. This model is unknown or unsupported.\n"
-                            "Suggestion: provide 'config.json' of the model in the same directory containing model files.")
+            msg = """\
+                failed to guess 'n_ctx'. This model is unknown or unsupported.
+                Suggestion: provide 'config.json' of the model in the same directory containing model files."""
+            raise KeyError(textwrap.dedent(msg))
 
         n_experts      = None
         n_experts_used = None
@@ -265,7 +274,8 @@ class Params:
     # {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1}
     @staticmethod
     def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
-        config = json.load(open(config_path))
+        with open(config_path) as f:
+            config = json.load(f)
 
         n_experts      = None
         n_experts_used = None
@@ -331,47 +341,86 @@ class Params:
 # vocab
 #
 
-class BpeVocab:
+@runtime_checkable
+class BaseVocab(Protocol):
+    tokenizer_model: ClassVar[str]
+    name: ClassVar[str]
+
+
+class NoVocab(BaseVocab):
+    tokenizer_model = "no_vocab"
+    name = "no_vocab"
+
+    def __repr__(self) -> str:
+        return "<NoVocab for a model without integrated vocabulary>"
+
+
+@runtime_checkable
+class Vocab(BaseVocab, Protocol):
+    vocab_size: int
+    added_tokens_dict: dict[str, int]
+    added_tokens_list: list[str]
+    fname_tokenizer: Path
+
+    def __init__(self, base_path: Path): ...
+    def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: ...
+
+
+class BpeVocab(Vocab):
     tokenizer_model = "gpt2"
     name = "bpe"
 
-    def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None:
-        self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read())
-        if isinstance(self.bpe_tokenizer.get('model'), dict):
-            self.vocab = self.bpe_tokenizer["model"]["vocab"]
-        else:
-            self.vocab = self.bpe_tokenizer
-        added_tokens: dict[str, int]
-        if fname_added_tokens is not None:
-            # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab.
-            added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
+    def __init__(self, base_path: Path):
+        added_tokens: dict[str, int] = {}
+
+        if (fname_tokenizer := base_path / 'vocab.json').exists():
+            # "slow" tokenizer
+            with open(fname_tokenizer, encoding="utf-8") as f:
+                self.vocab = json.load(f)
+
+            try:
+                # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab.
+                with open(base_path / ADDED_TOKENS_FILE, encoding="utf-8") as f:
+                    added_tokens = json.load(f)
+            except FileNotFoundError:
+                pass
         else:
-            # Fall back to trying to find the added tokens in tokenizer.json
-            tokenizer_json_file = fname_tokenizer.parent / 'tokenizer.json'
-            if not tokenizer_json_file.is_file():
-                added_tokens = {}
-            else:
-                tokenizer_json = json.load(open(tokenizer_json_file, encoding="utf-8"))
-                added_tokens = dict(
-                    (item['content'], item['id'])
-                    for item in tokenizer_json.get('added_tokens', [])
-                    # Added tokens here can be duplicates of the main vocabulary.
-                    if item['content'] not in self.bpe_tokenizer)
-
-        vocab_size: int = len(self.vocab)
-        expected_ids    = list(range(vocab_size, vocab_size + len(added_tokens)))
-        actual_ids      = sorted(added_tokens.values())
+            # "fast" tokenizer
+            fname_tokenizer = base_path / FAST_TOKENIZER_FILE
+
+            # if this fails, FileNotFoundError propagates to caller
+            with open(fname_tokenizer, encoding="utf-8") as f:
+                tokenizer_json = json.load(f)
+
+            tokenizer_model: dict[str, Any] = tokenizer_json['model']
+            if (
+                tokenizer_model['type'] != 'BPE' or tokenizer_model.get('byte_fallback', False)
+                or tokenizer_json['decoder']['type'] != 'ByteLevel'
+            ):
+                raise FileNotFoundError('Cannot find GPT-2 BPE tokenizer')
+
+            self.vocab = tokenizer_model["vocab"]
+
+            if (added := tokenizer_json.get('added_tokens')) is not None:
+                # Added tokens here can be duplicates of the main vocabulary.
+                added_tokens = {item['content']: item['id']
+                                for item in added
+                                if item['content'] not in self.vocab}
+
+        vocab_size   = len(self.vocab)
+        expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
+        actual_ids   = sorted(added_tokens.values())
         if expected_ids != actual_ids:
             expected_end_id = vocab_size + len(actual_ids) - 1
-            raise Exception(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range {vocab_size} - {expected_end_id}; got {actual_ids}")
+            raise ValueError(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range "
+                             f"{vocab_size} - {expected_end_id}; got {actual_ids}")
 
         items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
         self.added_tokens_dict    = added_tokens
         self.added_tokens_list    = [text for (text, idx) in items]
-        self.vocab_size_base: int = vocab_size
-        self.vocab_size: int      = self.vocab_size_base + len(self.added_tokens_list)
+        self.vocab_size_base      = vocab_size
+        self.vocab_size           = self.vocab_size_base + len(self.added_tokens_list)
         self.fname_tokenizer      = fname_tokenizer
-        self.fname_added_tokens   = fname_added_tokens
 
     def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
         reverse_vocab = {id: encoded_tok for encoded_tok, id in self.vocab.items()}
@@ -392,19 +441,25 @@ class BpeVocab:
         return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
 
 
-class SentencePieceVocab:
+class SentencePieceVocab(Vocab):
     tokenizer_model = "llama"
     name = "spm"
 
-    def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None:
-        self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
-        added_tokens: dict[str, int]
-        if fname_added_tokens is not None:
-            added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
-        else:
-            added_tokens = {}
+    def __init__(self, base_path: Path):
+        added_tokens: dict[str, int] = {}
+        if (fname_tokenizer := base_path / 'tokenizer.model').exists():
+            # normal location
+            try:
+                with open(base_path / ADDED_TOKENS_FILE, encoding="utf-8") as f:
+                    added_tokens = json.load(f)
+            except FileNotFoundError:
+                pass
+        elif not (fname_tokenizer := base_path.parent / 'tokenizer.model').exists():
+            # not found in alternate location either
+            raise FileNotFoundError('Cannot find tokenizer.model')
 
-        vocab_size: int = self.sentencepiece_tokenizer.vocab_size()
+        self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
+        vocab_size = self.sentencepiece_tokenizer.vocab_size()
 
         new_tokens       = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
         expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens)))
@@ -414,18 +469,17 @@ class SentencePieceVocab:
             raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}")
 
         # Token pieces that were added to the base vocabulary.
-        self.added_tokens_dict = added_tokens
+        self.added_tokens_dict  = added_tokens
         self.added_tokens_list  = [new_tokens[id] for id in actual_new_ids]
         self.vocab_size_base    = vocab_size
         self.vocab_size         = self.vocab_size_base + len(self.added_tokens_list)
         self.fname_tokenizer    = fname_tokenizer
-        self.fname_added_tokens = fname_added_tokens
 
     def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
         tokenizer = self.sentencepiece_tokenizer
         for i in range(tokenizer.vocab_size()):
             piece = tokenizer.id_to_piece(i)
-            text: bytes = piece.encode("utf-8")
+            text         = piece.encode("utf-8")
             score: float = tokenizer.get_score(i)
 
             toktype = gguf.TokenType.NORMAL
@@ -458,27 +512,42 @@ class SentencePieceVocab:
         return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
 
 
-class HfVocab:
+class LlamaHfVocab(Vocab):
     tokenizer_model = "llama"
     name = "hfft"
 
-    def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None = None) -> None:
+    def __init__(self, base_path: Path, ignore_nonllama: bool = False):
+        fname_tokenizer = base_path / FAST_TOKENIZER_FILE
+        # if this fails, FileNotFoundError propagates to caller
+        with open(fname_tokenizer, encoding='utf-8') as f:
+            tokenizer_json = json.load(f)
+
+        # pre-check so we know if we need transformers
+        tokenizer_model: dict[str, Any] = tokenizer_json['model']
+        if ignore_nonllama:
+            pass  # workaround incorrect use of this class for WordPiece
+        elif (
+            tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False)
+            or tokenizer_json['decoder']['type'] != 'Sequence'
+        ):
+            raise FileNotFoundError('Cannot find Llama BPE tokenizer')
+
         try:
             from transformers import AutoTokenizer
         except ImportError as e:
             raise ImportError(
-                "To use HfVocab, please install the `transformers` package. "
+                "To use LlamaHfVocab, please install the `transformers` package. "
                 "You can install it with `pip install transformers`."
             ) from e
 
-        print("fname_tokenizer:", fname_tokenizer)
         # Allow the tokenizer to default to slow or fast versions.
         # Explicitly set tokenizer to use local paths.
         self.tokenizer = AutoTokenizer.from_pretrained(
-            fname_tokenizer,
-            cache_dir=fname_tokenizer,
+            base_path,
+            cache_dir=base_path,
             local_files_only=True,
         )
+        assert self.tokenizer.is_fast  # assume tokenizer.json is used
 
         # Initialize lists and dictionaries for added tokens
         self.added_tokens_list = []
@@ -506,8 +575,7 @@ class HfVocab:
         self.vocab_size_base = self.tokenizer.vocab_size
         self.vocab_size      = self.vocab_size_base + len(self.added_tokens_list)
 
-        self.fname_tokenizer    = fname_tokenizer
-        self.fname_added_tokens = fname_added_tokens
+        self.fname_tokenizer = fname_tokenizer
 
     def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
         reverse_vocab = {
@@ -559,18 +627,7 @@ class HfVocab:
         yield from self.added_tokens()
 
     def __repr__(self) -> str:
-        return f"<HfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
-
-
-class NoVocab:
-    tokenizer_model = "no_vocab"
-    name = "no_vocab"
-
-    def __repr__(self) -> str:
-        return "<NoVocab for a model without integrated vocabulary>"
-
-
-Vocab: TypeAlias = "BpeVocab | SentencePieceVocab | HfVocab | NoVocab"
+        return f"<LlamaHfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
 
 
 #
@@ -588,7 +645,7 @@ def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray:
             .reshape(weights.shape))
 
 
-class Tensor(metaclass=ABCMeta):
+class Tensor(ABC):
     data_type: DataType
 
     @abstractmethod
@@ -610,7 +667,7 @@ def bf16_to_fp32(bf16_arr: np.ndarray[Any, np.dtype[np.uint16]]) -> NDArray:
 
 
 class UnquantizedTensor(Tensor):
-    def __init__(self, ndarray: NDArray) -> None:
+    def __init__(self, ndarray: NDArray):
         assert isinstance(ndarray, np.ndarray)
         self.ndarray = ndarray
         self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype]
@@ -689,7 +746,7 @@ class ModelPlus:
     model: LazyModel
     paths: list[Path]  # Where this was read from.
     format: Literal['ggml', 'torch', 'safetensors', 'none']
-    vocab: Vocab | None  # For GGML models (which have vocab built in), the vocab.
+    vocab: BaseVocab | None  # For GGML models (which have vocab built in), the vocab.
 
 
 def merge_sharded(models: list[LazyModel]) -> LazyModel:
@@ -698,7 +755,7 @@ def merge_sharded(models: list[LazyModel]) -> LazyModel:
     names = {name: None for model in models for name in model}
 
     def convert(name: str) -> LazyTensor:
-        lazy_tensors: list[LazyTensor] = [model[name] for model in models]
+        lazy_tensors = [model[name] for model in models]
         if len(lazy_tensors) == 1:
             # only one file; don't go through this procedure since there might
             # be quantized tensors
@@ -719,7 +776,7 @@ def merge_sharded(models: list[LazyModel]) -> LazyModel:
 
         def load() -> UnquantizedTensor:
             ndarrays = [load_unquantized(tensor) for tensor in lazy_tensors]
-            concatenated: NDArray = np.concatenate(ndarrays, axis=axis)
+            concatenated = np.concatenate(ndarrays, axis=axis)
             return UnquantizedTensor(concatenated)
         description = 'concatenated[[' + '] | ['.join(lt.description for lt in lazy_tensors) + ']]'
         return LazyTensor(load, concatenated_shape, lazy_tensors[0].data_type, description)
@@ -807,10 +864,10 @@ class LazyUnpickler(pickle.Unpickler):
 
         def load(offset: int, elm_count: int) -> NDArray:
             dtype = data_type.dtype
-            fp = self.zip_file.open(info)
-            fp.seek(offset * dtype.itemsize)
-            size = elm_count * dtype.itemsize
-            data = fp.read(size)
+            with self.zip_file.open(info) as fp:
+                fp.seek(offset * dtype.itemsize)
+                size = elm_count * dtype.itemsize
+                data = fp.read(size)
             assert len(data) == size
             return np.frombuffer(data, dtype)
         description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}'
@@ -831,7 +888,7 @@ class LazyUnpickler(pickle.Unpickler):
     def rebuild_from_type_v2(func, new_type, args, state):
         return func(*args)
 
-    CLASSES: dict[tuple[str, str], Any] = {
+    CLASSES = {
         # getattr used here as a workaround for mypy not being smart enough to determine
         # the staticmethods have a __func__ attribute.
         ('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'),
@@ -890,7 +947,7 @@ def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus:
 def must_read(fp: IO[bytes], length: int) -> bytes:
     ret = fp.read(length)
     if len(ret) < length:
-        raise Exception("unexpectedly reached end of file")
+        raise EOFError("unexpectedly reached end of file")
     return ret
 
 
@@ -948,13 +1005,14 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc
             yield result
 
 
-def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None:
+def check_vocab_size(params: Params, vocab: BaseVocab, pad_vocab: bool = False) -> None:
     # Handle special case where the model's vocab size is not set
     if params.n_vocab == -1:
         raise ValueError(
-            f"The model's vocab size is set to -1 in params.json. Please update it manually.{f' Maybe {vocab.vocab_size}?' if hasattr(vocab, 'vocab_size') else ''}"
+            "The model's vocab size is set to -1 in params.json. Please update it manually."
+            + (f" Maybe {vocab.vocab_size}?" if isinstance(vocab, Vocab) else ""),
         )
-    if isinstance(vocab, NoVocab):
+    if not isinstance(vocab, Vocab):
         return  # model has no vocab
 
     # Check for a vocab size mismatch
@@ -979,11 +1037,11 @@ def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> N
     if vocab.vocab_size < params.n_vocab:
         msg += " Add the --pad-vocab option and try again."
 
-    raise Exception(msg)
+    raise ValueError(msg)
 
 
 class OutputFile:
-    def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE) -> None:
+    def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
         self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
 
     def add_meta_arch(self, params: Params) -> None:
@@ -1034,8 +1092,6 @@ class OutputFile:
             self.gguf.add_file_type(params.ftype)
 
     def extract_vocabulary_from_model(self, vocab: Vocab) -> tuple[list[bytes], list[float], list[gguf.TokenType]]:
-        assert not isinstance(vocab, NoVocab)
-
         tokens = []
         scores = []
         toktypes = []
@@ -1135,7 +1191,7 @@ class OutputFile:
 
     @staticmethod
     def write_all(
-        fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab,
+        fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
         concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
         pad_vocab: bool = False,
     ) -> None:
@@ -1145,11 +1201,11 @@ class OutputFile:
 
         # meta data
         of.add_meta_arch(params)
-        if isinstance(vocab, NoVocab):
-            of.gguf.add_tokenizer_model(vocab.tokenizer_model)
-        else:
+        if isinstance(vocab, Vocab):
             of.add_meta_vocab(vocab)
             of.add_meta_special_vocab(svocab)
+        else:  # NoVocab
+            of.gguf.add_tokenizer_model(vocab.tokenizer_model)
 
         # tensor info
         for name, lazy_tensor in model.items():
@@ -1176,7 +1232,7 @@ def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileT
 
     name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()}
 
-    raise Exception(f"Unexpected combination of types: {name_to_type}")
+    raise ValueError(f"Unexpected combination of types: {name_to_type}")
 
 
 def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel:
@@ -1186,7 +1242,7 @@ def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyM
 
 def convert_model_names(model: LazyModel, params: Params, skip_unknown: bool) -> LazyModel:
     tmap = gguf.TensorNameMap(ARCH, params.n_layer)
-    should_skip: set[gguf.MODEL_TENSOR] = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, []))
+    should_skip = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, []))
 
     tmp = model
 
@@ -1213,8 +1269,7 @@ def convert_model_names(model: LazyModel, params: Params, skip_unknown: bool) ->
             if skip_unknown:
                 print(f"Unexpected tensor name: {name} - skipping")
                 continue
-            else:
-                raise Exception(f"Unexpected tensor name: {name}. Use --skip-unknown to ignore it (e.g. LLaVA)")
+            raise ValueError(f"Unexpected tensor name: {name}. Use --skip-unknown to ignore it (e.g. LLaVA)")
 
         if tensor_type in should_skip:
             print(f"skipping tensor {name_new}")
@@ -1231,7 +1286,7 @@ def nth_multifile_path(path: Path, n: int) -> Path | None:
     the nth path in the model.
     '''
     # Support the following patterns:
-    patterns: list[tuple[str, str]] = [
+    patterns = [
         # - x.00.pth, x.01.pth, etc.
         (r'\.[0-9]{2}\.pth$', f'.{n:02}.pth'),
         # - x-00001-of-00002.bin, x-00002-of-00002.bin, etc.
@@ -1277,9 +1332,9 @@ def load_some_model(path: Path) -> ModelPlus:
             globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"]
             files = [file for glob in globs for file in path.glob(glob)]
         if not files:
-            raise Exception(f"Can't find model in directory {path}")
+            raise FileNotFoundError(f"Can't find model in directory {path}")
         if len(files) > 1:
-            raise Exception(f"Found multiple models in {path}, not sure which to pick: {files}")
+            raise ValueError(f"Found multiple models in {path}, not sure which to pick: {files}")
         path = files[0]
 
     paths = find_multifile_paths(path)
@@ -1293,36 +1348,14 @@ def load_some_model(path: Path) -> ModelPlus:
 
 
 class VocabFactory:
-    _FILES = {"spm": "tokenizer.model", "bpe": "vocab.json", "hfft": "tokenizer.json"}
+    _VOCAB_CLASSES: list[type[Vocab]] = [SentencePieceVocab, BpeVocab, LlamaHfVocab]
 
     def __init__(self, path: Path):
         self.path = path
-        self.file_paths = self._detect_files()
-        print(f"Found vocab files: {self.file_paths}")
-
-    def _detect_files(self) -> dict[str, Path | None]:
-        def locate(file: str) -> Path | None:
-            if (path := self.path / file).exists():
-                return path
-            if (path := self.path.parent / file).exists():
-                return path
-            return None
-
-        return {vt: locate(f) for vt, f in self._FILES.items()}
-
-    def _select_file(self, vocab_types: list[str]) -> tuple[str, Path]:
-        for vtype in vocab_types:
-            try:
-                path = self.file_paths[vtype]
-            except KeyError:
-                raise ValueError(f"Unsupported vocabulary type {vtype}") from None
-            if path is not None:
-                return vtype, path
-        raise FileNotFoundError(f"Could not find any of {[self._FILES[vt] for vt in vocab_types]}")
 
-    def _create_special_vocab(self, vocab: Vocab, model_parent_path: Path) -> gguf.SpecialVocab:
+    def _create_special_vocab(self, vocab: BaseVocab, model_parent_path: Path) -> gguf.SpecialVocab:
         load_merges = vocab.name == "bpe"
-        n_vocab = vocab.vocab_size if hasattr(vocab, "vocab_size") else None
+        n_vocab = vocab.vocab_size if isinstance(vocab, Vocab) else None
         return gguf.SpecialVocab(
             model_parent_path,
             load_merges=load_merges,
@@ -1331,27 +1364,29 @@ class VocabFactory:
         )
 
     def _create_vocab_by_path(self, vocab_types: list[str]) -> Vocab:
-        vocab_type, path = self._select_file(vocab_types)
-        print(f"Loading vocab file {path!r}, type {vocab_type!r}")
+        vocab_classes: dict[str, type[Vocab]] = {cls.name: cls for cls in self._VOCAB_CLASSES}
+        selected_vocabs: dict[str, type[Vocab]] = {}
+        for vtype in vocab_types:
+            try:
+                selected_vocabs[vtype] = vocab_classes[vtype]
+            except KeyError:
+                raise ValueError(f"Unsupported vocabulary type {vtype}") from None
 
-        added_tokens_path = path.parent / "added_tokens.json"
-        if vocab_type == "bpe":
-            return BpeVocab(
-                path, added_tokens_path if added_tokens_path.exists() else None
-            )
-        if vocab_type == "spm":
-            return SentencePieceVocab(
-                path, added_tokens_path if added_tokens_path.exists() else None
-            )
-        if vocab_type == "hfft":
-            return HfVocab(
-                path.parent, added_tokens_path if added_tokens_path.exists() else None
-            )
-        raise ValueError(vocab_type)
+        for vtype, cls in selected_vocabs.items():
+            try:
+                vocab = cls(self.path)
+                break
+            except FileNotFoundError:
+                pass  # ignore unavailable tokenizers
+        else:
+            raise FileNotFoundError(f"Could not find a tokenizer matching any of {vocab_types}")
+
+        print(f"Loaded vocab file {vocab.fname_tokenizer!r}, type {vocab.name!r}")
+        return vocab
 
-    def load_vocab(self, vocab_types: list[str], model_parent_path: Path) -> tuple[Vocab, gguf.SpecialVocab]:
-        vocab: Vocab
-        if len(vocab_types) == 1 and "no_vocab" in vocab_types:
+    def load_vocab(self, vocab_types: list[str] | None, model_parent_path: Path) -> tuple[BaseVocab, gguf.SpecialVocab]:
+        vocab: BaseVocab
+        if vocab_types is None:
             vocab = NoVocab()
         else:
             vocab = self._create_vocab_by_path(vocab_types)
@@ -1408,10 +1443,8 @@ def main(args_in: list[str] | None = None) -> None:
     parser.add_argument("--skip-unknown", action="store_true",    help="skip unknown tensor names instead of failing")
 
     args = parser.parse_args(args_in)
-    if args.no_vocab:
-        if args.vocab_only:
-            raise ValueError("no need to specify --vocab-only if using --no-vocab")
-        args.vocab_type = "no_vocab"
+    if args.no_vocab and args.vocab_only:
+        raise ValueError("--vocab-only does not make sense with --no-vocab")
 
     if args.dump_single:
         model_plus = lazy_load_file(args.model)
@@ -1433,10 +1466,12 @@ def main(args_in: list[str] | None = None) -> None:
     params = Params.load(model_plus)
     if params.n_ctx == -1:
         if args.ctx is None:
-            raise Exception("The model doesn't have a context size, and you didn't specify one with --ctx\n"
-                            "Please specify one with --ctx:\n"
-                            " - LLaMA v1: --ctx 2048\n"
-                            " - LLaMA v2: --ctx 4096\n")
+            msg = """\
+                The model doesn't have a context size, and you didn't specify one with --ctx
+                Please specify one with --ctx:
+                 - LLaMA v1: --ctx 2048
+                 - LLaMA v2: --ctx 4096"""
+            parser.error(textwrap.dedent(msg))
         params.n_ctx = args.ctx
 
     if args.outtype:
@@ -1451,9 +1486,11 @@ def main(args_in: list[str] | None = None) -> None:
     model_parent_path = model_plus.paths[0].parent
     vocab_path = Path(args.vocab_dir or args.model or model_parent_path)
     vocab_factory = VocabFactory(vocab_path)
-    vocab, special_vocab = vocab_factory.load_vocab(args.vocab_type.split(","), model_parent_path)
+    vocab_types = None if args.no_vocab else args.vocab_type.split(",")
+    vocab, special_vocab = vocab_factory.load_vocab(vocab_types, model_parent_path)
 
     if args.vocab_only:
+        assert isinstance(vocab, Vocab)
         if not args.outfile:
             raise ValueError("need --outfile if using --vocab-only")
         outfile = args.outfile
diff --git a/llama.h b/llama.h
index 1fe4af495820f864e2ea986c57c1047b1afe1ba8..f061d014ca8eb05e2f5e2104eb8299903418a28a 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -60,9 +60,9 @@ extern "C" {
 
     enum llama_vocab_type {
         LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab
-        LLAMA_VOCAB_TYPE_SPM  = 1, // SentencePiece
-        LLAMA_VOCAB_TYPE_BPE  = 2, // Byte Pair Encoding
-        LLAMA_VOCAB_TYPE_WPM  = 3, // WordPiece
+        LLAMA_VOCAB_TYPE_SPM  = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
+        LLAMA_VOCAB_TYPE_BPE  = 2, // GPT-2 tokenizer based on byte-level BPE
+        LLAMA_VOCAB_TYPE_WPM  = 3, // BERT tokenizer based on WordPiece
     };
 
     // note: these values should be synchronized with ggml_rope